Why TPUs (systolic arrays) very fast at matmuls: animation, paper

Certain compute like sorting on tpu’s is bad
| TPUv1 | TPUv2 | TPUv3 | TPUv4[17][19] | TPUv5e[20] | TPUv5p[21] [22] | v6e (Trillium)[23][24] |
|---|---|---|---|---|---|---|
| Date introduced | 2015 | 2017 | 2018 | 2021 | 2023 | 2023 |
| Process node | 28 nm | 16 nm | 16 nm | 7 nm | Unstated | Unstated |
| Die size (mm2) | 331 | < 625 | < 700 | < 400 | 300-350 | Unstated |
| On-chip memory (MiB) | 28 | 32 | 32 | 32 | 48 | 112 |
| Clock speed (MHz) | 700 | 700 | 940 | 1050 | Unstated | 1750 |
| Memory | 8 GiB DDR3 | 16 GiB HBM | 32 GiB HBM | 32 GiB HBM | 16 GB HBM | 95 GB HBM |
| Memory bandwidth | 34 GB/s | 600 GB/s | 900 GB/s | 1200 GB/s | 819 GB/s | 2765 GB/s |
| TDP (W) | 75 | 280 | 220 | 170 | Not Listed | Not Listed |
| TOPS (Tera Operations Per Second) | 23 | 45 | 123 | 275 | 197 (bf16) | |
| 393 (int8) | 459 (bf16) | |||||
| 918 (int8) | ||||||
| TOPS/W | 0.31 | 0.16 | 0.56 | 1.62 | Not Listed | Not Listed |
Questions
Try
jax.profiler.save_device_memory_profile(): very confusing to interpret