
This is the SAM (Sharpness-Aware Minimization) paper. Here's what it concretely does:
Pass 1 — Find the adversarial perturbation: Do a normal forward + backward pass on the batch at your current weights $w$. This gives you the gradient $\nabla_w L(w)$. But you don't use this gradient to update the weights. Instead you use it to compute a perturbation:
$$\hat{\epsilon} = \rho \cdot \frac{\nabla_w L(w)}{|\nabla_w L(w)|_2}$$
This is just the gradient, rescaled to have norm $\rho$ (a hyperparameter, typically ~0.05). It's the direction in which the loss increases fastest — the "worst-case" direction in a ball of radius $\rho$ around $w$.
Pass 2 — Compute the gradient at the perturbed point: Temporarily move the weights to $w + \hat{\epsilon}$. Run forward + backward again on the same batch. This gives you a new gradient: $\nabla_w L(w)|_{w+\hat{\epsilon}}$.
Update: Throw away the perturbation (snap back to $w$), then apply the second gradient to the original weights:
$$w_{t+1} = w_t - \eta \cdot \nabla_w L(w)|_{w+\hat{\epsilon}}$$
You're not descending at $w$ — you're descending based on what the loss looks like at the worst nearby point. If you're sitting at the bottom of a sharp, narrow minimum, the perturbation kicks you up onto the steep wall, and the resulting gradient pushes you out of the narrow valley entirely. If you're in a wide flat basin, the perturbation barely changes the loss, and you descend normally.
Net effect: you converge to flat minima, which generalize better. Cost: 2× the forward/backward passes per step.