Straight-through estimators allow optimization even on flat gradient surfaces

I get that this is non-differentiable at exactly zero, but everywhere else it is smooth right? at negative values it's zero gradient, and at positive values it's also zero gradient.

This is precisely the problem! If we used these "true" derivatives:

  1. For x < 0: gradient = 0
  2. For x > 0: gradient = 0
  3. At x = 0: undefined

Then during backpropagation, we would get zero gradients everywhere (except the mathematically undefined case at exactly 0). This means the parameters that led to this input would receive no gradient signal about whether they should change to push x higher or lower.

The straight-through estimator isn't about handling the case where x is exactly zero. Instead, it's about providing useful gradient information everywhere by pretending the function had a meaningful slope.

Typically, it pretends the step function was actually the identity function during the backward pass. This way, even though the real function has zero gradient almost everywhere, we still get useful learning signals that indicate whether increasing or decreasing x would have helped.

class StepFunctionWithSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return (x > 0).float()# Real step function

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output# Pretend we did f(x) = x

Forced exploration is the key to optimizing the router over the taken samples

Wait, but only the chosen expert gets compute. It's not like we are computing all the experts and then downweighting the rest. So how could we have received signal about which expert would have been better?

The STE here only helps with propagating gradients through the argmax selection mechanism - it doesn't magically give us information about counterfactuals (what other experts would have done).

This leads to several practical challenges:

  1. Exploration Problem