Efficient Transformers - Attention Without the Quadratic Cost
Helpful context:
- Transformers From First Principles - Why Attention Changed Everything
- The Chain Rule - How Derivatives Compose Through Layers
The attention mechanism computes $n^2$ dot products for a sequence of $n$ tokens. GPT-4 processes up to 128,000 tokens. That’s $128000^2 = 16,384,000,000$ attention score computations per layer - 16 billion operations - and GPT-4 has 96 layers. But compute is only half the problem. The attention matrix itself, the $n \times n$ array of scores before softmax, would require $128000 \times 128000 \times 4$ bytes $= 65.5$ GB of memory - for a single layer, and just for the matrix. The GPU you’re running on has 80 GB total.
Standard attention does not scale to long contexts. Efficient transformers are the collection of engineering and mathematical ideas that make it possible anyway. Some are approximations that trade a small amount of accuracy for large gains in efficiency. Others - most famously FlashAttention - are exact but reorganize computation to avoid the memory bottleneck entirely.
The Quadratic Bottleneck
Let’s be precise about what “quadratic” means here.
Standard scaled dot-product attention computes:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
where $Q$, $K$, $V \in \mathbb{R}^{n \times d_k}$.
The matrix product $QK^T$ produces an $n \times n$ matrix. This costs $O(n^2 d_k)$ time and $O(n^2)$ memory. For $n = 1000$: the attention matrix has one million entries. For $n = 100,000$: ten billion entries. Memory is often the binding constraint, not compute - modern GPUs are fast enough to crunch the numbers, but can’t fit the matrix in their fast memory.
The bottleneck has two axes:
- Compute: reducing the number of operations (approximation methods).
- Memory I/O: reorganizing operations so the $n \times n$ matrix never has to exist (IO-aware methods).
Different techniques target different axes. FlashAttention targets memory I/O without approximating. Sparse attention and linear attention reduce compute, usually with some approximation.
Sparse Attention
The key observation: most attention weights are small. In a trained transformer, the softmax over a long sequence is dominated by a few high-weight positions; the rest contribute negligibly. So why compute them all?
Sparse attention restricts which $(i, j)$ pairs token $i$ can attend to. Instead of attending to all $n$ positions, each token attends to a structured subset of size $k \ll n$. The cost drops from $O(n^2)$ to $O(nk)$.
Local window attention: token $i$ attends to tokens $i - w, \ldots, i + w$ - a sliding window of width $2w + 1$. Cost: $O(nw)$. This captures local context well. Think of how you understand a sentence: nearby words matter most, and words 500 positions away rarely help. Window size $w = 512$ on a sequence of 100,000 tokens reduces compute by 100×.
Dilated (strided) attention: token $i$ attends to every $d$-th token: $\{i - d, i, i + d, i + 2d, \ldots\}$. This captures long-range periodic dependencies (useful for structured documents, code with repeated patterns).
Global tokens: some special tokens (like [CLS] in BERT, or an explicit “summary” token) attend to and are attended by every position. This provides a bottleneck through which global information can flow even when most pairs don’t attend to each other directly.
Longformer (Beltagy et al., 2020) combines local window attention with global tokens. Most tokens use a local window; a few designated tokens (sentence boundaries, special tokens) use full attention. This gives $O(n)$ complexity while preserving global information flow. Longformer was designed for document-level tasks.
BigBird (Zaheer et al., 2020) combines local attention, global tokens, and a random component - each token also attends to a small random subset of positions. The random component ensures that, in expectation, information can flow between any two tokens in $O(\log n)$ steps (like a random graph expander). BigBird proved that this combination is Turing-complete - it can simulate any Turing machine computation - while remaining $O(n)$.
The limitation of sparse attention: if two tokens need to interact directly but aren’t in each other’s windows, they can only communicate indirectly across layers. This increases depth requirements for certain tasks. For tasks with genuinely long-range dependencies - cross-reference resolution in long documents, for example - pure sparse attention can miss interactions that full attention would catch.
Linear Attention: The Kernel Trick
Sparse attention approximates the structure of which tokens interact. Linear attention approximates the softmax function itself.
Standard attention can be written:
$$\text{Attention}(Q, K, V)_i = \frac{\sum_j \exp(Q_i K_j^T) V_j}{\sum_j \exp(Q_i K_j^T)}$$
The difficulty is the $\exp(Q_i K_j^T)$ terms - to compute them, you need all $(i, j)$ pairs, which is $O(n^2)$.
The linear attention trick: replace $\exp(Q_i K_j^T)$ with a kernel function $\phi(Q_i)^T \phi(K_j)$, where $\phi: \mathbb{R}^d \to \mathbb{R}^r$ is a feature map. Then:
$$\sum_j \phi(Q_i)^T \phi(K_j) V_j = \phi(Q_i)^T \left(\sum_j \phi(K_j) V_j^T\right)$$
The parenthesized sum $\sum_j \phi(K_j) V_j^T$ is a matrix of size $r \times d$ - it doesn’t depend on $i$, so you compute it once for the whole sequence. Then for each query $Q_i$, you just multiply $\phi(Q_i)^T$ against this precomputed matrix. Total cost: $O(nr^2)$ instead of $O(n^2 d)$.
If $r \ll n$, this is a dramatic speedup. The catch: the approximation quality depends on how well $\phi(Q_i)^T \phi(K_j)$ approximates $\exp(Q_i K_j^T / \sqrt{d})$. For random Fourier features (the approximation used in Performer, Choromanski et al., 2020), this approximation is unbiased but has variance that can be large for very concentrated attention distributions (where a few weights are dominant and the rest are near zero).
Mamba (Gu & Dao, 2023) approaches this differently using state-space models - a recurrent formulation that computes attention-like operations in $O(n)$ time and $O(1)$ memory per step, without explicitly computing an attention matrix at all. Mamba has shown strong results on long-sequence tasks and is one of the more promising alternatives to standard transformers.
Discomfort check. Linear attention methods have better theoretical complexity than standard attention, but theoretical complexity and wall-clock speed don’t always match on real hardware. Modern GPUs are highly optimized for the specific matrix operations in standard attention, while linear attention’s sequential structure can be harder to parallelize. The practical speedup of linear attention over FlashAttention (which is exact and highly optimized) is often smaller than the asymptotic analysis suggests. This is a recurring theme in efficient ML: the “efficient” algorithm on paper may not be the faster one in practice.
FlashAttention: The Right Bottleneck
The most practically important advance in transformer efficiency is FlashAttention (Dao et al., 2022), and it works by solving a different problem than you might expect.
The insight: the bottleneck is not arithmetic throughput - it’s memory bandwidth between HBM (high bandwidth memory, main GPU RAM) and SRAM (on-chip cache, fast but small).
Modern GPUs have a memory hierarchy:
- SRAM (L1 cache / shared memory): ~20 MB on an A100. Extremely fast: ~19 TB/s bandwidth.
- HBM (GPU DRAM): 40 - 80 GB on an A100. Slower: ~2 TB/s bandwidth.
Standard attention’s memory access pattern:
- Load $Q$, $K$ from HBM to SRAM. Compute $QK^T / \sqrt{d_k}$. Write $n \times n$ result back to HBM.
- Read $n \times n$ matrix from HBM. Apply softmax. Write result back to HBM.
- Read result and $V$ from HBM. Multiply. Write output to HBM.
This is 3 round trips to HBM for an $n \times n$ matrix. At $n = 16,000$, that’s $\sim$3 GB of HBM traffic just for the attention matrix - and this is repeated every layer.
FlashAttention’s solution: tile the computation so that $Q$, $K$, $V$ blocks always fit in SRAM, and never materialize the full $n \times n$ matrix in HBM.
The key algorithmic challenge is softmax: the softmax denominator for row $i$ is $\sum_j \exp(Q_i K_j^T)$, which requires seeing all $j$ to normalize. FlashAttention uses an online softmax algorithm - it maintains a running maximum and running sum as it processes blocks, updating both incrementally. When processing block $b$ of keys, it adjusts the previously computed partial softmax using the new maximum.
The math: suppose you’ve processed keys $K_1, \ldots, K_{b-1}$ and computed:
- Running max $m_{b-1} = \max_{j < b} Q_i K_j^T$
- Running sum $\ell_{b-1} = \sum_{j < b} \exp(Q_i K_j^T - m_{b-1})$
When you process the next block $K_b$ with new scores and max $m_b = \max(m_{b-1}, \max_j Q_i K_{j,b}^T)$:
$$\ell_b = e^{m_{b-1} - m_b} \ell_{b-1} + \sum_{j \in b} \exp(Q_i K_j^T - m_b)$$
This recurrence allows exact softmax computation in a single pass without ever storing the full $n \times n$ matrix.
The result: same mathematical output as standard attention, 2 - 4× faster, $O(n)$ memory instead of $O(n^2)$. No approximation.
FlashAttention-2 (Dao, 2023) improved parallelism by splitting work across sequence length (not just batch and head dimensions), achieving 2 - 3× further speedup. FlashAttention-3 adds optimizations for H100’s asynchronous execution.
FlashAttention is now the default attention implementation in PyTorch (torch.nn.functional.scaled_dot_product_attention), HuggingFace Transformers, vLLM, and essentially every production LLM serving system.
Mixture of Experts (MoE)
A different axis of efficiency: instead of making attention cheaper, make the feedforward network sparse.
In a standard transformer, every layer has one feedforward network (FFN) of size $d_{\text{model}} \times d_{\text{ffn}}$. It’s applied to every token, every layer. The FFN is often the largest component by parameter count.
Mixture of Experts replaces the single FFN with $E$ “expert” FFNs and a router. For each token, the router selects $k$ of the $E$ experts (typically $k = 2$ out of $E = 8$ or $E = 64$). Only the selected experts process that token. The outputs are summed, weighted by the router’s scores.
The result:
- Active parameters per token: $k/E$ of the total FFN parameters. With $k=2, E=8$: 25% of FFN parameters are active per token.
- Total parameter count: $E$ times larger than a dense model.
- Compute per token: same as a dense model with $k$ experts.
“Parameters don’t equal compute.” You can have a 400B-parameter model that costs as much to run per token as a 50B dense model.
GPT-4 is reportedly a mixture of experts (though the architecture hasn’t been officially disclosed). Mixtral 8×7B (Mistral AI, 2023) is a publicly available MoE with 8 experts of size 7B each - about 45B total parameters - but each token only activates 2 experts (about 13B active parameters). It matches or exceeds LLaMA-2 70B on most benchmarks while being significantly faster. Gemini 1.5 Pro uses MoE to achieve its million-token context at tractable cost.
The challenges with MoE:
- Load balancing: without explicit regularization, the router tends to send all tokens to a few popular experts, leaving others unused. This is fixed with an auxiliary load-balancing loss during training.
- Expert specialization: do experts learn different things? Evidence suggests they partially specialize by topic, syntax, and language - but the specialization is soft, not crisp.
- Communication overhead: in distributed training, expert outputs from different GPUs must be aggregated. This “all-to-all” communication becomes a bottleneck.
Load balancing is the hardest part. A naive MoE collapses: the router learns to always send tokens to the same 1-2 experts (they get more gradient signal, get better, attract more tokens - a rich-get-richer collapse). Every serious MoE implementation adds a load balancing auxiliary loss:
$$\mathcal{L}{\text{aux}} = \alpha \sum{i=1}^{E} f_i \cdot P_i$$
where $f_i$ is the fraction of tokens routed to expert $i$, $P_i$ is the average routing probability for expert $i$, and $\alpha$ controls the strength. In perfect balance, $f_i = P_i = 1/E$ for all $i$, so $\mathcal{L}_{\text{aux}} = \alpha/E$. Any deviation increases this loss. The coefficient $\alpha$ must be tuned carefully - too small and experts collapse, too large and the router becomes uniform (ignoring token content entirely).
Shared experts (used in DeepSeek-V2/V3) are always-active experts that absorb common patterns, freeing specialised experts to focus on rarer distributions. One shared expert is typically sufficient. Expert granularity - more experts with smaller dimensions vs fewer with larger - trades off parameter efficiency against specialization. Recent models (DeepSeek, Kimi K2) use high sparsity: 8 active experts out of 64-384 total. Loss-free load balancing (DeepSeek-V3 approach) adds a learned bias to routing scores rather than a loss term, avoiding interference with the primary training objective.
Grouped Query Attention and Multi-Query Attention
During inference, the KV cache is a major memory consumer. For every token generated, the model must store the key and value vectors for all previous tokens in all layers - because attention is computed against the full history.
For a model with $H$ attention heads, $d$ head dimension, $L$ layers, and context length $n$:
$$\text{KV cache size} = 2 \times H \times d \times L \times n \times \text{bytes per element}$$
For a 70B-parameter model with $H = 64$ heads, $d = 128$, $L = 80$ layers, at $n = 4096$ tokens and 16-bit precision: $2 \times 64 \times 128 \times 80 \times 4096 \times 2 \approx 10.7$ GB. Just for the KV cache. Add the model weights (~140 GB at 16-bit) and you’re over the capacity of most GPUs.
Multi-Query Attention (MQA) (Shazeer, 2019): use one shared $K$ head and one shared $V$ head across all query heads. Queries retain their per-head projections. This reduces the KV cache by $H \times$ - for 64 heads, a 64× reduction. The tradeoff: slightly lower model quality, because the keys and values can’t specialize per head.
Grouped Query Attention (GQA) (Ainslie et al., 2023): a middle ground. Group the $H$ query heads into $G$ groups ($G$ divides $H$). Each group shares one $K$ head and one $V$ head. KV cache reduction: $G \times$ instead of $H \times$, but with less quality degradation than MQA.
LLaMA-2 70B uses GQA with $G = 8$ (8 KV heads for 64 query heads). This gives an 8× reduction in KV cache compared to standard MHA, enabling much longer contexts at the same memory budget.
Multi-Head Latent Attention
Grouped Query Attention reduces the KV cache by sharing key-value heads across query groups. Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, takes a different approach: compress the KV cache into a low-dimensional latent vector, then project it back to full keys and values at inference time.
Instead of storing $n_h$ key-value pairs per token (as in MHA), MLA stores a single compressed latent $c_t^{KV} \in \mathbb{R}^{d_c}$ where $d_c \ll n_h \cdot d_h$:
$$c_t^{KV} = W^{DKV} h_t$$
At attention time, keys and values are recovered via learned up-projections:
$$K_t = W^{UK} c_t^{KV}, \qquad V_t = W^{UV} c_t^{KV}$$
The KV cache stores only $c_t^{KV}$ (dimension $d_c$ per token), not the full $n_h \cdot d_h$ key-value pairs. Compression ratios of 4-8x over MHA are typical, with performance comparable to or better than GQA.
A complication: RoPE applies position-dependent rotations to queries and keys. But if keys are computed from a shared latent, applying RoPE to keys would require materializing the full key matrix at every position - defeating the compression. MLA handles this by splitting queries and keys into two parts: a position-dependent component (where RoPE is applied) and a position-independent component (compressed in the latent). The position-dependent components use a separate low-rank compression called $c_t^Q$ for queries.
MHA vs GQA vs MQA vs MLA - when to use which:
- MHA: highest quality, largest KV cache, use when memory is not a constraint
- GQA (groups 4-8): 4-8x KV cache reduction, minimal quality loss, current default for most models
- MQA: maximum compression, but quality degrades noticeably in practice
- MLA: comparable KV cache to GQA, MHA-level quality, but complex implementation (RoPE decomposition, custom kernels required)
Summary
| Method | Complexity | Memory | Exact? | Key idea |
|---|---|---|---|---|
| Standard attention | $O(n^2 d)$ | $O(n^2)$ | Yes | Baseline |
| Local window attention | $O(nwd)$ | $O(nw)$ | Approx | Restrict to $2w+1$ neighbors |
| Longformer | $O(nwd)$ | $O(nw)$ | Approx | Window + global tokens |
| Linear attention / Performer | $O(nrd)$ | $O(nr)$ | Approx | Kernel feature map |
| FlashAttention | $O(n^2 d)$ | $O(n)$ | Exact | Tiling to avoid HBM |
| MoE FFN | $O(n \cdot k/E \cdot d^2)$ | Dense | Exact | Sparse FFN activation |
| GQA / MQA | $O(n^2 d)$ | $O(n)$ KV | Approx quality | Shared KV heads |
| MLA | $O(n^2 d)$ | $O(n \cdot d_c)$ KV | Exact | Low-rank KV latent compression |
The most impactful result in this list is FlashAttention: it achieves $O(n)$ memory without any approximation by changing the order of computation, not the computation itself. If you’re working with transformers in 2024, FlashAttention is not a research curiosity - it’s the default implementation.
Read next: