Overview

Shampoo is a preconditioned gradient descent algorithm, similar to Adam, but it handles the parameters in their natural tensor shape rather than flattening everything into a vector. The key insight is that parameters in deep learning often have meaningful structure (like matrices for linear layers or 4D tensors for conv layers) and Shampoo exploits this structure.

image.png

Key Components:

  1. Instead of maintaining a single preconditioner like Adam, Shampoo maintains separate preconditioners for each dimension of your parameters. For a matrix W ∈ Rm×n, it maintains:
  2. The preconditioners are updated using accumulation of gradient statistics (similar to Adam's second moment), but separately for each dimension.

Let me show a simplified version for the matrix case:

# Initialize
W = zeros(m, n)# Parameters
L = epsilon * eye(m)# Left preconditioner
R = epsilon * eye(n)# Right preconditioner
eta = learning_rate

for t in range(num_steps):
		# Get gradient
    G = compute_gradient(W)

		# Update preconditioners
    L = L + G @ G.T# Accumulate left stats
    R = R + G.T @ G# Accumulate right stats

		# Update parameters with preconditioned gradient
    W = W - eta * L^(-1/4) @ G @ R^(-1/4)

Simple Example: Let's say you have a 2×2 matrix parameter W. At each step:

  1. You compute the gradient G
  2. Update L using outer product of G with itself from left side
  3. Update R using outer product of G with itself from right side
  4. Precondition the gradient by applying L^(-1/4) from left and R^(-1/4) from right

The key advantage over Adam is that Shampoo can capture parameter correlations more efficiently by working with the natural tensor structure. For a matrix W ∈ Rm×n:

The -1/4 power on the preconditioners gives similar theoretical properties to Adam's 1/√t scaling of learning rates.

Intuition

  1. MOTIVATION Think about training a neural network layer with weight matrix W ∈ Rm×n. Some input features (or neurons) might have large values, some might be small. Some output neurons might need big updates, some small. We want to adaptively scale our updates to account for this.