Inside JAX & XLA - Composable Transformations That Compile to Accelerators
Helpful context:
NumPy is fast. But what if you could write NumPy code and have it compile automatically to optimized CUDA kernels or TPU programs - with automatic differentiation, batch vectorization, and multi-device parallelism, all composable with each other and expressible in pure Python?
That is what JAX offers. And it is not magic. It is a specific set of design choices, each with real consequences you need to understand before JAX stops surprising you and starts feeling inevitable.
Why JAX Exists
The background starts at Google Brain around 2017 - 2018. TensorFlow 1.x required building a static computation graph before running anything - powerful for compilation but deeply painful to debug. PyTorch had just introduced dynamic (eager) execution, which made debugging natural but left performance on the table: every operation launched a separate GPU kernel, and Python overhead meant the compiler could never see the full computation.
JAX, open-sourced in 2018, chose a third path. Write code in a functional NumPy-compatible style. Trace it to a computation graph automatically. Compile that graph with XLA (Accelerated Linear Algebra), a compiler originally developed for TPUs that Google also uses for GPU and CPU. The result: the ease of eager execution during development, the performance of compiled execution in production.
The key that made this possible was functional purity: no in-place mutation, no global state, no side effects. This is not a philosophical preference - it is what makes the compiler sound. If a function has no side effects, the compiler can safely reorder, fuse, or replicate its operations. If arrays are immutable, there are no aliasing hazards. The constraints are load-bearing.
The Tracing Model: How jit Actually Works
When you call jax.jit(f)(x), something unusual happens. JAX doesn’t execute f. It traces f by running it with abstract values - tracer objects that record shape, dtype, and the operations performed on them, but carry no concrete numbers.
The trace produces a computation graph expressed in XLA’s HLO (High Level Operations) IR - a graph of matrix multiplies, elementwise operations, reductions, and control flow. XLA compiles that HLO to machine code for your hardware target and caches the compiled binary. Subsequent calls with inputs of the same shape and dtype skip Python entirely and execute the compiled binary directly.
import jax
import jax.numpy as jnp
def softmax(x):
x_max = jnp.max(x)
e_x = jnp.exp(x - x_max)
return e_x / jnp.sum(e_x)
softmax_jit = jax.jit(softmax)
x = jnp.array([1.0, 2.0, 3.0])
# First call: traces + compiles (~100ms warmup)
out = softmax_jit(x)
# Subsequent calls: compiled binary, microseconds
out = softmax_jit(x)
This is why JIT-compiled JAX is fast: the Python interpreter disappears, and XLA can see the entire computation and optimize it globally - fusing operations, choosing layouts, eliminating temporaries.
The Tracing Gotcha: Python Control Flow is Frozen
Here is the discomfort you need to sit with.
When JAX traces your function, it runs Python control flow with abstract values. An if statement that branches on a JAX array will be evaluated at trace time - with the abstract value, not the concrete one. The result is that the branch taken at trace time is baked into the compiled binary forever.
# DANGEROUS: this gets traced once and the branch is frozen
@jax.jit
def f(x):
if x > 0: # x is abstract at trace time; this evaluates arbitrarily
return x * 2
return x * -1
# CORRECT: use JAX control flow that lowers to HLO
@jax.jit
def f(x):
return jax.lax.cond(x > 0, lambda x: x * 2, lambda x: x * -1, x)
jax.lax.cond, jax.lax.fori_loop, jax.lax.while_loop - these are the JAX equivalents of Python control flow. They compile to corresponding HLO control flow operations that are evaluated at runtime, not trace time. Python for loops in traced functions are unrolled - they produce HLO for each iteration, which is fine for small loops but catastrophic for large ones.
Similarly, Python-level constants are baked into the compiled binary. If you pass a flag as a Python int and it changes, JAX won’t notice - it cached the binary for the first value. Use static_argnums to tell JAX to retrace when specific arguments change:
@functools.partial(jax.jit, static_argnums=(1,))
def f(x, training: bool):
if training: # now safe: JAX retraces when training changes
return dropout(x)
return x
Recompilation happens when input shapes change. Avoid dynamic shapes in hot loops. This is a real operational concern for serving systems that process variable-length inputs.
XLA: The Compiler That Makes It Real
XLA is not just a wrapper around cuBLAS. It is a full optimizing compiler for linear algebra with its own IR, its own optimization passes, and its own code generation backends (CUDA, HIP, CPU, TPU).
XLA’s most important optimization for ML workloads is operator fusion: merging multiple elementwise operations into a single GPU kernel. Without fusion, each elementwise operation (relu, add, multiply) launches a separate kernel. Each kernel reads its inputs from HBM (High Bandwidth Memory), does minimal work, and writes outputs back to HBM. The compute-to-memory-bandwidth ratio is terrible.
XLA fuses a sequence of elementwise operations into one kernel that reads inputs once, performs all operations in registers, and writes the final result once. For a chain of 10 elementwise ops, this reduces HBM traffic by roughly 10x. This is a significant fraction of why JAX often outperforms PyTorch eager mode on element-wise-heavy computations.
Beyond fusion, XLA performs:
- Layout optimization: choosing row-major vs column-major for each array based on how it is used downstream, to minimize transpose cost.
- Constant folding: evaluating expressions whose operands are all constants at compile time.
- Common subexpression elimination: computing a shared sub-expression once rather than multiple times.
- HBM traffic minimization: rematerialization (recomputing activations from cheaper inputs rather than storing them) for gradient checkpointing.
The trade-off: XLA compilation is slow. A model with many distinct input shapes can trigger minutes of compilation on first run. Production systems warm up JAX models extensively before serving traffic.
jax.grad: Automatic Differentiation Through Function Composition
jax.grad implements reverse-mode automatic differentiation, but not by symbolic manipulation of formulas and not by finite differences. It transforms the computation graph produced by tracing into a backward pass that computes exact gradients.
def mse_loss(params, x, y):
predictions = params["w"] @ x + params["b"]
return jnp.mean((predictions - y) ** 2)
# Differentiates w.r.t. the first argument (params) by default
grad_fn = jax.grad(mse_loss)
grads = grad_fn(params, x, y)
# grads has the same structure as params: {"w": ..., "b": ...}
# More efficient: compute loss and gradients in one pass
loss, grads = jax.value_and_grad(mse_loss)(params, x, y)
jax.grad composes with jax.jit: you can JIT-compile the gradient function, and XLA will optimize the backward pass just as aggressively as the forward pass. You can also take higher-order derivatives by nesting grad:
d2f = jax.grad(jax.grad(f)) # second derivative
This works because jax.grad(f) is itself a function that can be traced. The composition is mathematically sound because of functional purity - no state, no side effects.
jax.vmap: Vectorization as a Transformation
vmap transforms a function written for a single example into a batched function. Crucially, this is not a Python loop - it produces a single vectorized computation that runs in one pass.
def single_loss(w, x, y):
return jnp.dot(w, x) - y
# Batch over axis 0 of x and y; don't batch w (it's shared)
batched_loss = jax.vmap(single_loss, in_axes=(None, 0, 0))
w = jnp.ones(10)
X = jnp.ones((32, 10)) # 32 examples
Y = jnp.ones(32)
losses = batched_loss(w, X, Y) # shape (32,)
The canonical use case in research is per-example gradients - the gradient of the loss for each individual training example, needed for privacy-preserving methods like differential privacy. In PyTorch, per-example gradients require running backprop once per example or hacky vmap-like tricks. In JAX:
per_example_grads = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))
This composes two transformations - vmap and grad - and JAX handles the algebra. The result is a function that computes all per-example gradients in a single vectorized pass with no Python loops.
jax.pmap: SPMD Across Devices
pmap (parallel map) replicates a function across multiple devices (GPUs or TPUs), where each device receives a different shard of the input. It implements Single Program Multiple Data (SPMD) parallelism.
from jax import pmap
from jax.lax import pmean
@pmap
def train_step(params, batch):
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
grads = pmean(grads, axis_name="devices") # AllReduce across devices
params = update(params, grads)
return params, loss
pmean compiles to a collective AllReduce - NCCL on GPU, ICI on TPU. After AllReduce, every device holds the same averaged gradient and applies the same update, maintaining exact synchrony. This is synchronous data parallelism, equivalent to training on a batch that is the sum of all device batches.
For model parallelism (sharding model parameters across devices rather than data), JAX provides jax.experimental.maps and, more recently, jax.sharding and jax.lax.with_sharding_constraint. These allow you to annotate which axes of an array are sharded across which devices, and XLA inserts the necessary communication automatically.
This is how Google trains Gemini: JAX + pmap/jax.distributed across thousands of TPUs. The functional purity that constrains individual programs becomes a massive advantage at scale - there is no shared mutable state between devices, which makes the communication pattern explicit and analyzable.
Pytrees: Operating on Nested Structures
JAX transformations operate on pytrees: arbitrary nested Python containers (lists, tuples, dicts) of JAX arrays. This is what makes it natural to represent model parameters as nested dicts.
params = {
"layer1": {"w": jnp.ones((64, 32)), "b": jnp.zeros(64)},
"layer2": {"w": jnp.ones((10, 64)), "b": jnp.zeros(10)},
}
# grad returns a pytree with the same structure as params
grads = jax.grad(loss_fn)(params, x, y)
# grads: {"layer1": {"w": ..., "b": ...}, "layer2": {...}}
# jit, vmap, pmap all understand pytrees natively
jax.tree_util.tree_map(lambda g: g * 0.01, grads) # scale all gradients
jax.tree_util.tree_map applies a function to every leaf of a pytree simultaneously. Combined with jax.grad, this is essentially the entire parameter update loop of neural network training, expressed in a dozen lines.
Explicit PRNG: Reproducibility by Design
NumPy and PyTorch use global random state - call np.random.randn() twice and you get different values, with the state mutated globally. This is incompatible with JAX’s functional model and makes reproducibility across parallel programs nearly impossible.
JAX requires explicit PRNG keys, split deterministically:
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, shape=(100,))
Splitting a key produces two keys that are independent and fully deterministic given the parent. JAX code is exactly reproducible: run the same key on the same hardware and you get the same random numbers, regardless of parallelism. This is operationally important for debugging distributed training.
JAX vs PyTorch 2.0: Tracing vs Graph Capture
PyTorch 2.0 introduced torch.compile, which uses a different approach to similar ends. Rather than tracing through abstract values (JAX’s method), Dynamo captures the Python bytecode executed during a real forward pass and converts it into a graph IR, then optimizes that IR with Inductor (a compiler targeting CUDA and CPU).
The difference matters in practice:
- JAX JIT is strict: you must write JAX-compatible functional code. Python control flow that depends on array values must use
jax.laxprimitives. The constraint is tight. - torch.compile is lenient: it tries to handle arbitrary PyTorch code, falling back to eager execution when it can’t compile. The constraint is loose, which makes adoption easier but optimization less aggressive.
For research code that changes daily, torch.compile is lower friction. For production workloads with stable computation graphs, JAX’s aggressive compilation (especially on TPU) tends to win. Google uses JAX for Gemini training. OpenAI uses PyTorch. Both are correct choices for their respective contexts.
The Flax/Optax Ecosystem
JAX itself has no built-in neural network library. The ecosystem provides:
- Flax: neural network modules as pytree-structured parameter dictionaries. The convention: modules define
init(returns initial parameters) andapply(runs the forward pass given parameters, stateless). Nomodel.train()/model.eval()- state is explicit. - Optax: a library of gradient transformations. An optimizer is a pure function
(params, grads) → updates, composable with gradient clipping, weight decay, learning rate schedules. - Equinox: an alternative to Flax that uses Python dataclasses for module definitions, feeling more like PyTorch.
The composability is elegant but the learning curve is real. Newcomers expecting model.forward() and optimizer.step() will find JAX’s explicit state management jarring. It becomes natural after a week, but that week is genuinely hard.
Critique: What JAX Gets Wrong
JAX’s functional purity is a double-edged sword. The constraint that makes XLA optimization safe also makes debugging painful. When your JIT-compiled function fails, you get an XLA error - sometimes a reasonable one, sometimes a cryptic HLO shape mismatch with no reference to your original Python code. The mental model of “trace once, compile, run” means errors manifest at trace time or at runtime, and mapping them back to source lines requires experience.
Debugging strategies: remove jax.jit to run eagerly (errors surface in Python with normal tracebacks), use jax.debug.print (which works inside JIT by emitting a debug callback), and jax.debug.breakpoint() (experimental but useful).
The ecosystem is also genuinely more fragmented than PyTorch. Flax, Haiku, Equinox, and NNX coexist without a clear winner. Optax is well-regarded but has a different API than PyTorch optimizers. Libraries like Hugging Face Transformers have JAX/Flax ports but they lag the PyTorch versions. If you want to fine-tune a Llama model in JAX today, you will find far fewer tutorials and community support than the PyTorch equivalent.
Finally: JAX on GPU outside of Google infrastructure is good but not seamless. XLA’s GPU backend targets NVIDIA well but the tooling for debugging GPU kernels (nsight, rocprof) is less integrated with XLA than with PyTorch/CUDA. AMD GPU support exists but is less mature.
Future: Differentiable Programming and IREE
JAX’s core contribution is not just a faster NumPy. It is a demonstration that differentiable programming - the idea that any computation can be differentiated through if it is written in a functional style - is practically useful and not prohibitively complex.
The principles JAX embodies - composable transformations, functional purity, compilation through tracing - are spreading. Julia’s Enzyme.jl brings similar ideas to Julia. Swift for TensorFlow experimented with language-level differentiation. Tinygrad and Micrograd are minimal implementations of the same idea for learning.
IREE (Intermediate Representation Execution Environment) is an open-source effort to provide a portable alternative to XLA’s compilation stack, targeting a wider range of hardware including edge devices. JAX can emit to IREE in addition to XLA, which may eventually allow JAX-style programming to reach hardware targets (mobile, embedded, FPGA) that XLA doesn’t support well.
Summary
| Transformation | What It Does | Composes With | Key Gotcha |
|---|---|---|---|
jax.jit |
Trace + compile through XLA | All others | Python control flow frozen at trace |
jax.grad |
Reverse-mode autodiff | jit, vmap, pmap |
Must return a scalar |
jax.vmap |
Auto-batching | jit, grad, pmap |
in_axes must match input shapes |
jax.pmap |
SPMD multi-device | jit, grad |
Requires data replicated on all devices |
| Feature | JAX | PyTorch 2.0 |
|---|---|---|
| Compilation approach | Tracing (abstract values) | Graph capture (bytecode) |
| Control flow | Must use jax.lax.* |
Mostly transparent |
| Functional purity | Required | Not required |
| Ecosystem maturity | Growing | Dominant |
| TPU support | Native (Google) | Limited |
| Debugging experience | Harder (XLA errors) | Better (eager fallback) |
Read Next: