Helpful context:


GPT-3 has 175 billion parameters. In FP16, that’s 350 GB - about five high-end A100 GPUs just to store the weights. Training requires gradients and optimizer states on top of that: roughly 2 - 3 TB total. No single machine has this memory.

The solution isn’t better hardware. A100s already push the limits of what’s physically possible with current chip technology. The solution is distributed computation: split the model across many GPUs, and coordinate them so they collectively behave as a single model.

Building and training these “sharded” transformers is one of the hardest engineering problems in machine learning. The mathematics involves linear algebra (how to split matrices), communication theory (how to synchronize across machines), and numerical methods (how to do it without losing precision). This post covers all three.


The Memory Problem, Precisely

Before discussing solutions, let’s quantify the problem.

For a model with $N$ parameters, the memory required for training in full precision is:

Component Memory Why
Model weights $4N$ bytes FP32, 4 bytes each
Gradients $4N$ bytes One gradient per parameter
Adam first moment $m$ $4N$ bytes Running mean of gradients
Adam second moment $v$ $4N$ bytes Running mean of squared gradients
Total $16N$ bytes

For GPT-3 with $N = 175 \times 10^9$: $16 \times 175 \times 10^9 = 2.8$ TB.

Training in mixed precision (FP16/BF16 for activations and gradients, FP32 for optimizer states):

Component Memory
FP16 weights $2N$ bytes
FP16 gradients $2N$ bytes
FP32 master weights $4N$ bytes
FP32 Adam states $8N$ bytes
Total $16N$ bytes

Same total. Mixed precision saves memory on activations (not shown) but not on optimizer states.

For 175B parameters in mixed precision: still 2.8 TB. Requiring roughly 35 A100-80GB GPUs at minimum, before accounting for activations.


Data Parallelism

The simplest form of distributed training: each GPU holds a complete copy of the model, but different GPUs process different batches of data.

Procedure:

  1. Split the global batch of data into $P$ mini-batches.
  2. GPU $i$ processes mini-batch $i$ using its local model copy.
  3. Each GPU computes its local gradient $g_i$.
  4. All-reduce the gradients: $\bar{g} = \frac{1}{P}\sum_{i=1}^P g_i$. Every GPU receives the averaged gradient.
  5. Each GPU updates its model: $\theta \leftarrow \theta - \eta\bar{g}$.

After the all-reduce, all GPUs have identical parameters. The effect is mathematically equivalent to training on a batch $P$ times larger than a single GPU’s batch.

Limitation: the model must fit on a single GPU. Works for 7B models on A100-80GB. Doesn’t help for 175B.

Communication cost: the all-reduce step transmits $4N$ bytes (the gradients) per step, at a bandwidth cost of $2 \times \frac{N(P-1)}{P} \approx 2N$ parameters transmitted per GPU per step (for large $P$). For $N = 7 \times 10^9$ in FP32: 28 GB per step. This is why NVLink (within a node, 600 GB/s) and InfiniBand (across nodes, 200 - 400 GB/s) matter.


Tensor Parallelism

When the model is too large for one GPU, you split individual layers across multiple GPUs. This is tensor parallelism (TP), also called model parallelism (Shoeybi et al., 2019, Megatron-LM style).

Consider a linear layer: $y = xW$ where $x \in \mathbb{R}^{b \times d}$ (batch × hidden) and $W \in \mathbb{R}^{d \times k}$.

Column splitting: partition $W$ along its columns across $P$ GPUs:

$$W = [W_1 \mid W_2 \mid \cdots \mid W_P], \quad W_i \in \mathbb{R}^{d \times (k/P)}.$$

Each GPU $i$ computes $y_i = xW_i \in \mathbb{R}^{b \times (k/P)}$. To get the full output $y = [y_1 \mid y_2 \mid \cdots \mid y_P]$, all-gather across GPUs.

Row splitting: partition $W$ along its rows:

$$W = \begin{pmatrix} W_1 \\ W_2 \\ \vdots \\ W_P \end{pmatrix}, \quad W_i \in \mathbb{R}^{(d/P) \times k}.$$

Each GPU needs the corresponding shard of $x$: $x_i \in \mathbb{R}^{b \times (d/P)}$. Computes $y_i = x_i W_i \in \mathbb{R}^{b \times k}$. The full output is $y = \sum_{i=1}^P y_i$ - an all-reduce.

For attention layers: split by heads. If the model has $H$ heads and $P$ GPUs, each GPU handles $H/P$ heads. This is a natural split because heads are independent: head $i$’s computation doesn’t depend on head $j$’s.

Communication pattern: TP requires an all-reduce at every layer. For a model with $L$ layers and batch size $b$, that’s $L$ all-reduces per forward pass, each transmitting $b \times d$ activations. This communication-per-layer requirement means TP works best within a single node where GPUs are connected by NVLink ($\sim 600$ GB/s), not across nodes where bandwidth is limited.

In practice: use TP=8 within a single 8-GPU node.


Pipeline Parallelism

Instead of splitting within layers (TP), split across layers (PP). Assign a consecutive block of transformer layers to each GPU:

  • GPU 0: layers $1, \ldots, L/P$
  • GPU 1: layers $L/P + 1, \ldots, 2L/P$
  • $\vdots$
  • GPU $P-1$: layers $(P-1)L/P + 1, \ldots, L$

Each GPU processes its block and sends activations to the next GPU via the “pipeline.”

The bubble problem: with a single micro-batch, GPU 0 computes, then GPU 1, then GPU 2, and so on. During GPU 1’s forward pass, GPU 0 is idle. During GPU 0’s backward pass, GPU $P-1$ is idle. The total idle time is the “bubble.”

Bubble fraction (fraction of pipeline time wasted):

$$\text{bubble} = \frac{P-1}{m + P - 1},$$

where $m$ is the number of micro-batches the global batch is split into. For $m = 1$: bubble = $(P-1)/P$ - as $P$ increases, almost all time is wasted. For large $m$: bubble $\approx (P-1)/m \to 0$.

GPipe (Huang et al., 2019) introduced the micro-batch pipelining strategy: split the batch into $m$ micro-batches. Feed them into the pipeline sequentially. Overlap the computation of later micro-batches with the idle time of earlier ones:

$$\text{Time} \approx (m + P - 1) \times t_{\text{micro}},$$

where $t_{\text{micro}}$ is the time per micro-batch per stage. For $m \gg P$, efficiency approaches 100%.

Interleaved scheduling (Narayanan et al., 2021): each GPU holds multiple non-consecutive layer blocks (interleaved, not contiguous), allowing more overlap. Reduces bubble further at the cost of more complex scheduling.

Communication cost: only activations at stage boundaries - $b \times d$ floats between consecutive GPUs, once per micro-batch. This is much lower bandwidth than TP’s per-layer all-reduces, making PP suitable for across-node communication.


ZeRO: Zero Redundancy Optimizer

In standard data parallelism, every GPU stores the full optimizer state. For $P$ GPUs, you’re replicating the 16N bytes of per-parameter state $P$ times. ZeRO (Rajbhandari et al., 2020) eliminates this redundancy by partitioning across GPUs.

ZeRO Stage 1: Partition optimizer states across $P$ GPUs. Each GPU stores $1/P$ of the Adam states (first and second moments). Memory per GPU:

$$\text{Weights}: 4N \quad \text{Grads}: 4N \quad \text{Optimizer}: 8N/P \quad \text{Total}: 8N + 8N/P.$$

At $P = 64$: roughly halved.

ZeRO Stage 2: Also partition gradients. Memory per GPU:

$$\text{Weights}: 4N \quad \text{Grads}: 4N/P \quad \text{Optimizer}: 8N/P \quad \text{Total}: 4N + 12N/P.$$

At large $P$: dominated by $4N$ - just the weights.

ZeRO Stage 3: Also partition the weights. Memory per GPU:

$$\text{Weights}: 4N/P \quad \text{Grads}: 4N/P \quad \text{Optimizer}: 8N/P \quad \text{Total}: 16N/P.$$

At $P = 64$ GPUs, ZeRO-3 reduces memory per GPU by $64\times$ compared to standard DP. For 175B parameters in FP32: $2.8$ TB total becomes $43$ GB per GPU - fitting on a single A100.

Communication overhead of ZeRO-3: during the forward pass, each GPU needs the weights for the current layer (even though it only stores $1/P$ of them). So ZeRO-3 performs all-gathers to reconstruct weight shards on the fly during forward and backward passes, then discards them. This adds communication volume, but overlapped with compute, the overhead is typically 10 - 30%.

DeepSpeed (Microsoft) is the standard implementation of ZeRO. PyTorch’s FSDP (Fully Sharded Data Parallel) is PyTorch’s native equivalent of ZeRO-3.


3D Parallelism

For the largest models, you combine all three strategies:

$$\text{GPUs} = P_{\text{TP}} \times P_{\text{PP}} \times P_{\text{DP}}.$$

Example (GPT-3 scale):

  • $P_{\text{TP}} = 8$: tensor parallelism within each 8-GPU node (NVLink)
  • $P_{\text{PP}} = 16$: pipeline parallelism across 16 nodes
  • $P_{\text{DP}} = 64$: data parallelism across 64 model replicas

Total: $8 \times 16 \times 64 = 8{,}192$ GPUs.

Each “layer block” in the pipeline is split across 8 GPUs via tensor parallelism. Multiple complete model copies run in parallel for data parallelism. The three strategies are orthogonal and compose cleanly.

Choosing the right configuration is empirical. The optimal split depends on model architecture (number of layers, hidden dimension), cluster topology (inter/intra-node bandwidth), batch size, and sequence length.


Activation Checkpointing

Training requires storing intermediate activations from the forward pass for use during the backward pass. For a transformer with sequence length $n$, hidden dimension $d$, and $L$ layers, this is $O(n \times d \times L)$ - which can easily exceed the model weights in memory for long sequences or large batches.

Activation checkpointing (also called gradient checkpointing): don’t store all intermediate activations. Instead, only store activations at “checkpoint” boundaries (e.g., every $k$ layers). During the backward pass, recompute the non-stored activations by re-running the corresponding forward pass segment.

Memory-compute tradeoff: with $L$ layers and checkpoints every $\sqrt{L}$ layers:

  • Memory: $O(\sqrt{L} \times n \times d)$ instead of $O(L \times n \times d)$
  • Extra compute: roughly $33%$ more forward passes (you re-run each segment once during backward)

For a 70B model with 80 layers and 4096 context: activation memory without checkpointing can exceed 100 GB. With checkpointing: under 20 GB. The extra 33% compute is almost always worth it at this scale.


Mixed Precision Training

Training in FP32 is safe but memory-intensive. FP16 saves memory but has numerical issues: values below $\sim 6 \times 10^{-5}$ underflow to zero (gradient vanishing), and values above $\sim 65{,}504$ overflow.

Mixed precision training (Micikevicius et al., 2018) keeps the benefits of both:

  1. Maintain a master copy of weights in FP32 for the optimizer update.
  2. At each forward pass, cast weights to FP16. Run forward and backward passes in FP16.
  3. Scale the loss by a large factor $S$ before backward. This prevents FP16 gradients from underflowing.
  4. Divide gradients by $S$ before the optimizer step.
  5. Update the FP32 master weights with the FP32 gradients.

Loss scaling: if typical gradient values are around $10^{-4}$, multiply the loss by $S = 2^{16}$, making gradients $\sim 6.5$ - well within FP16 range. After the backward pass, divide by $2^{16}$ to recover true gradients. Dynamic loss scaling: if overflow is detected (NaN/Inf gradients), halve $S$ and skip the update. If no overflow for a while, double $S$.

BF16 (Brain Float 16): 1 sign bit, 8 exponent bits, 7 mantissa bits (vs. FP16’s 1 sign, 5 exponent, 10 mantissa). BF16 has the same dynamic range as FP32 (8 exponent bits), making it immune to the overflow problem that plagues FP16. Less precision than FP16 in the mantissa, but empirically this barely matters for training. Used by TPUs from the beginning; supported on NVIDIA GPUs since Ampere (A100).

Modern large model training almost universally uses BF16 + FP32 master weights, without dynamic loss scaling.


Communication Primitives

Distributed training is built on a small set of collective communication operations.

All-reduce: each GPU $i$ has a tensor $x_i$. After all-reduce, every GPU has $\sum_i x_i$. Total data transmitted per GPU: $2N(P-1)/P \approx 2N$ values. Used in data parallelism (gradient averaging) and tensor parallelism (summing partial activations).

All-gather: each GPU $i$ has a shard $x_i$ of some partitioned tensor. After all-gather, every GPU has the full tensor $[x_1, x_2, \ldots, x_P]$. Total data transmitted per GPU: $N(P-1)/P \approx N$ values. Used in ZeRO-3 (gathering weight shards for computation).

Reduce-scatter: each GPU $i$ has $x_i$. The result: GPU $i$ receives $(\sum_j x_j)_i$ - the $i$-th shard of the summed tensor. Total data: $N(P-1)/P \approx N$ values. Used in ZeRO-3 (accumulating gradient shards) and as one half of a ring all-reduce.

Ring all-reduce = reduce-scatter followed by all-gather. Each step transmits $N/P$ values; $P-1$ steps total. Very bandwidth-efficient: total transmitted is $2N(P-1)/P$ regardless of $P$.

Bandwidth vs. latency: all-reduce is latency-sensitive (you can’t proceed until all GPUs synchronize). At the scale of thousands of GPUs, even microseconds of latency per operation accumulate. High-bandwidth interconnects (NVLink within node, InfiniBand between nodes) are essential.

Discomfort check. These parallelism strategies interact in complex ways that theory doesn’t fully capture. TP requires synchronization at every layer; PP requires careful pipeline scheduling; ZeRO-3 adds communication during weight gather. The “correct” configuration for a given model and cluster is highly empirical - determined by profiling, not analysis. At frontier scale, the engineering of the training run - parallelism strategy, micro-batch size, gradient checkpointing, communication overlap - can make a 3 - 5× difference in training throughput. This is one reason that training frontier models is so hard to replicate even with equivalent hardware.


The Software Stack

You don’t need to implement any of this from scratch. The standard libraries:

  • PyTorch DDP (torch.nn.parallel.DistributedDataParallel): data parallelism. Well-tested, easy to use.
  • FSDP (torch.distributed.fsdp): PyTorch’s native ZeRO-3. As of PyTorch 2.x, the recommended tool for large model training.
  • DeepSpeed: Microsoft’s library implementing ZeRO-1/2/3, with many optimizations. Standard for training models up to 100B parameters on many clusters.
  • Megatron-LM: NVIDIA’s framework for TP + PP + DP on large models. Used to train the Megatron-Turing NLG model and many others.
  • Alpa: automatic parallelism search - finds the optimal TP/PP/DP split for a given model and cluster configuration.

For most practitioners: use FSDP or DeepSpeed ZeRO-3 for models that don’t fit on a single GPU. Add pipeline parallelism (Megatron-style) for models in the 70B+ range across multiple nodes.


Summary

Strategy What’s Split Communication Best For
Data Parallelism Data (not model) All-reduce gradients ($2N$/step) Model fits on 1 GPU
Tensor Parallelism Individual layers (columns/rows) All-reduce per layer Within a node (NVLink)
Pipeline Parallelism Layers (consecutive blocks) Activations at boundaries Across nodes; latency-tolerant
ZeRO-1 Optimizer states All-reduce with reduce-scatter Replace DP optimizer overhead
ZeRO-3 Weights + grads + optimizer All-gather weights on the fly Very large models
3D Parallelism All of the above Combined Frontier-scale training
Technique Purpose Tradeoff
Mixed precision (BF16) Halve memory, maintain range Slight mantissa precision loss
Loss scaling Prevent FP16 gradient underflow Extra bookkeeping
Activation checkpointing Reduce activation memory 10 - 100× $\sim 33%$ extra compute
Continuous batching GPU utilization in inference Scheduling complexity
PagedAttention KV cache memory efficiency Implementation complexity

The barrier to training large language models is not primarily mathematical - the attention mechanism, the transformer architecture, the optimization algorithms are all well-understood. The barrier is engineering: managing memory, minimizing communication, overlapping compute and I/O, handling hardware failures at scale.

Understanding the concepts in this post - why you need tensor parallelism, what ZeRO-3 actually does, why pipeline bubbles cost you throughput - puts you in a position to read the systems literature directly. The papers (Megatron, DeepSpeed, vLLM, FSDP) are dense but assume exactly this background.


You have now seen the full ML systems stack from fine-tuning (LoRA) through alignment (RLHF, DPO) through inference (KV caching) to distributed training (sharding). The primary literature is directly accessible from here: Megatron-LM, DeepSpeed ZeRO, vLLM, and the original LoRA, InstructGPT, and DPO papers are all self-contained given this background.


Read Next: