-
Problem
- For standard language model training, we don't need to backprop through sampling at all. We use cross entropy loss and backprop through the softmax (have logits across entire vocab).
- But sometimes you have “hard” routing/sampling, such as with Sparse MoE or discrete/VQ-VAE or hard attention or neural architecture search
- We need discrete decisions IN THE MIDDLE of a computation graph where we want end-to-end training
- Side note: policy gradient theorem doesn’t need this
-
Straight-through estimator
-
Problem: In networks with discrete operations (like quantization or argmax), the gradient is undefined or zero almost everywhere. This prevents backpropagation from working effectively.
-
Idea: use a proxy gradient to let signal backpropagate.
-
Forward Pass: Use the actual discrete operation. Backward Pass: Pretend the operation was the identity function.
-
Mathematical Formulation:
Let's say we have a function f(x) that includes a non-differentiable operation g(x).
- Forward: y = f(x) = g(x)
- Backward: ∂y/∂x ≈ ∂g(x)/∂x = 1 (identity)
-
In VQ-VAE Context:
- Forward: z_q = argmin_k ||z_e - e_k||₂
(Find the nearest codebook vector)
- Backward: Treat z_q as if it were equal to z_e for gradient computation
-
Torch:
class StraightThroughEstimator(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return (input > 0).float()
@staticmethod
def backward(ctx, grad_output):
return grad_output
-
Applications: VQVAEs, Switch Transformer
-
Reparameterization trick
- A way to make non-differentiable sampling differentiable wrt some params, by expressing a random variable as a deterministic function of those params and some noise
- Example: instead of sampling from N(μ, σ²), we can sample ε ~ N(0,1) and then compute x = μ + σε. We can still learn µ and σ.
- Why it works:
- The randomness is moved to an input (ε) that doesn't depend on the network parameters.
- This trick allows gradients to flow through the sampling process.
- The sampled z is still distributed according to N(μ, σ²), but now it's a differentiable function of μ and σ.
- Applications: VAEs
-
Gumbel-softmax reparameterization trick
- Without reparameterization, you'd sample directly from a categorical distribution using the Gumbel-Max trick: $z = \text{onehot}(\arg\max_i[g_i + \log \pi_i])$, where $g_i$ are samples from Gumbel(0,1) and $\pi_i$ are the categorical probabilities
-
E.g., here’s how to sample from a categorical using just a argmax operation!
log_probs = np.log([0.7, 0.2, 0.1])
gumbel_noise = -np.log(-np.log(np.random.uniform(size=3)))
sample = np.argmax(gumbel_noise + log_probs)
-
Why is this useful instead of just generating a random number in [0,1] and seeing which bucket it falls in?
-
By itself is not super interesting, though it does offer parallelization. With the cumulative method, you're inherently sequential - you have to check buckets in order.
-
But the big one is that it leads naturally to the Gumbel-Softmax relaxation (next), which enables differentiable computation.
- Combine above ideas to create a differentiable approximation of categorical sampling:
- Start with logits $\pi_i$
- Add Gumbel noise to logits: $z_i = \pi_i + g_i$
- Apply softmax with temperature: $y_i = \frac{\exp(z_i/\tau)}{\sum_j \exp(z_j/\tau)}$
- Result: "almost one-hot" expert selection that's differentiable
- Replaces the argmax of Gumbel-max with a softmax with temp.
- Properties
- As $\tau \to 0$, this approaches a discrete one-hot sample
- As $\tau \to \infty$, this approaches uniform sampling
- Most importantly, the operation is fully differentiable
- Applications: discrete latent VAEs, sparse MoE?
- Code
def gumbel_softmax(logits, temperature=1.0):
# Sample Gumbel noise
g = -torch.log(-torch.log(torch.rand_like(logits)))
# Add noise to logits and apply temperature
y = torch.softmax((logits + g) / temperature, dim=-1)
return y
-
Applications
- Sparse MoE
- Need to choose top k experts
- Switch transformer uses STE
- Standard Continuous VAE:
- Latent space: Continuous, typically using multivariate normal distributions
- Sampling method: Reparameterization trick for normal distributions
z = μ + σ ⊙ ε, where ε ~ N(0, I)
- Training: End-to-end with backpropagation
- Gradients: Flow smoothly through the sampling process
- Use case: General-purpose image generation, representation learning
- Discrete Latent VAE with Gumbel-Softmax:
- Latent space: Discrete, but made differentiable (”smoothed”) via Gumbel-Softmax temperature
- Sampling method: Gumbel-Softmax trick
y_i = exp((log(π_i) + g_i) / τ) / Σ_j exp((log(π_j) + g_j) / τ)
where g_i are Gumbel(0,1) samples and τ is temperature
- Training: End-to-end with backpropagation
- Gradients: Flow through a differentiable approximation of discrete sampling
- Use case: Tasks benefiting from discrete representations, e.g., language modeling
- The discreteness is "soft" during training and can be made "hard" during inference.
- VQ-VAE (Vector Quantized-Variational AutoEncoder):
- Latent space: Strictly discrete, using a codebook of embeddings (”vector quantization in the latent space”)
- Sampling method: Nearest neighbor lookup in the codebook
z_q = argmin_k ||z_e - e_k||₂
where z_e is the encoder output and e_k are codebook vectors
- Training: Uses straight-through estimator for backpropagation
- Gradients: Do not flow through quantization; stop-gradient on encoder, copy gradient to codebook
- Use case: High-quality image and audio generation, especially for multi-stage models
- The discreteness is "hard" both during training and inference.
Straight-through estimators allow optimization even on flat gradient surfaces
I get that this is non-differentiable at exactly zero, but everywhere else it is smooth right? at negative values it's zero gradient, and at positive values it's also zero gradient.
This is precisely the problem! If we used these "true" derivatives:
- For x < 0: gradient = 0
- For x > 0: gradient = 0
- At x = 0: undefined
Then during backpropagation, we would get zero gradients everywhere (except the mathematically undefined case at exactly 0). This means the parameters that led to this input would receive no gradient signal about whether they should change to push x higher or lower.
The straight-through estimator isn't about handling the case where x is exactly zero. Instead, it's about providing useful gradient information everywhere by pretending the function had a meaningful slope.
Typically, it pretends the step function was actually the identity function during the backward pass. This way, even though the real function has zero gradient almost everywhere, we still get useful learning signals that indicate whether increasing or decreasing x would have helped.
class StepFunctionWithSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return (x > 0).float()# Real step function
@staticmethod
def backward(ctx, grad_output):
return grad_output# Pretend we did f(x) = x
Forced exploration is the key to optimizing the router over the taken samples
Wait, but only the chosen expert gets compute. It's not like we are computing all the experts and then downweighting the rest. So how could we have received signal about which expert would have been better?
The STE here only helps with propagating gradients through the argmax selection mechanism - it doesn't magically give us information about counterfactuals (what other experts would have done).
This leads to several practical challenges:
- Exploration Problem
- How do we discover if other experts would be better if we never try them?
- This is similar to the exploration-exploitation dilemma in reinforcement learning