• Quantization

    • Types
      • 8-bit / int8
      • 4-bit
    • Various algos, like GPTQ, AWQ, …
    • Check for HW support or no speedup
    • What's being quantized:
      • Weights: This is the most common target of quantization. The model's learned parameters are converted from higher precision to lower precision representations.
      • Activations: In some cases, the intermediate outputs (activations) during inference are also quantized.
      • Inputs: Input quantization is less common but can be part of an end-to-end quantized pipeline.
    • Floating-point to floating-point conversion: Converting from FP32 (32-bit floating-point) to FP16 (16-bit floating-point) or BF16 (16-bit brain floating-point) is often referred to as "precision reduction" rather than quantization. However, the line can be blurry, and some people might include this under the broader umbrella of quantization techniques.
  • Quantization aware training (QAT)

    • vs Post-training quantization (PTQ)

    • Simulate quantization (let’s say you’re just quantizing weights) during the forward pass while keeping the underlying weights in full precision. The computation continues in full precision, but using these quantized-then-dequantized values.

    • Example:

      • Original weight (full precision): 0.7364
      • Input value: 1.5
      • Learning rate: 0.1
      • 4-bit quantization range: [-1, 1] split into 16 levels

      Step 1: Quantization

      First, let's quantize our weight:

      • Quantization step: 2 / (2^4 - 1) = 2 / 15 ≈ 0.1333
      • Quantized weight: round(0.7364 / 0.1333) * 0.1333 = 6 * 0.1333 = 0.7998

      Step 2: Forward Pass

      In the forward pass, we use the quantized weight: Output = Input * Quantized Weight Output = 1.5 * 0.7998 = 1.1997

      Step 3: Backward Pass

      Let's assume the gradient of the loss with respect to the output is 2.0.

      Without STE: Gradient w.r.t quantized weight = Input * Output gradient = 1.5 * 2.0 = 3.0 Gradient w.r.t original weight = 0 (because quantization is not differentiable)

      With STE: Gradient w.r.t original weight = Gradient w.r.t quantized weight = 3.0

      Step 4: Weight Update

      Now we update the original full-precision weight: New weight = 0.7364 - (Learning rate * Gradient) New weight = 0.7364 - (0.1 * 3.0) = 0.4364

      import torch
      import torch.nn as nn
      import torch.nn.functional as F
      
      class FakeQuantize(torch.autograd.Function):
          @staticmethod
          def forward(ctx, x, num_bits):
              # Simulate quantization
              scale = x.abs().max() / (2**(num_bits-1) - 1)
              x_quant = torch.round(x / scale) * scale
              x_quant = torch.clamp(x_quant, -x.abs().max(), x.abs().max())
              return x_quant
      
          @staticmethod
          def backward(ctx, grad_output):
              # Straight-Through Estimator
              # Pass gradients straight through
              return grad_output, None
      
      fake_quantize = FakeQuantize.apply
      
      class QuantizedLinear(nn.Module):
          def __init__(self, in_features, out_features, num_bits=8):
              super().__init__()
              self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
              self.num_bits = num_bits
              # Initialize weights here
      
          def forward(self, input):
              # Use fake quantized weights in forward pass
              quantized_weight = fake_quantize(self.weight, self.num_bits)
              return F.linear(input, quantized_weight)
      
      # Usage in a model
      model = nn.Sequential(
          QuantizedLinear(784, 128, num_bits=8),
          nn.ReLU(),
          QuantizedLinear(128, 10, num_bits=8)
      )
      
      # In the training loop
      optimizer = torch.optim.Adam(model.parameters())
      criterion = nn.CrossEntropyLoss()
      
      for inputs, targets in dataloader:
          optimizer.zero_grad()
          outputs = model(inputs)  # Forward pass uses quantized weights
          loss = criterion(outputs, targets)
          loss.backward()  # Backward pass in full precision
          optimizer.step()  # Update weights in full precision
      
      
  • Pruning: requires HW support for sparse operations

  • Distillation: training on logits is more info than on labels

  • Misc engineering: fused kernels, etc.

    • Batching, continuous batching
    • Paged attention: simply have indirection in where chunks (pages) of KV caches live (a la normal paged memory), distinguishing physical vs logical pages. This means (1) you can allocate only what you actually use and overcommit memory instead of pessimistically pre-allocate the max ctx, and (2) you can share pages.
    • KV caching

    1_uyuyOW1VBqmF5Gtv225XHQ.gif

    • Speculative decoding: use rejection sampling (with adjusted true sampling) + importance sampling

      image.png

      image.png

      image.png

      • 2 papers both from DeepMind coincidentally published this at ~same time: https://arxiv.org/abs/2211.17192 and https://arxiv.org/abs/2302.01318
    • MultiLoRA: concurrently serve multiple adapters on same GPU

  • Examples

    • https://pytorch.org/blog/accelerating-generative-ai-2/ and https://github.com/pytorch-labs/gpt-fast
      • torch.compile
      • int8
      • speculative decoding
      • int4
      • tensor parallelism
  • Inference systems

    • DS Inference: Quentin says this has fastest numbers
    • https://github.com/pytorch-labs/gpt-fast: super fast demo
      • Horace said:

        I wouldn't recommend using it for a batch serving setting today. One crucial optimization for batched serving (which you need if you have a large number of requests) is continual batching, which this implementation doesn't have.

      • llama.cpp: within 20% of gpt-fast (source)

    • https://github.com/NVIDIA/TensorRT-LLM: proprietary. Requires reimpl. Replicate uses this currently (2024), NVIDIA, fast kernels, 170 tok/s vs gpt-fast’s 196
    • https://github.com/mlc-ai/mlc-llm: Replicate used this second, was within 20% of TRT
    • https://github.com/turboderp/exllamav2: Replicate started with this, one guy, fast kernels, but quality degradation that they never dug into
    • https://github.com/vllm-project/vllm: paged attention, cont batching