Prerequisite:


Neural Networks as Function Compositions

A feedforward neural network with $L$ layers is a composition of functions:

$$f = f_L \circ f_{L-1} \circ \cdots \circ f_1$$

Each layer $f_l: \mathbb{R}^{n_{l-1}} \to \mathbb{R}^{n_l}$ typically takes the form $f_l(a) = \sigma(W^{(l)} a + b^{(l)})$, where $W^{(l)} \in \mathbb{R}^{n_l \times n_{l-1}}$ is the weight matrix, $b^{(l)} \in \mathbb{R}^{n_l}$ is the bias, and $\sigma$ is a nonlinearity applied elementwise.

The training objective for a single example $(x, y)$ is

$$\mathcal{L}(\theta) = \ell(f(x;\theta),, y)$$

where $\ell$ is a loss function (cross-entropy, MSE, etc.) and $\theta = (W^{(1)}, b^{(1)}, \ldots, W^{(L)}, b^{(L)})$ collects all parameters. Training requires $\nabla_\theta \mathcal{L}$, and computing this efficiently is the entire point of backpropagation.


The Jacobian of a Composition

Definition. For a differentiable function $F: \mathbb{R}^m \to \mathbb{R}^k$, the Jacobian at $x \in \mathbb{R}^m$ is the $k \times m$ matrix

$$[J_F(x)]_{ij} = \frac{\partial F_i}{\partial x_j}(x)$$

Theorem (Chain Rule for Jacobians). If $g: \mathbb{R}^m \to \mathbb{R}^n$ and $f: \mathbb{R}^n \to \mathbb{R}^k$ are differentiable, then

$$J_{f \circ g}(x) = J_f(g(x)) \cdot J_g(x)$$

where the right-hand side is matrix multiplication of a $k \times n$ matrix with an $n \times m$ matrix.

For a network with layers $a^{(0)} = x,; a^{(l)} = f_l(a^{(l-1)})$, the Jacobian of the full network output with respect to the input is

$$J_f(x) = J_{f_L}(a^{(L-1)}) \cdot J_{f_{L-1}}(a^{(L-2)}) \cdots J_{f_1}(a^{(0)})$$

a product of $L$ Jacobians. Computing $\nabla_\theta \mathcal{L}$ requires differentiating through this product with respect to the entries of each weight matrix.


Automatic Differentiation: Two Modes

There are two fundamental algorithms for computing Jacobians of compositions. Both are correct implementations of the chain rule; they differ only in the order of multiplications.

Forward Mode AD

Propagate a tangent vector $v \in \mathbb{R}^n$ (where $n$ is the input dimension) forward through the network. At each layer, compute

$$\dot{a}^{(l)} = J_{f_l}(a^{(l-1)}),\dot{a}^{(l-1)}$$

where $\dot{a}^{(0)} = v$. The result $\dot{a}^{(L)}$ is the Jacobian-vector product $J_f(x),v$: one column of $J_f$ if $v = e_j$. Computing the full Jacobian requires $n$ forward passes.

Cost: $O(n)$ passes, each $O(L \cdot d^2)$ operations. Total: $O(nLd^2)$, where $d$ is the typical layer width.

Reverse Mode AD (Backpropagation)

Propagate a cotangent vector $u \in \mathbb{R}^k$ (where $k$ is the output dimension) backward. At each layer, compute

$$\bar{a}^{(l-1)} = (J_{f_l}(a^{(l-1)}))^T,\bar{a}^{(l)}$$

where $\bar{a}^{(L)} = u$. The result $\bar{a}^{(0)}$ is the vector-Jacobian product $u^T J_f(x)$: one row of $J_f$ if $u = e_i$.

For $\mathcal{L}: \mathbb{R}^n \to \mathbb{R}$ (scalar loss), setting $u = 1$ gives $\nabla_x \mathcal{L} = J_\mathcal{L}(x)^T$ in a single backward pass.

Cost: $O(1)$ backward passes for the gradient of a scalar, each $O(L \cdot d^2)$ operations. Total: $O(Ld^2)$.

The comparison is stark: for a network with millions of parameters (large $n$) and scalar output, reverse mode is $O(n)$ times faster than forward mode. This is why backpropagation is the algorithm of choice for neural network training.


The Backpropagation Algorithm

Forward pass. For $l = 1, \ldots, L$:

$$z^{(l)} = W^{(l)} a^{(l-1)} + b^{(l)}, \qquad a^{(l)} = \sigma(z^{(l)})$$

Store the pre-activations $z^{(l)}$ and activations $a^{(l)}$ for use in the backward pass.

Output error. Compute the gradient of the loss with respect to the final pre-activation:

$$\delta^{(L)} = \nabla_{a^{(L)}} \ell \odot \sigma'(z^{(L)})$$

(Here $\odot$ denotes elementwise multiplication. If $\ell$ uses logits directly, the expression may simplify.)

Backward pass. For $l = L-1, \ldots, 1$:

$$\delta^{(l)} = \left((W^{(l+1)})^T \delta^{(l+1)}\right) \odot \sigma'(z^{(l)})$$

The term $(W^{(l+1)})^T \delta^{(l+1)}$ is the transpose-Jacobian product, propagating the error signal back through the linear part. The $\odot \sigma'(z^{(l)})$ accounts for the elementwise nonlinearity.

Parameter gradients:

$$\nabla_{W^{(l)}} \mathcal{L} = \delta^{(l)} (a^{(l-1)})^T, \qquad \nabla_{b^{(l)}} \mathcal{L} = \delta^{(l)}$$

Data flow diagram:

Forward pass:
a^(0) -> [W^(1), b^(1)] -> z^(1) -> sigma -> a^(1) -> ... -> a^(L) -> loss

Backward pass:
delta^(L) <- grad loss
delta^(l) <- (W^(l+1))^T delta^(l+1) * sigma'(z^(l))
grad W^(l) <- delta^(l) (a^(l-1))^T

Theorem (Correctness of Reverse Mode AD). The backpropagation algorithm computes $\nabla_\theta \mathcal{L}$ exactly (in exact arithmetic). This follows directly from the chain rule for Jacobians applied in reverse order.


Gradient Flow and Vanishing/Exploding Gradients

The backward pass propagates $\delta^{(l)}$ through a product of Jacobians:

$$\delta^{(1)} = \left(\prod_{l=2}^{L} (W^{(l)})^T \text{diag}(\sigma'(z^{(l-1)}))\right) \delta^{(L)}$$

The behavior of this product depends on the eigenvalues (singular values) of each factor:

  • If $|J_{f_l}|_{\text{op}} < 1$ for all $l$, the product shrinks exponentially: vanishing gradients. Parameters in early layers receive negligible gradient signal and fail to learn.
  • If $|J_{f_l}|_{\text{op}} > 1$ for all $l$, the product grows exponentially: exploding gradients. Training becomes numerically unstable.
Gradient norm across layers (illustration):

Layer: L  L-1  L-2  ...  2    1

Vanishing:
|grad|: 1  0.5  0.25 ...  ~0   ~0   <- early layers learn nothing

Exploding:
|grad|: 1   2    4   ...  2^L  2^L  <- instability, NaN

This analysis motivates:

  • ReLU activations: $\sigma'(z) = \mathbf{1}[z>0]$, which does not saturate for positive inputs.
  • Residual connections (ResNets): $f_l(a) = a + g_l(a)$, so the Jacobian is $I + J_{g_l}$, ensuring gradients of at least 1 along the skip path.
  • Gradient clipping: cap $|\delta^{(l)}|$ during the backward pass.
  • Careful initialization: Xavier/He initialization scales initial weights so the Jacobian at initialization has singular values near 1.

Batch Gradients and Microbatching

In practice, the loss is computed over a minibatch ${(x_i, y_i)}_{i=1}^B$:

$$\mathcal{L}{\text{batch}}(\theta) = \frac{1}{B}\sum{i=1}^B \ell(f(x_i;\theta),, y_i)$$

By linearity of differentiation, $\nabla_\theta \mathcal{L}{\text{batch}} = \frac{1}{B}\sum_i \nabla\theta \ell_i$. Each example contributes an independent backward pass; in practice these are computed simultaneously by treating the batch as a matrix and leveraging BLAS routines.

Gradient accumulation (microbatching) processes sub-batches of size $B/k$ and accumulates gradients before applying an optimizer step. This achieves the effect of a large batch size with limited GPU memory, since only $B/k$ activations need to be stored at once.


Examples

PyTorch Autograd. Every tensor operation records itself in a computation graph (a directed acyclic graph of Function nodes). Calling .backward() on a scalar traverses this graph in reverse topological order, applying the vector-Jacobian product at each node. This is reverse mode AD implemented dynamically (eagerly).

JAX jvp and vjp. JAX exposes forward mode as jax.jvp(f, primals, tangents) and reverse mode as jax.vjp(f, *primals). Composing these gives higher-order derivatives: jax.jacfwd(jax.jacrev(f)) gives the Hessian efficiently. JAX uses functional transformations: the AD rules are composable with vmap (vectorization), pmap (parallelization), and jit (XLA compilation).

XLA and Compiler Optimization. Under jit, JAX traces the computation to an XLA computation graph, which the XLA compiler can fuse, tile, and optimize for GPU/TPU hardware. The backward pass is generated symbolically by differentiating through the XLA graph, allowing cross-operation fusion that is impossible in eager mode.

Gradient Checkpointing. The backward pass requires the stored activations $a^{(l)}$ from the forward pass. For deep networks, storing all activations is memory-prohibitive. Gradient checkpointing recomputes activations during the backward pass (trading compute for memory), reducing memory from $O(L)$ to $O(\sqrt{L})$ at the cost of roughly one additional forward pass.


Read Next: