Autoregressive decoding in a transformer requires a forward pass for each generated token. Without any caching, this means recomputing the key and value vectors for every previous token at every step - an $O(n^2)$ total cost for a sequence of length $n$. KV caching eliminates this redundancy, and the systems built around it define how modern LLM serving works.

The Autoregressive Bottleneck

At decoding step $t$, the transformer attends over the full prefix $y_1, \ldots, y_{t-1}$. For each attention head with query weight $W_Q$, key weight $W_K$, and value weight $W_V$, the keys and values for position $s$ are:

$$k_s = W_K h_s, \quad v_s = W_V h_s$$

where $h_s$ is the hidden state at position $s$. In naive decoding, these are recomputed at every step $t > s$. For a sequence of final length $n$ with $L$ layers and hidden dimension $d$, this costs $O(n^2 L d)$ total FLOPs - quadratic in sequence length.

KV Cache

The KV cache stores $K_{\le t} = [k_1, \ldots, k_t]$ and $V_{\le t} = [v_1, \ldots, v_t]$ for each layer. At step $t+1$, only the new token’s hidden state needs to be computed; its keys and values are appended to the cache, and attention is computed against the full cached $K$ and $V$. This reduces decoding to $O(n)$ FLOPs per step, or $O(n^2)$ total - but dominated by memory bandwidth rather than computation.

The memory footprint of the KV cache is:

$$\text{Memory} = 2 \times n_{\text{layers}} \times d_{\text{model}} \times \text{seq_len} \times \text{batch_size} \times \text{bytes_per_element}$$

For LLaMA 2 70B ($n_{\text{layers}} = 80$, $d_{\text{model}} = 8192$, bf16), a single sequence of length 4096 requires $2 \times 80 \times 8192 \times 4096 \times 2 \approx 10.7$ GB. With a batch of 16, this is 171 GB - comparable to the model weights themselves. KV cache memory is the binding constraint on serving throughput.

Multi-Query Attention (MQA)

Multi-Query Attention (Shazeer, 2019) uses a single set of key and value heads shared across all query heads. If there are $h$ query heads each of dimension $d_k$, MQA uses one key head and one value head of the same dimension, and all queries attend to the same $K$ and $V$.

This reduces KV cache memory by a factor of $h$ (from $h$ key-value pairs per layer to 1), at the cost of some expressiveness. Empirically, models trained with MQA from scratch match the quality of MHA models with much lower memory overhead. The GQA paper found that naive conversion of a trained MHA model to MQA (by averaging key-value heads) degrades quality, but MQA-trained models are competitive.

Grouped Query Attention (GQA)

Grouped Query Attention (Ainslie et al., 2023) interpolates between standard Multi-Head Attention (MHA) and MQA. The $h$ query heads are divided into $g$ groups, each with its own pair of key and value heads. Group $i$ handles query heads $(i-1)(h/g)+1$ through $i(h/g)$.

Memory reduction relative to MHA is a factor of $g/h$ (since there are $g$ KV heads instead of $h$). GQA with $g = h$ recovers MHA; $g = 1$ recovers MQA. In practice $g = h/8$ (e.g., 8 KV heads for 64 query heads) provides near-MQA memory efficiency with near-MHA quality. LLaMA 2 70B and Mistral 7B both use GQA.

Paged Attention

A practical problem in LLM serving is KV cache fragmentation. Sequences in a batch grow at different rates and have different final lengths, which are unknown at the start of generation. Naive allocation reserves the maximum possible sequence length upfront, wasting memory for sequences that terminate early.

Paged Attention (Kwon et al., vLLM, 2023) adapts the virtual memory abstraction from operating systems. The KV cache is divided into fixed-size pages (blocks of $B$ token slots each). A page table maps (sequence, layer, block index) to a physical page in GPU memory. When a sequence needs more cache space, new pages are allocated on demand; when a sequence terminates, its pages are immediately freed and can be reassigned to new sequences.

This eliminates fragmentation entirely: the only wasted memory is at most $B - 1$ slots per sequence per layer (the partially-filled last page). vLLM achieves $2\text{–}4\times$ higher throughput than naive serving for the same GPU memory budget by eliminating the padding overhead.

Speculative Decoding

Autoregressive decoding is memory-bandwidth bound: each forward pass processes one token but must load all model weights from HBM. Speculative decoding (Leviathan et al., 2022; Chen et al., 2023) exploits the observation that a small draft model can propose tokens quickly, and the large target model can verify a batch of them in a single forward pass.

The procedure for generating $\gamma$ tokens:

  1. Run the draft model autoregressively to propose $\gamma$ tokens $\tilde{y}1, \ldots, \tilde{y}\gamma$.
  2. Run the target model on the prompt plus all $\gamma$ draft tokens in a single batched forward pass, obtaining target distributions $p_1, \ldots, p_\gamma$.
  3. For each position $i$, accept draft token $\tilde{y}_i$ with probability $\min(1, p_i(\tilde{y}_i) / q_i(\tilde{y}_i))$ where $q_i$ is the draft distribution.
  4. If $\tilde{y}_i$ is rejected, sample a correction from $p_i - q_i$ (renormalised) and stop.

The accept-reject procedure guarantees that the output distribution is exactly the target distribution - speculative decoding is not an approximation. The speedup comes from amortising the target model’s per-step overhead over multiple draft tokens. With a draft model $3\text{–}4\times$ smaller than the target, acceptance rates of 0.7–0.9 per token are typical, yielding $2\text{–}3\times$ wall-clock speedups.

Prefix Caching

When many requests share a common prefix (e.g., a system prompt), recomputing the KV cache for that prefix on every request is wasteful. Prefix caching stores the KV cache for frequently seen prefixes and reuses them across requests, with the cost of initial computation amortised over many calls. This is particularly impactful for long system prompts or multi-turn conversations.

Flash Decoding

Standard Flash Attention is designed for the prefill phase (processing the full prompt in parallel), not for decoding (single-token queries attending over a long context). Flash Decoding (Dao et al., 2023) splits the key-value sequence across multiple thread blocks, each computing a partial attention output with its own softmax normaliser, then combines the partial results using the online softmax identity. This parallelises over the sequence length dimension, which is the bottleneck for long contexts, rather than the batch dimension, yielding significant speedups for long-context generation.

Examples

vLLM throughput vs naive serving. On LLaMA 13B with a batch of 256 requests and average output length 200 tokens, vLLM achieves approximately $24\times$ higher throughput than the Hugging Face naive implementation by combining paged attention (eliminating fragmentation), continuous batching (immediately reusing freed pages from completed sequences), and scheduling that maximises GPU utilisation.

Speculative decoding with a small draft model. Using LLaMA 2 7B as draft and LLaMA 2 70B as target, with $\gamma = 4$ draft tokens per step and a task-specific fine-tuned draft model, Chen et al. report $2.5\text{–}3\times$ wall-clock speedup on coding and summarisation tasks with no change in output quality, since the target distribution is preserved exactly by the accept-reject sampling.


Read Next: