Building Sharded Transformers
Prerequisite:
A single GPU cannot hold or train the largest modern language models. GPT-3 at 175B parameters in fp16 requires 350 GB of memory - more than four A100-80GB cards for weights alone, before accounting for activations, gradients, and optimiser states. Distributed training across many devices requires careful partitioning of computation and data, and the choice of strategy determines whether training is communication-bound, memory-bound, or compute-bound.
Why Single-GPU Training Fails at Scale
The memory cost of training a model with $N$ parameters is roughly:
| Component | bytes (fp16 params, fp32 Adam) |
|---|---|
| Parameters | $2N$ |
| Gradients | $2N$ |
| Adam $m_t$ (fp32) | $4N$ |
| Adam $v_t$ (fp32) | $4N$ |
| Total | $\sim 16N$ |
For $N = 175 \times 10^9$, this is approximately 2.8 TB - far beyond any single accelerator. Even ignoring optimiser states, activations for a large batch accumulate to hundreds of gigabytes.
Data Parallelism
The simplest strategy: replicate the full model on $k$ GPUs, split the batch of size $B$ into $k$ micro-batches of size $B/k$, run forward and backward passes independently on each device, then average the gradients across devices.
The gradient aggregation is implemented with Ring-AllReduce. Devices form a ring; in $k-1$ rounds, each device sends a shard of its gradient to the next device in the ring while receiving a shard from the previous one. After $k-1$ rounds, every device holds a fully aggregated gradient for one shard. In $k-1$ more rounds, the aggregated shards are broadcast back. Total communication per device is:
$$C = 2 \cdot \frac{k-1}{k} \cdot N \cdot \text{bytes_per_param} \approx 2N \cdot \text{bytes}$$
Data parallelism does not reduce per-device memory for the model, only for the data. It scales well when the model fits on one device.
Tensor Parallelism
Tensor parallelism (Narayanan et al., Megatron-LM, 2021) shards the model’s weight matrices across devices.
Attention. With $h$ heads across $k$ GPUs, each GPU handles $h/k$ heads completely - their $W_Q$, $W_K$, $W_V$ projections, the attention computation, and $W_O$ rows. The result requires a single AllReduce to sum partial outputs across GPUs.
FFN. A two-layer FFN with weights $W_1 \in \mathbb{R}^{d \times 4d}$ and $W_2 \in \mathbb{R}^{4d \times d}$ is split column-wise for $W_1$ (each GPU holds $4d/k$ columns) and row-wise for $W_2$ (each GPU holds $4d/k$ rows). Each GPU computes its partial output $W_2^{(i)} \text{GeLU}(W_1^{(i)} x)$ independently; an AllReduce sums the results. One AllReduce per transformer sublayer.
The communication volume per layer is $O(Bsd)$ for a batch of size $B$, sequence length $s$, and hidden dimension $d$. Tensor parallelism is efficient when $Bsd$ is small relative to the computation saved by distributing the weight matrices, making it suitable for the large-hidden-dimension regime.
Pipeline Parallelism
Pipeline parallelism assigns different layers to different GPUs. GPU 0 holds layers 1 through $L/k$, GPU 1 holds layers $L/k + 1$ through $2L/k$, and so on. A forward pass flows through GPUs in sequence, with each GPU passing its output activations to the next.
GPipe
GPipe (Huang et al., 2019) splits the batch into $m$ micro-batches and processes them in a strict schedule: all $m$ micro-batches complete forward passes before any backward pass begins. The bubble - time wasted while some GPUs wait for work - is:
$$\text{bubble fraction} = \frac{k-1}{k + m - 1}$$
For large $m$, the bubble shrinks to $(k-1)/m$. The cost is that GPipe must store all intermediate activations from all micro-batches and all forward passes simultaneously, requiring $O(mk)$ activation memory.
1F1B Schedule
The 1-Forward-1-Backward schedule interleaves forward and backward micro-batches. Once GPU 0 finishes the forward pass for micro-batch 1 and passes it to GPU 1, GPU 0 begins the backward pass for micro-batch 1. This keeps all GPUs busy most of the time.
In the non-interleaved 1F1B variant, the bubble is $O((k-1)/m)$, the same as GPipe. In the interleaved variant (Narayanan et al., 2021), each GPU holds $v$ non-contiguous pipeline stages (chunks of layers). Communication volume increases by a factor of $v$, but the bubble shrinks to $O(1/(vm))$ - a significant improvement for large $k$.
ZeRO: Zero Redundancy Optimiser
ZeRO (Rajbhandari et al., DeepSpeed, 2020) eliminates the redundancy in data parallelism where every GPU stores identical copies of optimiser states, gradients, and parameters.
Stage 1 - Partition Optimiser States. Each GPU stores only $1/k$ of the Adam $m_t$ and $v_t$ tensors. After a backward pass, each GPU AllReduces gradients (as in standard data parallelism), then each GPU applies its shard of the optimiser update. Memory for optimiser states drops by $k\times$. Per-device memory: $12N/k + 4N$ bytes (optimiser states sharded, parameters and gradients not).
Stage 2 - Partition Gradients. Each GPU additionally stores only $1/k$ of the gradients. Gradients are reduced-scattered: each GPU accumulates only the gradient shard for which it owns the optimiser state. Memory for gradients drops by $k\times$. Per-device memory: $4N + 12N/k$ bytes.
Stage 3 - Partition Parameters. Each GPU stores only $1/k$ of the parameters themselves. Parameters needed for a given forward or backward pass are fetched via AllGather before computation and discarded afterwards. Per-device memory: $16N/k$ bytes - a factor of $k$ reduction vs standard data parallelism.
The communication overhead of Stage 3 is one AllGather and one ReduceScatter per layer per training step, approximately doubling communication volume vs Stage 1. In practice this is acceptable when bandwidth is sufficient and the memory reduction enables larger batch sizes.
Activation Checkpointing
All parallelism strategies must contend with the cost of storing activations for the backward pass. A standard transformer layer stores $O(Bsd)$ activations per layer; for $L$ layers this is $O(BsdL)$ total.
Activation checkpointing (gradient checkpointing) reduces this by recomputing activations during the backward pass rather than storing them. Only a subset of activations - the checkpoints - are saved; the remaining activations are recomputed from the nearest checkpoint when needed.
The optimal strategy for a sequence of $n$ layers with no constraint on the number of checkpoints places checkpoints at intervals of $\sqrt{n}$, reducing activation memory from $O(n)$ to $O(\sqrt{n})$ at the cost of one additional forward pass per checkpoint interval. In practice, checkpointing at every transformer layer (the standard in large-scale training) reduces activation memory by a factor of the number of sub-operations per layer (typically 5–10$\times$) at the cost of roughly 33% more compute.
3D Parallelism
At the scale of GPT-3 or larger, the most effective strategy combines all three forms of parallelism:
- Tensor parallelism within a node (NVLink provides the high bandwidth needed for frequent AllReduces).
- Pipeline parallelism across nodes (inter-node bandwidth is lower, so communication-sparse pipeline bubbles are preferable to frequent tensor-parallel AllReduces).
- Data parallelism across pipeline replicas, scaling to total GPU count.
For a cluster with $p$ nodes of $n$ GPUs each, a common configuration is $n$-way tensor parallelism, $k$-way pipeline parallelism, and $(p/k)$-way data parallelism, with ZeRO Stage 1 or 2 applied within the data-parallel group.
Examples
Megatron-LM parallelism strategy for GPT-3 scale. Megatron-LM training of a 530B parameter model (MT-NLG, Smith et al., 2022) used 8-way tensor parallelism within a node, 35-way pipeline parallelism across nodes, and 8-way data parallelism, on 560 A100 GPUs. Model FLOPs utilisation (MFU) reached approximately 30% - a significant achievement given the communication overhead.
Memory vs speed trade-offs. Enabling ZeRO Stage 3 on a 13B parameter model across 8 A100-80GB GPUs reduces per-GPU memory from 208 GB (infeasible) to 26 GB (feasible with room for activations), at the cost of approximately 15% throughput reduction due to AllGather communication. Activation checkpointing on top of this reduces peak memory by a further 5–8$\times$ at a 33% compute overhead, enabling batch sizes large enough to maintain high GPU utilisation.
Read Next: