Prerequisite:


ResNets Are Discretized ODEs

A standard residual network (ResNet) updates its hidden state as

$$\mathbf{x}_{t+1} = \mathbf{x}_t + f(\mathbf{x}_t,, \theta_t)$$

where $\mathbf{x}_t \in \mathbb{R}^d$ is the hidden state at layer $t$ and $f(\cdot, \theta_t)$ is the residual block at layer $t$. Compare this to Euler’s method applied to the ODE

$$\dot{\mathbf{x}}(t) = f(\mathbf{x}(t),, \theta(t))$$

with step size $h = 1$:

$$\mathbf{x}(t + h) \approx \mathbf{x}(t) + h,f(\mathbf{x}(t),, \theta(t)).$$

A ResNet with $L$ layers corresponds to $L$ Euler steps on a vector field $f$. This observation - that depth in ResNets mimics time in a dynamical system - motivated the Neural ODE framework.


Neural ODEs

Chen et al. (2018) proposed replacing the discrete layer stack entirely with a continuous ODE:

$$\frac{d\mathbf{h}(t)}{dt} = f_\theta(\mathbf{h}(t),, t), \quad \mathbf{h}(0) = \mathbf{x}$$

The output of the network is $\mathbf{h}(T)$ for some terminal time $T$, obtained by solving this IVP with any black-box ODE solver. The number of “effective layers” is not fixed in advance - the solver adaptively chooses how many function evaluations to use.

This yields continuous-depth networks: instead of $L$ discrete transformations, we have a flow in $\mathbb{R}^d$ parameterized by a neural network $f_\theta$.

The Adjoint Method for Backpropagation

Training requires computing $\frac{\partial \mathcal{L}}{\partial \theta}$ where $\mathcal{L}$ is a scalar loss on $\mathbf{h}(T)$. Naively differentiating through all solver steps requires storing intermediate states - $O(L)$ memory for an $L$-step solver. The adjoint method avoids this.

Define the adjoint $\lambda(t) = \frac{\partial \mathcal{L}}{\partial \mathbf{h}(t)}$. By differentiating the ODE constraint, $\lambda$ satisfies its own ODE run backwards in time:

$$\frac{d\lambda}{dt} = -\lambda(t)^T \frac{\partial f}{\partial \mathbf{h}}(\mathbf{h}(t),, \theta,, t).$$

This is the adjoint ODE. The boundary condition is $\lambda(T) = \frac{\partial \mathcal{L}}{\partial \mathbf{h}(T)}$.

The gradient with respect to parameters is accumulated along the backward pass:

$$\frac{\partial \mathcal{L}}{\partial \theta} = -\int_T^0 \lambda(t)^T \frac{\partial f}{\partial \theta}(\mathbf{h}(t),, \theta,, t),dt.$$

The key insight: we solve three ODEs in a single backward pass - one for $\mathbf{h}(t)$ (re-integrating forward if needed), one for $\lambda(t)$, and one accumulating $\frac{\partial \mathcal{L}}{\partial \theta}$. Memory cost is $O(1)$ in the number of solver steps (we only need the current state, not the entire trajectory).


ODE Solvers as Layers

Euler’s Method

$$\mathbf{y}_{n+1} = \mathbf{y}_n + h,f(t_n,, \mathbf{y}_n)$$

Theorem (Euler Error). The global truncation error of Euler’s method is $O(h)$: if $f$ is Lipschitz and the exact solution has bounded second derivative, then $|\mathbf{y}_n - \mathbf{y}(t_n)| \leq C h$ for a constant $C$ depending on the problem but not on $h$.

One function evaluation per step. Simple but requires small $h$ for accuracy.

Runge-Kutta 4 (RK4)

RK4 takes four function evaluations per step:

$$k_1 = f(t_n,, \mathbf{y}_n)$$ $$k_2 = f!\left(t_n + \tfrac{h}{2},, \mathbf{y}_n + \tfrac{h}{2}k_1\right)$$ $$k_3 = f!\left(t_n + \tfrac{h}{2},, \mathbf{y}_n + \tfrac{h}{2}k_2\right)$$ $$k_4 = f!\left(t_n + h,, \mathbf{y}n + h,k_3\right)$$ $$\mathbf{y}{n+1} = \mathbf{y}_n + \frac{h}{6}(k_1 + 2k_2 + 2k_3 + k_4)$$

Theorem (RK4 Error). RK4 is a fourth-order method: the global truncation error satisfies $|\mathbf{y}_n - \mathbf{y}(t_n)| = O(h^4)$.

Compared to Euler ($O(h)$), halving $h$ reduces the RK4 error by a factor of 16 vs. 2 for Euler - a dramatic difference.

Adaptive Solvers (Dormand-Prince / DOPRI5)

Adaptive solvers estimate the local error by comparing solutions of different orders (e.g., 4th and 5th order) and automatically adjust the step size $h$ to maintain a user-specified tolerance. This is how scipy.integrate.solve_ivp and torchdiffeq’s dopri5 work.

Fixed-step:   [t0]----[t1]----[t2]----[t3]--> T
               h        h       h       h

Adaptive:     [t0]--[t1]--------[t2]-[t3]--> T
              small h  large h   small h
              (steep) (smooth)  (steep)

Step size chosen to keep local error < tolerance.

Connection to Normalizing Flows: FFJORD

A normalizing flow transforms a simple base density $p_0(\mathbf{z}_0)$ into a complex one $p(\mathbf{x})$ via an invertible map $\mathbf{x} = g(\mathbf{z}_0)$. The change of variables formula gives

$$\log p(\mathbf{x}) = \log p_0(\mathbf{z}_0) - \log\left|\det \frac{\partial g}{\partial \mathbf{z}_0}\right|.$$

FFJORD (Grathwohl et al., 2018) defines $g$ as the flow of a Neural ODE. The log-density evolves continuously via the instantaneous change of variables formula:

$$\frac{\partial \log p(\mathbf{z}(t),, t)}{\partial t} = -\operatorname{tr}!\left(\frac{\partial f}{\partial \mathbf{z}}\right).$$

Computing $\operatorname{tr}(\partial f / \partial \mathbf{z})$ exactly costs $O(d^2)$. FFJORD uses the Hutchinson trace estimator: $\operatorname{tr}(A) = \mathbb{E}_\epsilon[\epsilon^T A\epsilon]$ for $\epsilon \sim \mathcal{N}(0, I)$, reducing cost to $O(d)$ per step.


Hamiltonian Neural Networks

A Hamiltonian system conserves a scalar $H(\mathbf{q}, \mathbf{p})$ (the Hamiltonian, typically energy) through Hamilton’s equations:

$$\dot{q}_i = \frac{\partial H}{\partial p_i}, \quad \dot{p}_i = -\frac{\partial H}{\partial q_i}.$$

Hamiltonian Neural Networks (Greydanus et al., 2019) learn $H_\theta$ from data and define the dynamics via Hamilton’s equations. This guarantees that the learned system conserves $H_\theta$ exactly - something a generic neural ODE cannot ensure.

The corresponding phase space trajectories are confined to level sets of $H$:

        p
        |      H = E2
        |    .-~~~-.
        |  .'       '.   H = E1
        | /   .---.   \ .-~~~-.
        |/  .'     '.  X       |
--------+--+----+----+--+------+--> q
        |   '.     .'  X       |
        |     '---'   / '---'-'
        |              '     '
        |               '---'
Level curves of H(q,p) = E are invariant.

Lagrangian Neural Networks

Alternatively, Lagrangian Neural Networks (Cranmer et al., 2020) learn a Lagrangian $\mathcal{L}(\mathbf{q}, \dot{\mathbf{q}})$ and derive dynamics from the Euler-Lagrange equations:

$$\frac{d}{dt}\frac{\partial \mathcal{L}}{\partial \dot{q}_i} - \frac{\partial \mathcal{L}}{\partial q_i} = 0.$$

Both frameworks embed known physical structure - energy conservation or the principle of least action - into the network architecture, improving generalization and long-run simulation accuracy.


Stability and Stiffness

A differential equation is stiff when the solution contains components that decay at vastly different rates. The classic example is

$$y' = -1000y + 3000 - 2000e^{-t}, \quad y(0) = 0.$$

The exact solution has a slow component $\sim e^{-t}$ and a fast transient $\sim e^{-1000t}$. Explicit methods like Euler or RK4 require $h < 2/1000 = 0.002$ for stability (even though the interesting behavior is on scale $t \sim 1$). Implicit methods (e.g., backward Euler, implicit Runge-Kutta) are stable for any $h$ because they solve for $\mathbf{y}_{n+1}$ implicitly.

Stiffness ratio $= |\lambda_{\max}| / |\lambda_{\min}|$ of the Jacobian $\partial f / \partial \mathbf{y}$ quantifies severity. Neural ODE models with stiff dynamics incur huge computational cost from adaptive solvers taking tiny steps.

Formal Theorem (Stability of Euler). For the test equation $y' = \lambda y$ with $\lambda \in \mathbb{C}$, Euler’s method $y_{n+1} = (1 + h\lambda)y_n$ is stable if and only if $|1 + h\lambda| \leq 1$, i.e., $h\lambda$ lies in the unit disk centered at $-1$.


Examples

torchdiffeq (Python/PyTorch) provides odeint and odeint_adjoint, supporting Euler, RK4, DOPRI5, and implicit solvers. Usage pattern:

from torchdiffeq import odeint_adjoint as odeint

def f(t, h):          # h: (batch, d), t: scalar
    return net(h, t)  # any nn.Module

h_T = odeint(f, h0, t_span, method='dopri5')

diffrax (Python/JAX) provides JIT-compiled, differentiable ODE solvers with similar interface but JAX-native vectorization and XLA compilation.

Solver selection guide:

  • Smooth, non-stiff: dopri5 (adaptive RK4/5), good default.
  • Stiff systems: kvaerno3 or radau (implicit Runge-Kutta).
  • Fixed-step for speed: euler or rk4 with known good step size.
  • Normalizing flows (FFJORD): dopri5 with adjoint_method='dopri5'.

The adjoint method cuts memory from $O(N_{\text{steps}} \cdot d)$ to $O(d)$, making Neural ODEs practical even when the solver takes thousands of steps through a complex vector field.


Read Next: