GPU vs TPU Architectures
Prerequisite: SIMD & Vectorization | How Computers Execute Programs
The reason large neural networks became tractable was not algorithmic - backpropagation was understood in the 1980s. It was hardware. The shift from training on CPUs to GPUs cut training times by 10-50x. Custom ASICs like Google’s TPUs cut them again. Understanding why requires looking at what these chips actually do and how their designs reflect different tradeoffs between latency, throughput, and flexibility.
CPU vs GPU Design Philosophy
A modern server CPU has 64–128 cores, each with large private caches (L1/L2/L3), out-of-order execution, branch prediction, and support for speculative execution. Every design decision optimizes for latency: finishing a single task as fast as possible. A CPU core can execute complex, irregular code with many branches and data-dependent decisions with minimal stall.
A GPU has thousands of small cores that are far simpler. An NVIDIA A100 has 6912 CUDA cores. Each core is dumb by CPU standards - no branch prediction, no out-of-order execution. The design optimizes for throughput: completing as many operations per second as possible across the entire chip. Cores execute in groups of 32 called warps, all running the same instruction simultaneously (Single Instruction, Multiple Thread - SIMT). If threads in a warp diverge (e.g., some take an if branch and some don’t), the warp serializes both paths. Irregular control flow is the enemy of GPU efficiency.
This makes GPUs ideal for matrix multiplication: thousands of multiply-accumulate operations on independent data, all the same instruction, no branching.
GPU Memory Hierarchy
Memory on a GPU is organized in tiers with very different characteristics:
- HBM (High Bandwidth Memory): The main GPU memory, physically stacked on the same package as the chip. The A100 has 80 GB of HBM at roughly 2 TB/s bandwidth. It’s fast by DRAM standards but slow compared to on-chip memory.
- Shared memory / L1 cache (SRAM): Each Streaming Multiprocessor (SM) has around 192 KB of SRAM shared among its threads. Access is roughly 100x faster than HBM but there’s very little of it - a 100 GPU has 108 SMs, so about 20 MB total.
- Registers: Private to each thread, sub-nanosecond access, limited quantity.
Efficient GPU code minimizes HBM traffic by staging data through shared memory. FlashAttention, for example, achieves its speedup almost entirely by restructuring the attention computation to be SRAM-resident rather than repeatedly reading from and writing to HBM.
Tensor Cores and Precision
CUDA cores handle scalar FP32 multiply-adds. Tensor Cores are dedicated hardware units for matrix multiply-accumulate operations at lower precision. The A100’s Tensor Cores operate in FP16, BF16, TF32, and INT8. In FP16 mode they deliver 312 TFLOPS - roughly 10x the throughput of FP32 CUDA cores.
BF16 has become the standard training precision for LLMs. It has the same dynamic range as FP32 (same 8-bit exponent) but only 7 mantissa bits, making overflow and underflow rare - a major practical advantage over FP16’s narrower exponent.
The Memory Bandwidth Bottleneck and Arithmetic Intensity
The A100 delivers 312 TFLOPS (FP16) but only 2 TB/s HBM bandwidth. If an operation reads X bytes and performs Y FLOPs, its arithmetic intensity is Y/X (FLOPs per byte). The A100’s peak ratio is 312e12 / 2e12 = 156 FLOPs/byte.
An operation with arithmetic intensity below 156 is memory-bound: the chip sits idle waiting for data. An operation above 156 is compute-bound: the memory subsystem keeps up and the cores are the bottleneck.
Matrix multiplication of large matrices has high arithmetic intensity - it’s compute-bound and Tensor Cores shine. A simple elementwise operation like relu(x) reads each element and does one operation: arithmetic intensity of ~0.25 FLOP/byte, deeply memory-bound. This is why kernel fusion (combining multiple elementwise ops into one kernel) is so effective - you pay the memory bandwidth cost once for many FLOPs.
The Roofline model plots achievable performance as a function of arithmetic intensity. Below the memory bandwidth ceiling, performance grows linearly with arithmetic intensity. Above the compute ceiling, it’s flat.
TPU Architecture
Google’s TPUs (Tensor Processing Units) are ASICs - application-specific integrated circuits built entirely for one job: matrix multiplication in neural networks. The core of a TPU is a systolic array: a grid of multiply-accumulate units that pass data from cell to cell in a wave, enabling extremely high throughput for matrix multiplies with no memory bandwidth pressure during the computation itself.
TPUs are designed from the ground up for XLA (Accelerated Linear Algebra), Google’s compiler for tensor computations. XLA performs aggressive operator fusion, memory layout optimization, and hardware-aware scheduling. Code that goes through XLA (JAX and TensorFlow) benefits automatically. PyTorch support via torch_xla is more limited.
TPU v4 pods connect up to 4096 chips via a high-bandwidth custom interconnect. The inter-chip bandwidth within a pod is significantly higher than the InfiniBand fabric used to connect GPU nodes, making AllReduce operations faster.
Multi-GPU Interconnects: NVLink vs PCIe
Moving data between GPUs is a critical bottleneck for large model training. On a standard server, GPUs connect to the host CPU via PCIe at 16–32 GB/s. NVLink is NVIDIA’s direct GPU-to-GPU interconnect: within an H100 DGX node, NVLink provides 900 GB/s total bidirectional bandwidth. This is why gradient synchronization within a node is far faster than across nodes.
For cross-node communication, InfiniBand (400 Gb/s HDR or 800 Gb/s NDR) is the standard. Still 50–100x slower than on-node NVLink, which is why model and pipeline parallelism strategies minimize the data that must cross node boundaries.
Inference vs Training Hardware
Training requires high-throughput matrix multiply, large HBM for storing activations and optimizer state, and fast interconnects for gradient synchronization. GPUs and TPUs both work well.
Inference has different constraints. The model weights are fixed (no gradient storage), batch sizes are often small (adding latency), and the priority is low latency rather than peak throughput. This has led to specialized inference hardware: NVIDIA’s H100 with its transformer engine, Groq’s LPU (deterministic execution, minimal memory bandwidth), and Cerebras’s wafer-scale chip for models that don’t fit in GPU memory.
Examples
Roofline Analysis for Matrix Multiply
# A100 specs
hbm_bandwidth_bytes_per_sec = 2e12 # 2 TB/s
peak_flops = 312e12 # 312 TFLOPS FP16
ridge_point = peak_flops / hbm_bandwidth_bytes_per_sec # ~156 FLOP/byte
# Matrix multiply: M=N=K=4096, FP16
M = N = K = 4096
flops = 2 * M * N * K # ~137 GFLOPs
bytes_read = 2 * (M*K + K*N) * 2 # input matrices, FP16=2 bytes
bytes_write = M * N * 2 # output
total_bytes = bytes_read + bytes_write
arithmetic_intensity = flops / total_bytes # ~512 FLOP/byte
print(f"Arithmetic intensity: {arithmetic_intensity:.0f} FLOP/byte")
print(f"Ridge point: {ridge_point:.0f} FLOP/byte")
print("=> COMPUTE BOUND" if arithmetic_intensity > ridge_point else "=> MEMORY BOUND")
When to Use TPU vs GPU
| Criterion | GPU (A100/H100) | TPU v4 |
|---|---|---|
| Framework | PyTorch (first-class) | JAX, TensorFlow |
| Custom kernels | CUDA, Triton | Limited (XLA only) |
| Availability | AWS, Azure, GCP | GCP only |
| LLM training at scale | Good | Excellent (pod bandwidth) |
| Research/prototyping | Excellent | Harder to debug |
Memory Arithmetic for LLM Serving
# How much GPU memory does a 70B parameter model need?
params = 70e9
bytes_per_param_fp16 = 2
model_size_gb = params * bytes_per_param_fp16 / 1e9
print(f"Model weights: {model_size_gb:.0f} GB") # 140 GB
# A100 80GB can't fit it on one GPU
# Need at least 2 x A100 80GB (tensor parallelism)
gpus_needed = -(-model_size_gb // 80) # ceiling division
print(f"Minimum A100 80GB GPUs needed: {gpus_needed}") # 2
# In practice use 4 or 8 for KV cache headroom
The hardware determines what is possible. Understanding the architecture tells you whether your bottleneck is compute, memory bandwidth, or interconnect - and that tells you where to invest your optimization effort.
Read Next: Parallelism at Scale | Inside JAX & XLA