Prerequisite:


The Internal Covariate Shift Problem

Deep networks suffer from a fundamental training instability: as parameters in early layers update, the distribution of inputs to later layers continuously shifts. Ioffe and Szegedy (2015) coined this internal covariate shift - each layer must constantly re-adapt to a moving input distribution, slowing learning and forcing very low learning rates. Normalization methods address this by standardizing layer inputs, stabilizing the distributions that each layer receives.

Batch Normalization

Given a mini-batch $\mathcal{B} = {x_1, \ldots, x_m}$ of activations at a given layer, batch normalization standardizes over the batch dimension. For each feature index $j$:

$$\mu_j^{\mathcal{B}} = \frac{1}{m}\sum_{i=1}^m x_{ij}, \qquad (\sigma_j^{\mathcal{B}})^2 = \frac{1}{m}\sum_{i=1}^m (x_{ij} - \mu_j^{\mathcal{B}})^2$$

The normalized activation is:

$$\hat{x}{ij} = \frac{x{ij} - \mu_j^{\mathcal{B}}}{\sqrt{(\sigma_j^{\mathcal{B}})^2 + \epsilon}}$$

where $\epsilon$ is a small constant for numerical stability. This would fix activations to zero mean and unit variance, which can hurt representational power - so two learnable parameters $\gamma_j$ and $\beta_j$ rescale and shift:

$$y_{ij} = \gamma_j \hat{x}_{ij} + \beta_j$$

With appropriate $\gamma_j = \sigma_j$ and $\beta_j = \mu_j$, the network can learn to undo normalization if beneficial, giving BatchNorm no representational cost.

Inference Behavior

At test time, single examples cannot form a batch statistic. BatchNorm maintains running estimates of the mean and variance using an exponential moving average over training mini-batches:

$$\mu_j^{\text{run}} \leftarrow (1 - \alpha)\mu_j^{\text{run}} + \alpha,\mu_j^{\mathcal{B}}$$

These running statistics are frozen at inference, making the transformation a simple affine map per feature - efficient and deterministic.

Limitations

BatchNorm’s dependence on the batch mean and variance is its Achilles' heel. For small batch sizes, the batch statistics are noisy estimates of the true statistics, destabilizing training. It also behaves differently at train and test time, which can cause subtle bugs. It is ineffective for recurrent networks (statistics differ at each time step) and for online learning or reinforcement learning where batch sizes of 1 are common.

Layer Normalization

Layer normalization (Ba et al., 2016) normalizes over the feature dimension rather than the batch dimension. For a single example $x \in \mathbb{R}^d$:

$$\mu = \frac{1}{d}\sum_{j=1}^d x_j, \qquad \sigma^2 = \frac{1}{d}\sum_{j=1}^d (x_j - \mu)^2$$

$$\hat{x}_j = \frac{x_j - \mu}{\sqrt{\sigma^2 + \epsilon}}, \qquad y_j = \gamma_j \hat{x}_j + \beta_j$$

Each sample is normalized independently - there is no cross-sample dependence. This makes LayerNorm:

  • Identical at train and test time (no running statistics needed)
  • Effective for batch size 1
  • The standard choice for Transformers, where sequences vary in length and batch statistics would be noisy

Group Normalization and Instance Normalization

Group normalization (Wu and He, 2018) is a compromise between BatchNorm and LayerNorm. It divides the $d$ features into $G$ groups and normalizes within each group independently. When $G = 1$ it becomes LayerNorm; when $G = d$ it becomes Instance Normalization.

Instance normalization normalizes over the spatial dimensions $(H, W)$ for each sample and each channel independently - commonly used in style transfer, where style statistics should be decoupled per sample.

RMSNorm

Root Mean Square Layer Normalization (Zhang and Sennrich, 2019) simplifies LayerNorm by removing the mean subtraction step:

$$\text{RMS}(x) = \sqrt{\frac{1}{d}\sum_{j=1}^d x_j^2}$$

$$\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma$$

This re-centering-free formulation is slightly cheaper to compute (one fewer sum and subtraction), and the empirical evidence suggests that the mean subtraction step in LayerNorm contributes little to training stability. RMSNorm is the normalization layer used in LLaMA and many recent large language models. The parameter count is halved compared to LayerNorm since $\beta$ is dropped, though this is rarely the bottleneck.

Why Normalization Helps

Loss landscape smoothing. Normalization makes the loss surface significantly smoother - formally, it reduces the Lipschitz constant of the gradient, meaning the curvature is more uniform and the gradient is a more reliable indicator of the loss landscape. This allows much larger learning rates without divergence.

Higher learning rates. Without normalization, large learning rates can cause activations to saturate or diverge due to the compounding of changes across layers. Normalization decouples the scale of weights from the scale of activations, letting the optimizer use aggressive step sizes.

Slight regularization effect. BatchNorm adds noise to gradients because batch statistics are stochastic estimates. This acts as a weak regularizer, similar in spirit to dropout. This effect vanishes for large batch sizes, explaining why models sometimes generalize worse with very large batches.

Pre-Norm vs Post-Norm in Transformers

The original “Attention Is All You Need” paper uses Post-Norm: LayerNorm is applied after the residual addition, i.e., $x \leftarrow \text{LayerNorm}(x + \text{Sublayer}(x))$.

Modern large language models overwhelmingly use Pre-Norm: LayerNorm is applied before the sublayer, i.e., $x \leftarrow x + \text{Sublayer}(\text{LayerNorm}(x))$.

The practical advantage of Pre-Norm is significant: gradients flow through the residual branch without passing through LayerNorm, giving a direct gradient path from the output to the input. This allows training much deeper networks without warm-up schedules or gradient clipping as critical dependencies. Post-Norm requires careful learning rate warm-up because the normalization is in the gradient path, which can cause instability in early training.

Examples

Training stability without vs. with BatchNorm. Training a deep convolutional network (20+ layers) without normalization typically requires very small learning rates (e.g., $10^{-4}$) and careful weight initialization. With BatchNorm, the same network trains stably at learning rates 10-100x larger ($10^{-3}$ to $10^{-2}$), converges faster in wall-clock time, and is less sensitive to initialization. The loss curve without BatchNorm is erratic with occasional spikes; with BatchNorm it is smooth and monotonically decreasing.

Why transformers switched to Pre-LN. The original BERT and GPT models used Post-Norm and required careful learning rate warm-up (typically 4,000-10,000 steps of linear warm-up) to avoid early divergence. When GPT-2 and later models adopted Pre-Norm, training stability improved substantially - models could be trained without warm-up or with very short warm-up, and scaling to hundreds of layers became feasible. The intuition is that Pre-Norm ensures the residual stream always has a well-conditioned gradient path, regardless of depth.


Read Next: