https://arxiv.org/pdf/2506.09477v1
The paper identifies several concrete implementation issues with KL divergence gradient estimation:
The most common pitfall is implementing the loss as:
loss = -KL_estimate # Then autodiff this
The problems:
log(π(y)/πref(y))
, you get a gradient with zero expectation - essentially just adding noiselog(π(y)/πref(y)) + πref(y)/π(y) - 1
, you accidentally minimize the reverse KL divergence KL(πref, π) instead of the target KL(π, πref)For sequences (like in LLMs), the common implementation is:
loss = sum([token_KL(π(yt|y<t), πref(yt|y<t)) for t in range(T)])
The problem: This doesn't account for how changing the distribution at token t affects future tokens t+1, t+2, etc. It only computes a "partial gradient" of the full sequence-level KL.
The paper shows the proper gradient estimates should be:
For single samples:
# Vanilla gradient (correct)
gradient = log(π(y)/πref(y)) * ∇log π(y)
For sequences:
# Account for future dependencies
gradient = sum([sum([log(π(ys)/πref(ys)) for s in range(t, T)]) * ∇log π(yt) for t in range(T)])