BatchNorm & LayerNorm - Keeping Activations From Exploding or Vanishing
Helpful context:
- Bias, Variance & Overfitting - The Three-Way Tradeoff You Can’t Escape
- Gradients & Partial Derivatives - Slopes in Every Direction at Once
In 2014, training a 20-layer neural network was a research achievement. You needed hand-tuned learning rates, carefully chosen weight initialization schemes, and even then, training could take months and fail unpredictably. Batch normalization, introduced by Ioffe and Szegedy in 2015, changed that almost overnight. Networks that previously could not be trained converged in days. Learning rates that would have caused gradient explosions became stable. Batch normalization was not a small improvement - it made a qualitatively different class of models trainable.
LayerNorm came shortly after, solving the specific problems batch normalization introduced for sequential models. Between them, these two techniques underlie essentially every modern deep learning system you will encounter.
The Problem: Internal Covariate Shift
To understand why normalization helps, you first need to understand what makes deep networks hard to train.
Consider a 10-layer network. Each layer takes input from the previous layer, applies weights, and passes the result forward. Now start training: the optimizer updates the weights in layers 1 through 9. But each update to layer $k$ changes the distribution of activations that layer $k+1$ receives as input. Layer 5’s job today is not the same as layer 5’s job tomorrow, because layers 1-4 have changed.
This phenomenon is called internal covariate shift. Each layer must continuously re-adapt to a moving target distribution. The deeper the network, the worse this gets: small changes in early layers produce compounding distributional shifts in later layers. Training becomes slow and unstable.
Before batch normalization, practitioners combated this with two tools:
Weight initialization: Xavier initialization sets initial weights to have variance $1/n_{\text{in}}$, keeping activations and gradients well-scaled at the start of training. He initialization uses $2/n_{\text{in}}$ for ReLU networks. Both are designed to prevent activation distributions from collapsing to zero or exploding at initialization, but they only help at the start.
Small learning rates: With well-initialized weights, a small learning rate ensures each update is small enough that later layers can adapt. But small learning rates mean slow training - you need many more gradient steps to converge.
Both tools address the symptom (unstable activations) rather than the cause (the absence of any mechanism to stabilize distributions during training). Batch normalization addresses the cause directly.
Batch Normalization
The idea: after each layer’s linear transformation, normalize the activations so that they have zero mean and unit variance. Do this within each mini-batch.
Formally: consider a mini-batch of $B$ samples, and a specific feature (neuron) $j$ in some layer. The batch statistics for this feature are:
$$\mu_j = \frac{1}{B}\sum_{i=1}^{B} x_{ij}, \qquad \sigma_j^2 = \frac{1}{B}\sum_{i=1}^{B}(x_{ij} - \mu_j)^2.$$
Normalize each activation:
$$\hat{x}_{ij} = \frac{x_{ij} - \mu_j}{\sqrt{\sigma_j^2 + \varepsilon}},$$
where $\varepsilon \approx 10^{-5}$ prevents division by zero. After this step, each feature in the batch has mean 0 and variance 1.
Then apply a learnable scale and shift:
$$y_{ij} = \gamma_j \hat{x}_{ij} + \beta_j.$$
The parameters $\gamma_j$ (scale) and $\beta_j$ (shift) are learned jointly with the rest of the network by backpropagation. This step is critical. Without it, batch normalization would force every layer’s output to have mean 0 and variance 1, which could be actively harmful - the network cannot learn to use a different scale or mean even when that would improve performance. The learnable parameters give the network the option to undo the normalization entirely (by setting $\gamma_j = \sigma_j$, $\beta_j = \mu_j$), or to use any other scale and shift it finds useful.
The full batch normalization layer applies this procedure to every feature independently.
At Inference
During training, the mean and variance are computed from the current mini-batch. At inference, you typically have a single example (or a batch with a different size), so you can’t use batch statistics. Instead, batch normalization maintains running statistics: an exponential moving average of the batch means and variances accumulated during training. At inference, these running statistics are used in place of batch statistics.
This is the correct approach, but it introduces a subtle distinction between training and inference behavior that can cause bugs if you forget to set the model to eval mode.
Why BatchNorm Helps
The intuition is that by normalizing activations, you are standardizing the input distribution to each layer. Layer 5 always receives activations with approximately the same mean and variance, regardless of what layers 1-4 are doing. This is the reduction in internal covariate shift that Ioffe and Szegedy originally described.
But the benefits go further:
Reduced sensitivity to initialization. With batch normalization, the exact choice of weight initialization matters much less. The normalization step counteracts poorly scaled initial activations, so training is stable even with less-careful initialization.
Higher learning rates. With stable activation distributions, you can use much larger learning rates without causing gradient explosion. Higher learning rates mean faster convergence.
Implicit regularization. The batch mean and variance are computed from a subset of the data and are therefore noisy estimates of the true population statistics. This noise acts as regularization - similar in spirit to dropout. Models trained with batch normalization often require less dropout.
Smoother loss landscape. This is the theoretical explanation that has held up best. Batch normalization makes the loss function more Lipschitz-smooth: the gradient changes slowly as you move through parameter space. A smooth landscape can be traversed with larger steps, which is why larger learning rates work. The formal argument is due to Santurkar et al. (2018), whose paper showed that the internal covariate shift explanation is not quite right - the real benefit is the smoothing of the loss landscape.
Discomfort check. Why does normalization help theoretically? The original explanation was internal covariate shift, and that is still a useful intuition. But Santurkar et al. showed that removing internal covariate shift was neither necessary nor sufficient for the improvement - you could reintroduce covariate shift after batch normalization and still get the training benefits. The actual mechanism involves the loss landscape becoming smoother in a precise mathematical sense (better Lipschitz constants for the gradient), which allows larger learning rates and more stable optimization. This was discovered empirically, and the theory followed the observation. Many of the most important techniques in deep learning were found this way.
Problems with Batch Normalization
Batch normalization is powerful but has real limitations.
Batch size dependence. The statistics $\mu_j$ and $\sigma_j^2$ are estimated from a mini-batch. With batch size 1, this collapses: the mean is just the single value, and the variance is zero. With batch size 2, the estimates are unreliable. Batch normalization requires moderately large batches (typically $\geq 16$, preferably $\geq 32$) to work well.
Train-inference discrepancy. Batch statistics during training and running statistics during inference are different quantities. If the running statistics are poorly estimated (e.g., due to non-stationary training data), inference behavior diverges from training behavior. This is a source of subtle, hard-to-debug failures.
Sequences. For sequential data, batch normalization’s behavior is problematic. A batch of sequences may have different lengths; padding and masking interact badly with batch statistics. More fundamentally, normalizing across the batch dimension means the normalization for token $t$ in sequence 1 is influenced by what happens at position $t$ in every other sequence in the batch. This cross-contamination is semantically wrong for language models.
Online learning and very small datasets. If you are processing one example at a time (online learning), or if your dataset is so small that you can’t form meaningful batches, batch normalization doesn’t work.
These limitations directly motivated layer normalization.
Layer Normalization
Layer normalization (Ba et al. 2016) makes a different normalization choice: instead of normalizing across the batch for each feature, normalize across features within each sample.
For a single sample $x \in \mathbb{R}^d$ (e.g., a single layer’s activation vector), compute:
$$\mu = \frac{1}{d}\sum_{j=1}^{d} x_j, \qquad \sigma^2 = \frac{1}{d}\sum_{j=1}^{d}(x_j - \mu)^2.$$
Normalize:
$$\hat{x}_j = \frac{x_j - \mu}{\sqrt{\sigma^2 + \varepsilon}}.$$
Scale and shift (with learnable $\gamma, \beta \in \mathbb{R}^d$):
$$y_j = \gamma_j \hat{x}_j + \beta_j.$$
The key difference: normalization is computed within a single sample, using all features. This means:
- Batch size 1 works perfectly. Statistics are computed within the sample, not across the batch.
- Same behavior at train and inference. There are no running statistics, no train-inference discrepancy.
- Works for sequences. Each token in a sequence is normalized independently using its own feature statistics.
Layer normalization is the standard normalization in transformers. Every modern large language model - GPT, BERT, Llama, and their successors - uses layer normalization, not batch normalization.
The cost: for convolutional networks on images, batch normalization is typically better. Layer normalization normalizes across feature channels, which can discard spatial information that batch normalization preserves.
RMSNorm
RMSNorm is a simplified variant of layer normalization that omits mean centering entirely. Instead of normalizing by the standard deviation, normalize by the root-mean-square:
$$\hat{x}_j = \frac{x_j}{\sqrt{\frac{1}{d}\sum_{k=1}^{d} x_k^2 + \varepsilon}}.$$
Then apply a learnable scale $\gamma_j$ (no bias term).
Why drop the mean centering? Empirically, it often doesn’t hurt performance, and it reduces compute and memory. It’s also slightly simpler to implement correctly. RMSNorm is used in Llama, PaLM, and many other recent models. As the field moves toward extreme computational efficiency, simpler normalization schemes become attractive.
Training Stability Techniques: Logit Softcapping, Z-Loss, QK-Norm
At frontier scale, attention logits and output logits can grow unboundedly during training, causing loss spikes and instability. Three techniques address this at different points in the forward pass.
Z-loss. Add a regularization term to the standard cross-entropy loss that penalizes large logit magnitudes:
$$\mathcal{L}_{\text{z}} = \alpha \cdot \mathbb{E}\left[\log^2 Z\right]$$
where $Z = \sum_i e^{z_i}$ is the softmax partition function and $\alpha$ is a small coefficient (typically $10^{-4}$ to $10^{-5}$). When logits are large, $Z$ is large, $\log Z$ is large, and $\mathcal{L}_z$ penalizes this. Z-loss adds a gradient signal that keeps logits from drifting to large values, stabilising the softmax distribution. It has negligible impact on task performance when $\alpha$ is small.
Logit softcapping. Rather than penalising large logits via loss, softcapping bounds them directly in the forward pass using a smooth, differentiable compression:
$$z_{\text{capped}} = s \cdot \tanh\left(\frac{z}{s}\right)$$
where $s$ is the cap threshold (e.g., $s = 50$ for attention logits, $s = 30$ for the final language model head). Unlike hard clipping (which has zero gradient at the boundary), $\tanh$ has nonzero gradient everywhere, so training signal still flows through the cap. Values well within $[-s, s]$ are unaffected; values outside are smoothly compressed.
Gemma 2 and Gemma 3 apply softcapping to both attention logits (pre-softmax) and the final LM head, which substantially reduces loss spikes. One caveat: logit softcapping is incompatible with FlashAttention’s fused kernel during training (which assumes standard attention). Training must use eager attention; inference can still use FlashAttention with negligible quality difference.
QK-Norm. Apply LayerNorm to query and key vectors before computing attention:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{\text{LN}(Q)\text{LN}(K)^T}{\sqrt{d_k}}\right)V$$
By normalising $Q$ and $K$, the dot product $QK^T$ is bounded, preventing attention logit explosion without any additional loss term or forward-pass modification. The drawback: LayerNorm strips the magnitude from $Q$ and $K$, which carries useful information (a large magnitude indicates high confidence). For long-context tasks, QK-norm has been found to hurt performance by de-emphasising relevant tokens - the normalisation removes exactly the signal that distinguishes attended-to from ignored positions.
Which to use. Logit softcapping (Gemma-style) is currently the most widely adopted in frontier models - it is effective, adds no loss term overhead, and is well-understood. Z-loss is a simpler alternative with less forward-pass modification. QK-norm is an option if you want bounded attention without modifying the loss or forward pass, but avoid it for long-context tasks.
Pre-Norm vs. Post-Norm
Where you apply normalization in the network architecture matters.
The original transformer (Vaswani et al. 2017) used post-norm: apply layer normalization after the residual addition.
$$x \to \text{LayerNorm}(x + \text{Sublayer}(x)).$$
This worked for the original transformer but was observed to cause instability when scaling to very deep models.
Pre-norm applies layer normalization before the sublayer (before attention or feedforward), inside the residual branch:
$$x \to x + \text{Sublayer}(\text{LayerNorm}(x)).$$
Pre-norm has become the standard in modern architectures (GPT-2 and later) because it produces more stable gradients during training for very deep networks. The intuition: in post-norm, gradients flow through the normalization on every backward pass, which can distort them. In pre-norm, the residual stream is unnormalized, giving gradients a direct path through the network.
Group Norm and Instance Norm
For completeness: there are other normalization variants for specific use cases.
Group normalization (Wu & He 2018) divides channels into groups and normalizes within each group. It is independent of batch size (like layer norm) but preserves more spatial structure than layer norm (useful for image models). It works well for object detection and segmentation tasks where batch sizes must be small due to memory constraints.
Instance normalization normalizes each sample’s each channel independently. It is used in style transfer networks, where removing instance-specific statistics helps transfer artistic styles between images.
Connection to Transformer Architecture
Every modern transformer uses either layer normalization or RMSNorm. Understanding normalization is not optional for reading the architecture literature.
The standard transformer block looks like this with pre-norm:
$$x' = x + \text{Attention}(\text{LayerNorm}(x)),$$ $$x'' = x' + \text{FFN}(\text{LayerNorm}(x')).$$
The layer norms ensure that the inputs to attention and feedforward layers are well-scaled, regardless of how large the residual stream activations have grown. Without normalization, the residual additions would cause activations to grow without bound through the depth of the network.
Training a 100-layer transformer without normalization is essentially impossible with standard optimizers. With pre-norm layer normalization, it is routine.
Summary
| Method | Normalizes Over | Batch Size 1? | Train = Inference? | Typical Use |
|---|---|---|---|---|
| BatchNorm | Batch (per feature) | No | No (running stats) | CNNs, image models |
| LayerNorm | Features (per sample) | Yes | Yes | Transformers, LLMs |
| RMSNorm | Features, no mean | Yes | Yes | Llama, modern LLMs |
| GroupNorm | Groups of features | Yes | Yes | Object detection |
| InstanceNorm | Per channel per sample | Yes | Yes | Style transfer |
Batch normalization made training deep networks practical in 2015. Layer normalization made training deep sequence models practical from 2016 onward. RMSNorm is making training slightly more efficient now. The underlying problem - keeping activations well-scaled as signals flow through many layers - has not changed. The solutions are still evolving.
Read next: