Prerequisite:

Standard self-attention is $O(n^2)$ in both time and memory with respect to sequence length $n$. For a sequence of 16k tokens, the attention matrix alone requires storing $16384^2 \approx 268$ million floats - roughly 1 GB at fp32. This post surveys the main architectural and algorithmic responses to that bottleneck.

The Quadratic Bottleneck

Given queries $Q \in \mathbb{R}^{n \times d_k}$, keys $K \in \mathbb{R}^{n \times d_k}$, and values $V \in \mathbb{R}^{n \times d_v}$, standard attention computes:

$$\text{Attention}(Q, K, V) = \text{softmax}!\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

The matrix $QK^T \in \mathbb{R}^{n \times n}$ must be fully materialised before the softmax can be applied. This is both a memory and a bandwidth problem - the bottleneck on modern hardware is not floating-point operations but movement of data between HBM (high-bandwidth memory) and the faster on-chip SRAM.

Sparse Attention

Sparse attention methods restrict which pairs of tokens are allowed to attend to each other, replacing the dense $n \times n$ matrix with a sparse pattern.

Longformer uses a combination of a local sliding window of size $w$ (each token attends to $w/2$ neighbours on each side) and a set of global tokens that attend to and are attended to by every position. The complexity drops to $O(n \cdot w)$ for the local part.

BigBird extends this with three types of attention: random edges (each token attends to $r$ randomly selected tokens), a local window, and global tokens. The full attention graph is theoretically a sparse random graph, which BigBird shows is sufficient to preserve the expressive power of full attention. Complexity is $O(n)$.

The drawback of both approaches is that the sparsity pattern is fixed at design time. Tasks requiring arbitrary long-range interactions may not align with the chosen pattern.

Linear Attention via Kernel Approximation

The softmax in standard attention can be written using the kernel trick. Define $\text{sim}(q, k) = \exp(q \cdot k / \sqrt{d_k})$. The attention output at position $i$ is:

$$o_i = \frac{\sum_j \text{sim}(q_i, k_j) v_j}{\sum_j \text{sim}(q_i, k_j)}$$

If we can factor $\text{sim}(q, k) \approx \phi(q)^T \phi(k)$ for some feature map $\phi : \mathbb{R}^d \to \mathbb{R}^r$, the computation becomes:

$$o_i \approx \frac{\phi(q_i)^T \left(\sum_j \phi(k_j) v_j^T\right)}{\phi(q_i)^T \left(\sum_j \phi(k_j)\right)}$$

The inner sums $\sum_j \phi(k_j) v_j^T$ and $\sum_j \phi(k_j)$ can be computed once and reused for all queries, reducing complexity to $O(n \cdot r)$. The Performer uses random Fourier features to approximate the exponential kernel. Linformer projects keys and values to a lower-dimensional space $\mathbb{R}^k$ before attention, with $k \ll n$.

The fundamental limitation is approximation error: the kernel decomposition does not reproduce the softmax exactly, and for tasks where attention sharpness matters (e.g., retrieval-heavy reasoning) this can hurt performance meaningfully.

Flash Attention

Flash Attention (Dao et al., 2022) achieves exact attention in $O(n)$ memory without any approximation, by reordering the computation to avoid materialising the full $n \times n$ matrix.

The key observation is the online softmax trick. Suppose we are computing softmax incrementally as we process blocks of keys. For a block with maximum logit $m$ and partition sum $l$, when we see a new block with maximum $m'$ and partition sum $l'$, we can update:

$$m_{\text{new}} = \max(m, m')$$ $$l_{\text{new}} = e^{m - m_{\text{new}}} l + e^{m' - m_{\text{new}}} l'$$

This lets us accumulate the correct normalised output without ever storing the full attention matrix. The algorithm tiles the $Q$, $K$, $V$ matrices into blocks of size $B$ that fit in SRAM, processes each block, and passes statistics $(m, l)$ between tiles. The number of HBM reads/writes is $O(n^2 / B)$ rather than $O(n^2)$, and $B$ can be chosen so each tile stays entirely in fast on-chip memory.

Flash Attention 2 further reduces non-matrix-multiplication FLOPs and improves thread-block scheduling for better GPU utilisation, achieving near-theoretical peak throughput.

Mixture of Experts (MoE)

Rather than making attention cheaper, MoE addresses the parameter-count-to-compute ratio by making the FFN layers conditional. Each transformer layer contains $E$ expert FFNs ${E_1, \ldots, E_E}$ and a gating network $G$.

The output for input $x$ is:

$$y = \sum_{k=1}^{E} G(x)_k , E_k(x)$$

where $G(x) \in \mathbb{R}^E$ assigns routing weights. In practice, top-$k$ routing keeps only the $k$ largest weights (typically $k=1$ or $k=2$) and sets the rest to zero, so only $k$ experts are actually computed per token.

The challenge is load imbalance: the router tends to collapse onto a few popular experts, leaving most experts idle. This is addressed with an auxiliary load-balancing loss:

$$L_{\text{aux}} = \alpha \sum_{i=1}^{E} f_i \cdot P_i$$

where $f_i$ is the fraction of tokens routed to expert $i$ and $P_i$ is the average routing probability assigned to expert $i$. Minimising this encourages uniform utilisation across experts.

Mixtral 8x7B applies top-2 MoE routing at every FFN layer, giving a model with 47B total parameters but only activating about 13B parameters per token - matching the inference compute of a much smaller dense model while retaining the representational capacity of the larger one.

Structured State Space Models (SSMs)

SSMs, and specifically Mamba (Gu & Dao, 2023), abandon attention entirely. A linear time-invariant state space model evolves a hidden state $x \in \mathbb{R}^N$ via:

$$x'(t) = Ax(t) + Bu(t), \quad y(t) = Cx(t)$$

Discretised with step size $\Delta$, this becomes a recurrence that can be unrolled into a convolution for training (enabling parallelism) and evaluated as a recurrence at inference (enabling $O(1)$ per-step cost). The H3 and S4 models constrain $A$ to structured (diagonal or diagonal plus low-rank) form for efficiency.

Mamba introduces a selective mechanism: the matrices $B$, $C$, and $\Delta$ become functions of the input $u(t)$ rather than fixed parameters. This breaks the time-invariance but allows the model to selectively retain or discard information based on content - partially recovering the expressiveness of attention. The parallel scan algorithm allows this to still be computed efficiently during training.

Examples

Flash Attention speedups. On a 2k-token sequence on an A100 GPU, Flash Attention achieves roughly $3\times$ end-to-end speedup over the PyTorch naive implementation and uses 5–20$\times$ less memory, with identical numerical outputs.

Mixtral MoE architecture. Mixtral 8x7B uses 32 layers, each with 8 expert FFNs of size 14336 and top-2 routing. On standard benchmarks it matches or exceeds LLaMA 2 70B while requiring less than half the active compute per forward pass.

Mamba scaling. On long-sequence tasks (e.g., sequence lengths of 1M tokens), Mamba processes tokens in $O(n)$ time and $O(1)$ recurrent memory, which is simply infeasible for standard attention-based models at comparable scale.


Read Next: