https://arxiv.org/pdf/2506.09477v1

Summary

The paper identifies several concrete implementation issues with KL divergence gradient estimation:

1. Differentiating Through KL Estimates Doesn't Give Correct Gradients

The most common pitfall is implementing the loss as:

loss = -KL_estimate  # Then autodiff this

The problems:

2. Token-Level Losses Miss Temporal Dependencies

For sequences (like in LLMs), the common implementation is:

loss = sum([token_KL(π(yt|y<t), πref(yt|y<t)) for t in range(T)])

The problem: This doesn't account for how changing the distribution at token t affects future tokens t+1, t+2, etc. It only computes a "partial gradient" of the full sequence-level KL.

3. The "Correct" Implementations

The paper shows the proper gradient estimates should be:

For single samples:

# Vanilla gradient (correct)
gradient = log(π(y)/πref(y)) * ∇log π(y)

For sequences:

# Account for future dependencies
gradient = sum([sum([log(π(ys)/πref(ys)) for s in range(t, T)]) * ∇log π(yt) for t in range(T)])