Inside JAX & XLA
Prerequisite: Autodiff
NumPy is elegant, but it runs only on CPU and has no autodiff. PyTorch fixed both, but its eager execution means every operation launches a separate GPU kernel, and Python overhead limits how much the compiler can optimize. JAX takes a different approach: it exposes NumPy’s API but compiles computations through XLA (Accelerated Linear Algebra), a compiler that produces highly optimized code for CPU, GPU, and TPU. The key insight is that functional purity - no in-place mutation, explicit state, no side effects - makes the computation graph amenable to aggressive compiler optimization.
Core Design: Functional and Compiled
JAX’s design enforces a functional style. Arrays are immutable. Operations return new arrays; they don’t modify existing ones. This constraint is not just philosophical - it is what makes JAX’s transformations composable and its compiler sound.
When you call jax.jit(f)(x), JAX traces f by running it with abstract values that represent the shape and dtype of x but not its concrete values. This produces a computation graph expressed in XLA’s HLO (High Level Operations) IR. XLA compiles that HLO to machine code for your hardware and caches the compiled artifact. Subsequent calls with inputs of the same shape and dtype execute the compiled binary directly, bypassing Python entirely.
This is why JIT-compiled JAX code is fast: the Python interpreter overhead disappears, and XLA can see the entire computation and optimize it globally.
jax.jit: Tracing and Compilation
import jax
import jax.numpy as jnp
def softmax(x):
x_max = jnp.max(x)
e_x = jnp.exp(x - x_max)
return e_x / jnp.sum(e_x)
softmax_jit = jax.jit(softmax)
x = jnp.array([1.0, 2.0, 3.0])
# First call: traces and compiles (~100ms)
out = softmax_jit(x)
# Subsequent calls: executes cached binary (~microseconds)
out = softmax_jit(x)
The tracing step is the critical one. JAX sees the abstract computation, not the concrete numbers. This means Python control flow that depends on concrete values - if x > 0 where x is a JAX array - will be evaluated at trace time with the abstract value and become static in the compiled binary. To handle dynamic shapes or values, use jax.lax primitives (lax.cond, lax.while_loop) that lower to control flow in HLO.
Gotchas:
- No side effects inside JIT (no printing, no in-place mutation).
- Python-side constants are baked into the compiled binary. Pass them as arguments if they change, or use
static_argnums. - Re-compilation happens when input shapes change - avoid dynamic shapes in hot loops.
XLA: Operator Fusion
XLA’s most important optimization is operator fusion: merging multiple operations into a single GPU kernel. Without fusion, each elementwise operation (relu, add, multiply) launches a separate kernel. Each kernel reads its inputs from HBM, does minimal computation, and writes outputs back to HBM. The bottleneck is memory bandwidth, not compute.
XLA fuses a sequence of elementwise operations into one kernel that reads inputs once, performs all operations in registers, and writes the result once. For a sequence of 10 elementwise ops, fusion reduces HBM traffic by ~10x. This is a significant fraction of why JAX code often outperforms PyTorch eager mode on element-wise-heavy computations.
XLA also performs layout optimization (choosing whether arrays are row-major or column-major based on how they’re used), constant folding, and common subexpression elimination.
jax.grad: Automatic Differentiation
jax.grad returns a function that computes the gradient of a scalar-valued function using reverse-mode automatic differentiation. It is not symbolic differentiation and not numerical differentiation - it transforms the computation graph into a backward pass that computes exact gradients.
def mse_loss(params, x, y):
predictions = params["w"] @ x + params["b"]
return jnp.mean((predictions - y) ** 2)
grad_fn = jax.grad(mse_loss) # differentiates w.r.t. first arg by default
grads = grad_fn(params, x, y)
# grads has the same structure as params: {"w": ..., "b": ...}
jax.value_and_grad returns both the loss and the gradients in one pass, avoiding double computation:
loss, grads = jax.value_and_grad(mse_loss)(params, x, y)
jax.grad composes with jax.jit - you can JIT-compile the gradient function, and XLA will optimize the backward pass just as it does the forward pass.
jax.vmap: Automatic Batching
vmap (vectorizing map) transforms a function written for a single example into a batched function that operates on an array of examples. Critically, it does this without for loops - it produces a single vectorized computation.
def single_loss(w, x, y):
return jnp.dot(w, x) - y
# Manually batching with vmap
batched_loss = jax.vmap(single_loss, in_axes=(None, 0, 0))
# None = don't batch w (shared), 0 = batch x and y along axis 0
w = jnp.ones(10)
X = jnp.ones((32, 10)) # batch of 32
Y = jnp.ones(32)
losses = batched_loss(w, X, Y) # shape (32,)
The power of vmap is that you write the logic for a single input - cleaner, easier to reason about - and get efficient batched execution for free. Combined with grad, you can compute per-example gradients without the memory overhead of a full batch backward pass.
jax.pmap: Data Parallelism Across Devices
pmap (parallel map) replicates a function across multiple devices, with each device receiving a different shard of the input. It is the JAX primitive for data-parallel training.
from jax import pmap
from jax.lax import pmean
@pmap
def train_step(params, batch):
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
grads = pmean(grads, axis_name="devices") # AllReduce gradients
params = update(params, grads)
return params, loss
pmean performs a collective AllReduce across all devices, averaging the gradients. After this, every device has identical gradients and updates its parameters identically - this is synchronous data parallelism.
Pytrees: Operating on Nested Structures
JAX operates on pytrees: arbitrary nested Python containers (lists, tuples, dicts) of JAX arrays. All JAX transformations (jit, grad, vmap, pmap) understand pytrees natively. When jax.grad differentiates with respect to a dict of parameters, it returns a dict of gradients with the same structure. This makes it natural to represent model parameters as nested dicts or dataclasses.
params = {
"layer1": {"w": jnp.ones((64, 32)), "b": jnp.zeros(64)},
"layer2": {"w": jnp.ones((10, 64)), "b": jnp.zeros(10)},
}
grads = jax.grad(loss_fn)(params, x, y)
# grads: {"layer1": {"w": ..., "b": ...}, "layer2": {...}}
Explicit PRNG
NumPy and PyTorch use global random state: call np.random.randn() twice and you get different values, with state mutated globally. This is incompatible with JAX’s functional model. JAX requires explicit PRNG keys, split deterministically:
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, shape=(100,))
Splitting a key produces two keys that are independent and fully deterministic given the parent. This means JAX code is exactly reproducible and safe to parallelize (no shared mutable state).
Examples
JIT-Compiled Neural Network Layer
import jax
import jax.numpy as jnp
def linear_relu(params, x):
return jax.nn.relu(x @ params["w"] + params["b"])
# JIT-compile for fast repeated calls
linear_relu_jit = jax.jit(linear_relu)
key = jax.random.PRNGKey(0)
params = {
"w": jax.random.normal(key, (128, 64)),
"b": jnp.zeros(64),
}
x = jax.random.normal(key, (32, 128)) # batch of 32
out = linear_relu_jit(params, x) # shape (32, 64)
vmap for Batched Loss
def single_cross_entropy(logits, label):
"""Loss for a single example."""
log_probs = jax.nn.log_softmax(logits)
return -log_probs[label]
# Batch over axis 0 of logits and labels
batched_loss = jax.vmap(single_cross_entropy)
logits = jax.random.normal(jax.random.PRNGKey(0), (32, 10))
labels = jnp.zeros(32, dtype=jnp.int32)
losses = batched_loss(logits, labels) # shape (32,)
mean_loss = jnp.mean(losses)
grad of a Custom Loss Function
def contrastive_loss(embeddings, labels, margin=1.0):
"""Simplified pairwise contrastive loss."""
dists = jnp.sum((embeddings[0] - embeddings[1]) ** 2)
same = labels[0] == labels[1]
loss = jnp.where(same, dists, jnp.maximum(0., margin - dists))
return loss
grad_fn = jax.jit(jax.grad(contrastive_loss))
key = jax.random.PRNGKey(0)
emb = jax.random.normal(key, (2, 128))
labs = jnp.array([0, 1])
grads = grad_fn(emb, labs) # gradient w.r.t. embeddings
JAX rewards understanding its execution model: trace once, compile, reuse. The functional constraints feel restrictive at first, but they are precisely what allows the transformations to compose correctly and XLA to optimize aggressively.
Read Next: Sharded Transformers