Prerequisite: Software Engineering for ML

Training a neural network is an exercise in managing uncertainty. The code runs without errors, the loss decreases - and yet the model produces garbage predictions. Or the loss goes to NaN after 10,000 steps, wasting two days of GPU time. Unlike a web service that crashes with a traceback, deep learning failures are often silent and delayed. A systematic debugging workflow is not optional - it is what separates engineers who iterate quickly from those who burn compute on broken experiments.

The Golden Rule: Overfit One Batch First

Before training on the full dataset, verify the training loop is correct by running it on a single batch with the expectation of driving loss to near zero. If your model cannot memorize 4 samples, something is wrong with the forward pass, the loss function, or the data pipeline - and you have found it cheaply before wasting GPU hours.

# Grab one batch and repeatedly train on it
batch = next(iter(train_loader))
for step in range(200):
    optimizer.zero_grad()
    loss = model_loss(model, batch)
    loss.backward()
    optimizer.step()
    if step % 20 == 0:
        print(f"step {step}: loss={loss.item():.6f}")
# Expected: loss should approach ~0 (or the irreducible minimum)

If loss does not decrease, the bug is structural. If it does, the model and loss are correct and the problem is elsewhere (data loading, regularization, learning rate schedule).

Training Loss Goes to NaN

NaN loss is the loudest failure, but its causes are often subtle.

Learning rate too high is the most common cause. An LR that is orders of magnitude too large causes gradient updates that overshoot, producing activations that overflow to infinity, which propagate as NaN. Fix: reduce LR by 10x and retry.

Exploding gradients occur in deep networks and RNNs where gradients compound multiplicatively over many layers or time steps. The gradient norm grows exponentially, producing infinite updates. Detect this by logging gradient norms per layer. Fix with gradient clipping:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Numerical instability in the loss function is less obvious. Cross-entropy loss involves taking the log of a softmax output. If you compute log(softmax(logits)) directly, large logits can cause softmax to produce exactly 0 before the log, yielding -inf. Use F.cross_entropy(logits, labels) which internally uses the numerically stable log-sum-exp:

$$\log \sum_i e^{x_i} = x_{\max} + \log \sum_i e^{x_i - x_{\max}}$$

Always use the framework’s fused loss functions rather than composing softmax and log separately.

NaN in the data propagates through every operation. Check your data pipeline: assert not torch.isnan(batch["x"]).any().

Loss Does Not Decrease

A flat or slowly decreasing loss when the single-batch test passes means the training loop is correct but the generalization is failing.

Learning rate too low: The model updates in the right direction but too slowly. Try the learning rate finder: increase the LR exponentially over a few hundred steps and plot loss vs LR. The LR just before loss starts to increase is a good maximum; use 1/10th of that as the initial LR.

Bug in the forward pass: A layer that is accidentally skipped, a transposition that swaps batch and feature dimensions, or a missing activation function. Add shape assertions and test intermediate outputs.

Wrong loss function: Regression problems with cross-entropy, or classification with MSE, will not converge correctly. Check that the loss function matches the task.

Data not shuffled: If training data is ordered by class and each batch is a single class, the model sees adversarial updates. Always shuffle.

Overfitting: Loss Diverges Between Train and Validation

When training loss decreases but validation loss increases, the model is memorizing training data rather than learning generalizable features.

Remedies:

  • Dropout: Randomly zero activations during training, forcing the network to not rely on individual neurons.
  • Weight decay (L2 regularization): Penalizes large weights, baked into the optimizer as AdamW(params, weight_decay=1e-4).
  • Data augmentation: Synthetically expand the training set with transforms (flips, crops, color jitter for images; synonym replacement for text).
  • More data: The most reliable regularizer.
  • Early stopping: Monitor validation loss and stop when it stops improving.

Underfitting: Both Losses Plateau Too High

When neither training nor validation loss reaches an acceptable level, the model lacks capacity or training is insufficient.

  • Larger model: More layers, wider layers, larger embedding dimensions.
  • More epochs: Sometimes the model simply hasn’t trained long enough. Check whether training loss is still decreasing slowly.
  • Lower learning rate (with more epochs): If LR is too high the model oscillates around a minimum without descending into it.
  • Data pipeline bug: If data is corrupted, mislabeled, or loaded incorrectly, no model size will fix it. Add visualizations: log a few raw samples and their labels to W&B at the start of every run.

Gradient Flow Diagnostics

Dead ReLUs occur when a ReLU unit’s input is always negative - the gradient is always zero and the unit never updates. This happens with high learning rates that push weights into regions where all inputs are negative. Detect by logging the fraction of zero activations per layer. Fix by using leaky ReLU or GELU, or reducing the learning rate.

Vanishing gradients in very deep networks: gradients in early layers are the product of many Jacobians, which can shrink exponentially. Check gradient norms by layer:

for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name}: grad_norm={param.grad.norm().item():.4f}")

If gradients in early layers are many orders of magnitude smaller than in later layers, add residual connections, use batch normalization, or switch to an architecture designed for deep networks.

Loss Curve Patterns

  • Spikes in loss: Sudden jumps indicate instability - often a bad batch (corrupt data) or a LR that is near the instability threshold. Add gradient clipping and check data quality.
  • Loss decreases then flatlines: Possible local minimum, LR too small to escape, or the model has reached its capacity. Try a cyclic learning rate schedule that occasionally increases LR to escape flat regions.
  • Decreasing training loss, flat validation loss from the start: The validation set distribution differs from training (data leakage in the reverse direction, or a preprocessing step applied to train but not val).
  • Both losses flat from step 1: The most common cause is a zero gradient - check that loss.backward() is called, optimizer.zero_grad() is called at the right time, and parameters are actually connected to the loss.

Monitoring Weight and Activation Statistics

Log the following to W&B or TensorBoard every N steps:

for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        wandb.log({
            f"{name}/weight_norm": module.weight.norm().item(),
            f"{name}/weight_mean": module.weight.mean().item(),
            f"{name}/weight_std": module.weight.std().item(),
        })

Weight norms that grow without bound indicate exploding parameters. Weight norms that collapse to near zero indicate dead units. Activation distributions that are always zero indicate dead ReLUs. These diagnostics catch slow-moving failures before they manifest as NaN or bad metrics.

Examples

Grad Norm Logging in PyTorch

def compute_grad_norm(model: nn.Module) -> float:
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.detach().data.norm(2)
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5

# In the training loop:
loss.backward()
grad_norm = compute_grad_norm(model)
wandb.log({"grad_norm": grad_norm})

# Clip AFTER computing norm so you log the pre-clip value
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

nan_to_num Pitfalls

# WRONG: masking NaN hides the bug and corrupts training
loss = torch.nan_to_num(raw_loss, nan=0.0)
optimizer.step()

# RIGHT: assert, diagnose the source, fix it
assert torch.isfinite(raw_loss), (
    f"Loss is {raw_loss.item()} at step {step}. "
    f"Input stats: mean={x.mean():.3f}, max={x.abs().max():.3f}"
)

nan_to_num is sometimes used defensively, but in training it is almost always wrong - it lets a broken forward pass continue, corrupting weights silently.

Systematic Batch Overfit Test

import torch
import torch.nn as nn

def overfit_test(model: nn.Module, loss_fn, n_samples: int = 4,
                 n_steps: int = 500, target_loss: float = 0.01) -> bool:
    """Returns True if model can memorize a small batch."""
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    x = torch.randn(n_samples, *input_shape)
    y = torch.randint(0, n_classes, (n_samples,))

    model.train()
    for step in range(n_steps):
        optimizer.zero_grad()
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        optimizer.step()

    final_loss = loss.item()
    passed = final_loss < target_loss
    print(f"Overfit test: final_loss={final_loss:.6f} | {'PASS' if passed else 'FAIL'}")
    return passed

Run this before every serious training job. If it fails, you have a bug. If it passes, you know the model and loss are correct and can proceed to train on the full dataset with confidence.

A disciplined debugging workflow - overfit one batch, monitor gradients, check loss curves - compresses what would otherwise be days of confused iteration into hours of systematic diagnosis.


Read Next: KV-Caching & LLM Inference | Building Sharded Transformers