Neural Networks as ODEs
Prerequisite:
ResNets Are Discretized ODEs
A standard residual network (ResNet) updates its hidden state as
$$\mathbf{x}_{t+1} = \mathbf{x}_t + f(\mathbf{x}_t,, \theta_t)$$
where $\mathbf{x}_t \in \mathbb{R}^d$ is the hidden state at layer $t$ and $f(\cdot, \theta_t)$ is the residual block at layer $t$. Compare this to Euler’s method applied to the ODE
$$\dot{\mathbf{x}}(t) = f(\mathbf{x}(t),, \theta(t))$$
with step size $h = 1$:
$$\mathbf{x}(t + h) \approx \mathbf{x}(t) + h,f(\mathbf{x}(t),, \theta(t)).$$
A ResNet with $L$ layers corresponds to $L$ Euler steps on a vector field $f$. This observation - that depth in ResNets mimics time in a dynamical system - motivated the Neural ODE framework.
Neural ODEs
Chen et al. (2018) proposed replacing the discrete layer stack entirely with a continuous ODE:
$$\frac{d\mathbf{h}(t)}{dt} = f_\theta(\mathbf{h}(t),, t), \quad \mathbf{h}(0) = \mathbf{x}$$
The output of the network is $\mathbf{h}(T)$ for some terminal time $T$, obtained by solving this IVP with any black-box ODE solver. The number of “effective layers” is not fixed in advance - the solver adaptively chooses how many function evaluations to use.
This yields continuous-depth networks: instead of $L$ discrete transformations, we have a flow in $\mathbb{R}^d$ parameterized by a neural network $f_\theta$.
The Adjoint Method for Backpropagation
Training requires computing $\frac{\partial \mathcal{L}}{\partial \theta}$ where $\mathcal{L}$ is a scalar loss on $\mathbf{h}(T)$. Naively differentiating through all solver steps requires storing intermediate states - $O(L)$ memory for an $L$-step solver. The adjoint method avoids this.
Define the adjoint $\lambda(t) = \frac{\partial \mathcal{L}}{\partial \mathbf{h}(t)}$. By differentiating the ODE constraint, $\lambda$ satisfies its own ODE run backwards in time:
$$\frac{d\lambda}{dt} = -\lambda(t)^T \frac{\partial f}{\partial \mathbf{h}}(\mathbf{h}(t),, \theta,, t).$$
This is the adjoint ODE. The boundary condition is $\lambda(T) = \frac{\partial \mathcal{L}}{\partial \mathbf{h}(T)}$.
The gradient with respect to parameters is accumulated along the backward pass:
$$\frac{\partial \mathcal{L}}{\partial \theta} = -\int_T^0 \lambda(t)^T \frac{\partial f}{\partial \theta}(\mathbf{h}(t),, \theta,, t),dt.$$
The key insight: we solve three ODEs in a single backward pass - one for $\mathbf{h}(t)$ (re-integrating forward if needed), one for $\lambda(t)$, and one accumulating $\frac{\partial \mathcal{L}}{\partial \theta}$. Memory cost is $O(1)$ in the number of solver steps (we only need the current state, not the entire trajectory).
ODE Solvers as Layers
Euler’s Method
$$\mathbf{y}_{n+1} = \mathbf{y}_n + h,f(t_n,, \mathbf{y}_n)$$
Theorem (Euler Error). The global truncation error of Euler’s method is $O(h)$: if $f$ is Lipschitz and the exact solution has bounded second derivative, then $|\mathbf{y}_n - \mathbf{y}(t_n)| \leq C h$ for a constant $C$ depending on the problem but not on $h$.
One function evaluation per step. Simple but requires small $h$ for accuracy.
Runge-Kutta 4 (RK4)
RK4 takes four function evaluations per step:
$$k_1 = f(t_n,, \mathbf{y}_n)$$ $$k_2 = f!\left(t_n + \tfrac{h}{2},, \mathbf{y}_n + \tfrac{h}{2}k_1\right)$$ $$k_3 = f!\left(t_n + \tfrac{h}{2},, \mathbf{y}_n + \tfrac{h}{2}k_2\right)$$ $$k_4 = f!\left(t_n + h,, \mathbf{y}n + h,k_3\right)$$ $$\mathbf{y}{n+1} = \mathbf{y}_n + \frac{h}{6}(k_1 + 2k_2 + 2k_3 + k_4)$$
Theorem (RK4 Error). RK4 is a fourth-order method: the global truncation error satisfies $|\mathbf{y}_n - \mathbf{y}(t_n)| = O(h^4)$.
Compared to Euler ($O(h)$), halving $h$ reduces the RK4 error by a factor of 16 vs. 2 for Euler - a dramatic difference.
Adaptive Solvers (Dormand-Prince / DOPRI5)
Adaptive solvers estimate the local error by comparing solutions of different orders (e.g., 4th and 5th order) and automatically adjust the step size $h$ to maintain a user-specified tolerance. This is how scipy.integrate.solve_ivp and torchdiffeq’s dopri5 work.
Fixed-step: [t0]----[t1]----[t2]----[t3]--> T
h h h h
Adaptive: [t0]--[t1]--------[t2]-[t3]--> T
small h large h small h
(steep) (smooth) (steep)
Step size chosen to keep local error < tolerance.
Connection to Normalizing Flows: FFJORD
A normalizing flow transforms a simple base density $p_0(\mathbf{z}_0)$ into a complex one $p(\mathbf{x})$ via an invertible map $\mathbf{x} = g(\mathbf{z}_0)$. The change of variables formula gives
$$\log p(\mathbf{x}) = \log p_0(\mathbf{z}_0) - \log\left|\det \frac{\partial g}{\partial \mathbf{z}_0}\right|.$$
FFJORD (Grathwohl et al., 2018) defines $g$ as the flow of a Neural ODE. The log-density evolves continuously via the instantaneous change of variables formula:
$$\frac{\partial \log p(\mathbf{z}(t),, t)}{\partial t} = -\operatorname{tr}!\left(\frac{\partial f}{\partial \mathbf{z}}\right).$$
Computing $\operatorname{tr}(\partial f / \partial \mathbf{z})$ exactly costs $O(d^2)$. FFJORD uses the Hutchinson trace estimator: $\operatorname{tr}(A) = \mathbb{E}_\epsilon[\epsilon^T A\epsilon]$ for $\epsilon \sim \mathcal{N}(0, I)$, reducing cost to $O(d)$ per step.
Hamiltonian Neural Networks
A Hamiltonian system conserves a scalar $H(\mathbf{q}, \mathbf{p})$ (the Hamiltonian, typically energy) through Hamilton’s equations:
$$\dot{q}_i = \frac{\partial H}{\partial p_i}, \quad \dot{p}_i = -\frac{\partial H}{\partial q_i}.$$
Hamiltonian Neural Networks (Greydanus et al., 2019) learn $H_\theta$ from data and define the dynamics via Hamilton’s equations. This guarantees that the learned system conserves $H_\theta$ exactly - something a generic neural ODE cannot ensure.
The corresponding phase space trajectories are confined to level sets of $H$:
p
| H = E2
| .-~~~-.
| .' '. H = E1
| / .---. \ .-~~~-.
|/ .' '. X |
--------+--+----+----+--+------+--> q
| '. .' X |
| '---' / '---'-'
| ' '
| '---'
Level curves of H(q,p) = E are invariant.
Lagrangian Neural Networks
Alternatively, Lagrangian Neural Networks (Cranmer et al., 2020) learn a Lagrangian $\mathcal{L}(\mathbf{q}, \dot{\mathbf{q}})$ and derive dynamics from the Euler-Lagrange equations:
$$\frac{d}{dt}\frac{\partial \mathcal{L}}{\partial \dot{q}_i} - \frac{\partial \mathcal{L}}{\partial q_i} = 0.$$
Both frameworks embed known physical structure - energy conservation or the principle of least action - into the network architecture, improving generalization and long-run simulation accuracy.
Stability and Stiffness
A differential equation is stiff when the solution contains components that decay at vastly different rates. The classic example is
$$y' = -1000y + 3000 - 2000e^{-t}, \quad y(0) = 0.$$
The exact solution has a slow component $\sim e^{-t}$ and a fast transient $\sim e^{-1000t}$. Explicit methods like Euler or RK4 require $h < 2/1000 = 0.002$ for stability (even though the interesting behavior is on scale $t \sim 1$). Implicit methods (e.g., backward Euler, implicit Runge-Kutta) are stable for any $h$ because they solve for $\mathbf{y}_{n+1}$ implicitly.
Stiffness ratio $= |\lambda_{\max}| / |\lambda_{\min}|$ of the Jacobian $\partial f / \partial \mathbf{y}$ quantifies severity. Neural ODE models with stiff dynamics incur huge computational cost from adaptive solvers taking tiny steps.
Formal Theorem (Stability of Euler). For the test equation $y' = \lambda y$ with $\lambda \in \mathbb{C}$, Euler’s method $y_{n+1} = (1 + h\lambda)y_n$ is stable if and only if $|1 + h\lambda| \leq 1$, i.e., $h\lambda$ lies in the unit disk centered at $-1$.
Examples
torchdiffeq (Python/PyTorch) provides odeint and odeint_adjoint, supporting Euler, RK4, DOPRI5, and implicit solvers. Usage pattern:
from torchdiffeq import odeint_adjoint as odeint
def f(t, h): # h: (batch, d), t: scalar
return net(h, t) # any nn.Module
h_T = odeint(f, h0, t_span, method='dopri5')
diffrax (Python/JAX) provides JIT-compiled, differentiable ODE solvers with similar interface but JAX-native vectorization and XLA compilation.
Solver selection guide:
- Smooth, non-stiff:
dopri5(adaptive RK4/5), good default. - Stiff systems:
kvaerno3orradau(implicit Runge-Kutta). - Fixed-step for speed:
eulerorrk4with known good step size. - Normalizing flows (FFJORD):
dopri5withadjoint_method='dopri5'.
The adjoint method cuts memory from $O(N_{\text{steps}} \cdot d)$ to $O(d)$, making Neural ODEs practical even when the solver takes thousands of steps through a complex vector field.
Read Next: