Sparse Auto Encoders
Architecture // Training
L1 Regularization
Top K Sparsity
L0 Regularization
Hardware // Scalability
Interpretability Metrics
Circuit Analysis
Challenges
References
- Critical for interpreting model states (Why did GPT-4 output X?).
Intermediate activations (MLP layers) encode superpositioned features (Elhage et al.).
- SAEs disentangle these via sparse bottlenecks.
Decoding internal states allows auditing/steering models (Olah et al.) -> we need safety!!
Input (d-dim) → Encoder (W_e: d×m) → ReLU → Latent (m-dim, sparse) → Decoder (W_d: m×d) → Output
ReLU induces sparsity by zeroing negatives (Nair & Hinton)
- Leaky ReLU -> sparser gradients.
Loss Function
L = ||x - x̂||² + λ⋅||h||₁ + γ⋅||W_e||²
MSE // cosine similarity for directionality (Bricken et al)
- Decoder Weights -> often tied (W_d = W_eᵀ) to reduce parameters (Rumelhart et al.)
Adds
λ * |latent_activations|
to loss function. Penalizes non-zero activationsMany latents fire rarely across all inputs (global sparsity)
But
Pulls all activations toward zero, including important ones
Requires tuning λ (tradeoff: sparsity vs. reconstruction loss)
Latents permanently “off” due to over-penalization —> dead neurons
L1 struggles to resolve this, leading to polysemantic latents.
Force only the K-largest latents to stay active; zero others
Use straight-through estimators (e.g.,
topk_mask.detach() + (1 - topk_mask)
during backward passActive latents retain original magnitudes
Set K directly (e.g., 256 active/16M latents)
Lower reconstruction error at same sparsity level vs L1
Higher monosemanticity
Penalize the number of non-zero activations
Direct sparsity control
Stochastic relaxation → noisy gradients
TPUs/GPUs are down bad for Sparsity
Sparse matrix ops (block-sparse kernels) reduce compute/memory
16M-latent SAE on GPT-4 → ~1% active latents per input → 160K ops, feasible on TPUv4.
Scale matters because
- Larger SAEs capture long-tail features (rare programming syntax, niche cultural concepts, obscure APIs).
- Scaling laws for interpretability? Early results suggest SAE performance improves predictably with size (Anthropic) (Kaplan)
Diversity Score: Number of unique features per latent (Goh).
Importance Analysis: Ablate latent → measure downstream effect (Wang).
Top-K SAEs achieve ~60% monosemanticity vs ~30% for L1 (Bricken) - jesus christ
SAE + Attention Heads -> map latents to attention heads driving specific behaviors (citation formatting)
Detect "backdoor" latents activated by adversarial prompts (Ilyas).
Linking latents to model outputs is NP-hard
Tools
- Causal Scrubbing: Edit latents → measure output changes (Wang).
- Shapley Values: Quantify latent contributions
Initialization Matters
Orthogonal decoder weights prevent feature collapse (Saxe)
Avoiding local minima -> curriculum learning (start dense, increase sparsity).
- Hinton & Salakhutdinov (2006) - Reducing dimensionality with autoencoders.
- Tibshirani (1996) - L1 regularization (Lasso).
- Bricken et al. (2023) - Scaling SAEs to 16M latents.
- Cunningham et al. (2023) - Monosemanticity in SAEs.
- Elhage et al. (2022) - Superposition in neural networks.
- Louizos et al. (2018) - L0 regularization.
- Olah et al. (2020) - Circuits in transformers.
- Lundberg & Lee (2017) - SHAP values for interpretability.
- Jouppi et al. (2017) - TPU architecture.
- Gray et al. (2017) - Block-sparse matrix ops.