Quantization
Quantization aware training (QAT)
vs Post-training quantization (PTQ)
Simulate quantization (let’s say you’re just quantizing weights) during the forward pass while keeping the underlying weights in full precision. The computation continues in full precision, but using these quantized-then-dequantized values.
Example:
Step 1: Quantization
First, let's quantize our weight:
Step 2: Forward Pass
In the forward pass, we use the quantized weight: Output = Input * Quantized Weight Output = 1.5 * 0.7998 = 1.1997
Step 3: Backward Pass
Let's assume the gradient of the loss with respect to the output is 2.0.
Without STE: Gradient w.r.t quantized weight = Input * Output gradient = 1.5 * 2.0 = 3.0 Gradient w.r.t original weight = 0 (because quantization is not differentiable)
With STE: Gradient w.r.t original weight = Gradient w.r.t quantized weight = 3.0
Step 4: Weight Update
Now we update the original full-precision weight: New weight = 0.7364 - (Learning rate * Gradient) New weight = 0.7364 - (0.1 * 3.0) = 0.4364
import torch
import torch.nn as nn
import torch.nn.functional as F
class FakeQuantize(torch.autograd.Function):
@staticmethod
def forward(ctx, x, num_bits):
# Simulate quantization
scale = x.abs().max() / (2**(num_bits-1) - 1)
x_quant = torch.round(x / scale) * scale
x_quant = torch.clamp(x_quant, -x.abs().max(), x.abs().max())
return x_quant
@staticmethod
def backward(ctx, grad_output):
# Straight-Through Estimator
# Pass gradients straight through
return grad_output, None
fake_quantize = FakeQuantize.apply
class QuantizedLinear(nn.Module):
def __init__(self, in_features, out_features, num_bits=8):
super().__init__()
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
self.num_bits = num_bits
# Initialize weights here
def forward(self, input):
# Use fake quantized weights in forward pass
quantized_weight = fake_quantize(self.weight, self.num_bits)
return F.linear(input, quantized_weight)
# Usage in a model
model = nn.Sequential(
QuantizedLinear(784, 128, num_bits=8),
nn.ReLU(),
QuantizedLinear(128, 10, num_bits=8)
)
# In the training loop
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
for inputs, targets in dataloader:
optimizer.zero_grad()
outputs = model(inputs) # Forward pass uses quantized weights
loss = criterion(outputs, targets)
loss.backward() # Backward pass in full precision
optimizer.step() # Update weights in full precision
Pruning: requires HW support for sparse operations
Distillation: training on logits is more info than on labels
Misc engineering: fused kernels, etc.
Speculative decoding: use rejection sampling (with adjusted true sampling) + importance sampling
MultiLoRA: concurrently serve multiple adapters on same GPU
Examples
Inference systems
I wouldn't recommend using it for a batch serving setting today. One crucial optimization for batched serving (which you need if you have a large number of requests) is continual batching, which this implementation doesn't have.