https://arxiv.org/abs/2307.03170

Focused Transformer (FoT): The Core Technique

This paper (NeurIPS 2023) introduces a method to extend a transformer's context length by fixing how attention keys are learned, not by changing the architecture. It has two parts: an inference mechanism (memory attention) and a training trick (crossbatch). The training trick is the actual contribution.

The problem they're solving: "distraction"

When you give a transformer access to a huge external memory of (key, value) pairs (as in Memorizing Transformer), attention degrades. Why? During standard training, the model never had to distinguish keys from its own document vs keys from unrelated documents. So keys from different documents end up overlapping in embedding space.

Empirically, they show that if you put 1 relevant document and $d-1$ irrelevant documents in memory, a vanilla transformer spreads attention roughly as $1/d$ across all of them — it can't tell signal from noise. As memory grows, useful information drowns.

Part 1: Memory attention layers (inference)

A small subset of layers $L$ (e.g., layers 6, 12, 18 in a 3B model) are designated "memory layers." At inference:

  1. As tokens are processed, their (key, value) pairs at layer $\ell \in L$ are dumped into an external FAISS index.
  2. For each new query at layer $\ell$, attention runs over the local context (normal) plus the top-$k$ nearest keys from memory (retrieved by inner-product kNN).
  3. Memory keys get no positional encoding (or position 0 in LongLLaMA), which is what allows extrapolation to arbitrary lengths.

This part is essentially Memorizing Transformer. The novelty is below.

Part 2: Crossbatch training (the core idea)

This is a contrastive-learning-inspired data pipeline that reshapes the key space so kNN retrieval actually works.

Setup: Arrange the batch so each batch element is a different document. Split each document into a previous chunk $C_{\text{prev}}$ and current chunk $C_{\text{curr}}$.

The trick: At memory layer $\ell$, when computing attention for tokens in $C_{\text{curr}}$ of document $\delta$, the attention context is built from:

So each query attends over its own history and deliberately injected distractors from unrelated documents — all in one differentiable softmax.