This is about the nuts and bolts of training ML models.
Training/optimization advice/wisdom
Coding mistakes
Data processing common things
Fixes high bias | Fixes high var | |
---|---|---|
Data | More | |
Features | More | Fewer |
Params | More | Fewer |
Regularization | Less | More |
overfitting
ML projects structure
Training issues and optimizations
Dealing with vanishing/exploding gradients
Debugging activations
For a simple NN like this:
layers = [
Linear(n_embd * block_size, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
Linear( n_hidden, vocab_size, bias=False), BatchNorm1d(vocab_size),
]
Expect these for activations, grads, and weights—don’t want saturation to go to 0, want all layers to be about the same instead of vanishing/exploding:
plt.figure(figsize=(20, 4)) # width and height of the plot
legends = []
for i, layer in enumerate(layers[:-1]): # note: exclude the output layer
if isinstance(layer, Tanh):
t = layer.out
print('layer %d (%10s): mean %+.2f, std %.2f, saturated: %.2f%%' % (i, layer.__class__.__name__, t.mean(), t.std(), (t.abs() > 0.97).float().mean()*100))
hy, hx = torch.histogram(t, density=True)
plt.plot(hx[:-1].detach(), hy.detach())
legends.append(f'layer {i} ({layer.__class__.__name__}')
plt.legend(legends);
plt.title('activation distribution')
# visualize histograms
plt.figure(figsize=(20, 4)) # width and height of the plot
legends = []
for i, layer in enumerate(layers[:-1]): # note: exclude the output layer
if isinstance(layer, Tanh):
t = layer.out.grad
print('layer %d (%10s): mean %+f, std %e' % (i, layer.__class__.__name__, t.mean(), t.std()))
hy, hx = torch.histogram(t, density=True)
plt.plot(hx[:-1].detach(), hy.detach())
legends.append(f'layer {i} ({layer.__class__.__name__}')
plt.legend(legends);
plt.title('gradient distribution')
# visualize histograms
plt.figure(figsize=(20, 4)) # width and height of the plot
legends = []
for i,p in enumerate(parameters):
t = p.grad
if p.ndim == 2:
print('weight %10s | mean %+f | std %e | grad:data ratio %e' % (tuple(p.shape), t.mean(), t.std(), t.std() / p.std()))
hy, hx = torch.histogram(t, density=True)
plt.plot(hx[:-1].detach(), hy.detach())
legends.append(f'{i} {tuple(p.shape)}')
plt.legend(legends)
plt.title('weights gradient distribution');
Visualize log( SD of grad steps / SD of weights ) for linear layers, over steps:
ud.append([((lr*p.grad).std() / p.data.std()).log10().item() for p in parameters])
...
plt.figure(figsize=(20, 4))
legends = []
for i,p in enumerate(parameters):
if p.ndim == 2:
plt.plot([ud[j][i] for j in range(len(ud))])
legends.append('param %d' % i)
plt.plot([0, len(ud)], [-3, -3], 'k') # these ratios should be ~1e-3, indicate on plot
plt.legend(legends);
Initialization
Usu. Gaussian
If X, W are unit normal, then XW will have SD > 1, but want SD=1. Should scale init W by $\sqrt{1/d}$ where d is dimension
But there’s also the activation nonlinearity, and you want the activation to have SD 1 too.
Parameterized Kaiming init is probably the most common init now (and in Pytorch). What gain you use depends on the nonlinearity. E.g. for ReLU, He says init to $\sqrt{2/d}$ (because ReLU discards half the dist).
GPT-NeoX uses “small init” $\sqrt{\frac{2}{5d}}$ from Transformers without Tears
Biases may be OK to zero
Residuals should be zeroed
Distributed training
Pretraining
Approaches
norm(anchor - pos) - norm(anchor - neg) < -a
LR schedules
Visualizations (source)
Even though Adam effectively gives each parameter dynamic learning rate, it can still make sense to use a scheduler on what is effectively the global cap of learning rates (source)
From https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/guide2/Research_Projects.html
Regularization techniques
Instruction fine tuning
Techniques for scaling to larger models than memory
https://huggingface.co/docs/transformers/v4.18.0/en/performance
Activation checkpointing: rather than materialize all activations for the backprop, recompute some
Gradient accumulation
Gradient checkpointing: train models 10x larger than your avail. mem
Mixed precision training: compute (forward+backward) in fp16/bf16 rather than fp32
Low-memory Optimizers: Ada, 8bit Adam
Floating point types