See also
General resources
DDP: simple, but requires full model on each node
Model parallelism: ambiguous term, refers to either pipeline or tensor, usu. tensor
Pipeline parallelism: inter-layer



Tensor parallelism: intra-layer, chatty


Piece by piece (source)
3D parallelism: DP, MP, PP combined


Maybe helpful:

ZeRO/FSDP: partitions optimizer state, weights, gradients
This is a form of data parallelism
Partitioning weights may interfere with MP/PP?
Animations (source)
Stage 1: just optimizer state
Stage 2: gradients
Stage 3: weights, very chatty?

FSDP and TP
I'm storing weights as [in, out], so the forward is literally Y = X @ W, no transpose. (If you're thinking PyTorch nn.Linear, which stores [out, in], then my W is your weight.T — same math, I'm just declaring the storage layout to match the math.)
B = 32 global batch
d = 8 d_model
f = 24 d_ff
X : [B, d] = [32, 8] input
W1 : [d, f] = [ 8, 24] up-proj, H = gelu(X @ W1) : [32, 24]
W2 : [f, d] = [24, 8] down-proj, Z = H @ W2 : [32, 8]
Mesh: dp = 4, tp = 2 (8 devices)
| tensor | global | after TP split | after FSDP split (at-rest per device) |
|---|---|---|---|
X |
[32, 8] |
replicated | [8, 8] — local batch, full features |
W1 |
[8, 24] |
[8, 12] (split f) |
[2, 12] (split d) |
W2 |
[24, 8] |
[12, 8] (split f) |
[3, 8] (split f/tp) |
H |
[32, 24] |
[8, 12] |
— (activation, not FSDP'd) |
Z |
[32, 8] |
[8, 8] partial → AR |
— |
Every number is now unique. If you see 12 you know it's f/tp. If you see 8 it's either d or B/dp — those do match here (8 = 8), sorry, unavoidable without making the numbers ugly. Mentally tag them by position: leading dim = batch, trailing = features.
W1 [8, 24] tiled onto the mesh — at restEach block below is the [8, 24] matrix: 8 rows, 24 cols. █ = the piece this device stores.
rows: d=8, split 4 ways by FSDP ↕
cols: f=24, split 2 ways by TP ↔
tp=0 tp=1
G0 ████████████············ G4 ············████████████ ← rows 0:2
························ ························
dp=0 ························ ························
························ ························
G1 ························ G5 ························
████████████············ ············████████████ ← rows 2:4
dp=1 ························ ························
························ ························
G2 ························ G6 ························
························ ························
dp=2 ████████████············ ············████████████ ← rows 4:6
························ ························
G3 ························ G7 ························
························ ························
dp=3 ························ ························
████████████············ ············████████████ ← rows 6:8
└── cols 0:12 ──┘└ 12:24 ┘ └── 0:12 ──┘└ cols 12:24 ┘
Each device stores a [2, 12] tile. Superimpose all eight → exact cover, no overlap.
FSDP all-gather ↕ fills in a column: G0 ends up with the full [8, 12] left half.
gelu(X @ W1) — H [32, 24] on the meshLocal op on each device is [8, 8] @ [8, 12] → [8, 12]. Inner d=8 is fully local (just materialized by FSDP), outer dims come from X and W1's sharded axes respectively:
rows: B=32, split 4 ways by dp ↕
cols: f=24, split 2 ways by tp ↔
tp=0 tp=1
G0 ████████████············ G4 ············████████████ H[ 0:8, 0:12] | H[ 0:8, 12:24]
dp=0 ························ ························
························ ························
························ ························
G1 ························ G5 ························
dp=1 ████████████············ ············████████████ H[ 8:16, 0:12] | H[ 8:16, 12:24]
························ ························
························ ························
G2 ························ G6 ························
dp=2 ························ ························
████████████············ ············████████████ H[16:24, 0:12] | H[16:24, 12:24]
························ ························
G3 ························ G7 ························
dp=3 ························ ························
························ ························
████████████············ ············████████████ H[24:32, 0:12] | H[24:32, 12:24]
Clean 2D tiling, every element of H held once. Zero comms so far (the FSDP all-gather was prefetched before this layer started).
W2 [24, 8] and the partial sumW2 is split on its input dim f=24:
W2[0:12, :] — FSDP shards to [3, 8] per deviceW2[12:24, :] — sameAfter the FSDP all-gather of W2, the local matmul on each device is [8, 12] @ [12, 8] → [8, 8]. But look at the contraction:
$$Z_{bi} = \sum_{k=0}^{23} H_{bk}, W^{(2)}{ki} ;=; \underbrace{\sum{k=0}^{11} H_{bk}, W^{(2)}{ki}}{\text{tp=0 computes this}} ;+; \underbrace{\sum_{k=12}^{23} H_{bk}, W^{(2)}{ki}}{\text{tp=1 computes this}}$$
Same [b, i] footprint on both sides, each half-a-sum.
Z [32, 8] — rows split by dp, cols NOT split
tp=0 tp=1
G0 ▒▒▒▒▒▒▒▒ ←──+──→ G4 ▒▒▒▒▒▒▒▒ Σ k=0:12 + Σ k=12:24
dp=0 ········ ········
········ ········ same region of Z,
········ ········ disjoint halves of the sum
G1 ········ ←──+──→ G5 ········
dp=1 ▒▒▒▒▒▒▒▒ ▒▒▒▒▒▒▒▒
········ ········
········ ········
...
All-reduce ↔ per row of the mesh sums them, ▒ → █, and now Z is replicated on tp, sharded on dp — same layout as X in frame A. Loop closed.
Sequence parallelism aka context parallelism (e.g. in DS Ulysses): partition the input sequence. (paper)

Doesn’t partition any model memory, so more similar to data parallelism (pure compute parallelism / chopping up of inputs and activations), but with the communication of tensor parallelism for attention (since all-to-all communication needed)
Ring self-attention algorithm to compute attention by passing around keys then values in a circle



Actually overcomputes (explainer)—see red not needed below

Ring attention
The use of a ring topology for computing self-attention has also been studied in prior work [21] but it incurs non-overlapped communication overheads similar to sequence parallelism, making it infeasible for large context sizes …


Striped attention
Flash decoding (blog post)
MoE: mixture of experts (foundations)
More on MoE routing schemes
MoE expert parallelism
How parallel groups work, worked out
Offload, ZeRO Offload