- Paper
- Reversible transformer based on RevNets
Explainer
1. Locality-Sensitive Hashing (LSH) Attention
The main innovation is replacing the O(L²) attention mechanism with an O(L log L) approximation:
How it works:
- Problem: Standard attention computes all query-key pairs, creating an L×L matrix (e.g., 64K×64K = 16GB for long sequences)
- Solution: Use LSH to find approximate nearest neighbors
The LSH mechanism:
- Shared Q-K: Set queries and keys to be identical (Q = K), normalized
- Hash function: Use random projections to assign vectors to buckets
- Similar vectors get same hash with high probability
- Distant vectors get different hashes
- Attention within buckets: Only compute attention between items in the same hash bucket
Implementation details:
- Sort queries/keys by hash bucket
- Process in chunks of size m for efficient batching
- Use multiple hash rounds (typically 4-8) to reduce probability of missing similar pairs
2. Reversible Layers
Eliminates the need to store activations for backpropagation:
Standard Transformer:
- Must store all N layer activations for backprop