Fréchet inception distance (FID) score: lower is better
DALL-E 3, openAI 2023
Stable Diffusion 3, Stability 2024
Diffusion transformer architecture, but with separate MLPs for text and image embeddings, but feeding into same attn. Reminiscent of MoEs (no routing).
Uses “flow matching” new general approach
Rectified flow for straight paths
Start from ODE
They formulate all other design points within the same framework
Fortunately it’s also one of the simpler models, compared to EDM (Karras) and cosine (Nichol) and log-normal and uniform (actually this is the simplest)
Classic flows are curvy, need to take more steps than if the path was straighter:
Trained on ImageNet, CC12M. Recaptioned with CogVLM. Deduped with faiss.
Autoencoder increased hidden dimensions—important since it’s an upper bound on how good your generated images are
Pretrain on low res, fine tune on high res, same as in SD2
Pretrain → fine tune → DPO on 128 captions from Partiprompt (simple but realistic captions → high-quality images)
Comprehensive paper did grid search over design space
Diffusion Transformers (DiT) 2023
VQVAE
ControlNet, 2023
Train with pair data, e.g. edges, depth maps, normal maps, p ose, style, etc.
DDIM
Can thus choose different sampling steps/skip steps
Altogether:
Also parameterizes the stochasticity (readded noise) with $\sigma$. $\eta$ is knob to control this: 0 means $\sigma=0$, 1 means $\sigma$ is its original DDPM value.
Code—notice how we parameterize the skip size:
def ddim_step(x_t, t, noise, abar_t, abar_t1, bbar_t, bbar_t1, eta):
vari = ((bbar_t1/bbar_t) * (1-abar_t/abar_t1))
sig = vari.sqrt()*eta
x_0_hat = ((x_t-bbar_t.sqrt()*noise) / abar_t.sqrt())
x_t = abar_t1.sqrt()*x_0_hat + (bbar_t1-sig**2).sqrt()*noise
if t>0: x_t += sig * torch.randn(x_t.shape).to(x_t)
return x_t
@torch.no_grad()
def sample(f, model, sz, n_steps, skips=1, eta=1.):
tsteps = list(reversed(range(0, n_steps, skips)))
x_t = torch.randn(sz).to(model.device)
preds = []
for i,t in enumerate(progress_bar(tsteps)):
abar_t1 = abar[tsteps[i+1]] if t > 0 else torch.tensor(1)
noise = model(x_t,t).sample
x_t = f(x_t, t, noise, abar[t], abar_t1, 1-abar[t], 1-abar_t1, eta)
preds.append(x_t.float().cpu())
return preds
Classifier free guidance
GLIDE
VAE
VQ-VAE
UNet: just a popular CNN architecture with wide resnet
Google Imagen, 2022
Dall-E 2 aka unCLIP, OpenAI 2022
DALL-E, OpenAI 2021
Cascaded diffusion models, 2022
Stable Diffusion, Stability 2022