Why TPUs (systolic arrays) very fast at matmuls: animation, paper
Certain compute like sorting on tpu’s is bad
TPUv1 | TPUv2 | TPUv3 | TPUv4[17][19] | TPUv5e[20] | TPUv5p[21] [22] | v6e (Trillium)[23][24] |
---|---|---|---|---|---|---|
Date introduced | 2015 | 2017 | 2018 | 2021 | 2023 | 2023 |
Process node | 28 nm | 16 nm | 16 nm | 7 nm | Unstated | Unstated |
Die size (mm2) | 331 | < 625 | < 700 | < 400 | 300-350 | Unstated |
On-chip memory (MiB) | 28 | 32 | 32 | 32 | 48 | 112 |
Clock speed (MHz) | 700 | 700 | 940 | 1050 | Unstated | 1750 |
Memory | 8 GiB DDR3 | 16 GiB HBM | 32 GiB HBM | 32 GiB HBM | 16 GB HBM | 95 GB HBM |
Memory bandwidth | 34 GB/s | 600 GB/s | 900 GB/s | 1200 GB/s | 819 GB/s | 2765 GB/s |
TDP (W) | 75 | 280 | 220 | 170 | Not Listed | Not Listed |
TOPS (Tera Operations Per Second) | 23 | 45 | 123 | 275 | 197 (bf16) | |
393 (int8) | 459 (bf16) | |||||
918 (int8) | ||||||
TOPS/W | 0.31 | 0.16 | 0.56 | 1.62 | Not Listed | Not Listed |
Questions
Try
jax.profiler.save_device_memory_profile()
: very confusing to interpret