KV Caching - Never Recompute What You've Already Attended To
Helpful context:
- Transformers From First Principles - Why Attention Changed Everything
- Efficient Transformers - Attention Without the Quadratic Cost
Generating 1000 tokens with a 7B parameter transformer naively requires 1000 separate forward passes, each processing the full sequence of up to 1000 tokens. That’s up to 500 million token-computations - processing each of the 1000 tokens up to 1000 times - to produce 1000 output tokens.
KV caching reduces this to roughly 7 million token-computations. A 70× speedup.
For anyone deploying language models at scale, understanding where this gain comes from - and where it goes away - is not optional. It determines latency, throughput, and cost.
Autoregressive Generation
Transformers generate text one token at a time. To produce token $t$, the model takes the full sequence of previous tokens as input and outputs a probability distribution over the vocabulary. You sample from this distribution, append the token, and repeat.
Formally: let $x_1, x_2, \ldots, x_T$ be the generated sequence. At step $t$, the model computes:
$$P(x_t \mid x_1, \ldots, x_{t-1})$$
and samples $x_t$ from it. Then at step $t+1$, it recomputes using $x_1, \ldots, x_t$.
The naive approach: feed the entire sequence seen so far through the transformer at every step. At step $t$, you run a forward pass on a sequence of length $t-1$. The cost of a transformer forward pass on sequence length $n$ is $O(n^2)$ (due to attention). Summing over $T$ steps:
$$\text{Total cost} \approx \sum_{t=1}^{T} O(t^2) = O(T^3).$$
For $T = 1000$ tokens, that’s a billion operations. For $T = 100{,}000$ tokens (long context), it’s a quintillion. Completely impractical.
The Redundancy in Attention
Let’s look more carefully at what the transformer computes during attention.
At each attention layer, for a sequence of length $n$, the model computes:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V.$$
Here $Q, K, V \in \mathbb{R}^{n \times d}$ are the query, key, and value matrices - linear projections of the token embeddings.
Now suppose we’ve already processed tokens $1, \ldots, t-1$ and we’re adding token $t$. We need the new attention output for position $t$ only (we’re doing causal/autoregressive generation, so position $t$ can only attend to positions $1, \ldots, t$).
The query at position $t$ is $q_t \in \mathbb{R}^d$. The attention output at position $t$ is:
$$o_t = \sum_{i=1}^{t} \alpha_{ti}v_i, \quad \text{where } \alpha_{ti} = \text{softmax}\left(\frac{q_t \cdot k_i}{\sqrt{d_k}}\right).$$
To compute this, we need:
- $q_t$: the query for the new token. Fresh computation.
- $k_i$ for $i = 1, \ldots, t$: the keys for all previous tokens plus the new one.
- $v_i$ for $i = 1, \ldots, t$: the values for all previous tokens plus the new one.
Here’s the critical observation: the keys and values for positions $1, \ldots, t-1$ were already computed at step $t-1$. They haven’t changed. The attention mask is causal, so earlier tokens don’t depend on later ones - the keys and values for position $i$ depend only on $x_1, \ldots, x_i$. Adding $x_t$ doesn’t change any of them.
We can cache them.
The KV Cache
The KV cache stores the key and value matrices for all previously processed tokens. At each new step, we:
- Compute only $Q$, $K$, $V$ for the new token (a single row, not the full matrix).
- Append the new $k_t$ and $v_t$ to the cache.
- Compute attention for the new token using the full cached key and value sequences.
Instead of an $O(t^2)$ forward pass at each step, we do an $O(t)$ computation (attending over $t$ cached keys/values). Total cost over $T$ steps:
$$\sum_{t=1}^{T} O(t) = O(T^2).$$
Still quadratic in $T$, but now the constant is much smaller - and the practical speedup is enormous. For $T = 1000$ tokens, naive is $O(T^3) = 10^9$; with KV cache it’s $O(T^2) = 10^6$. A 1000× improvement.
Discomfort check. “With KV caching, inference is $O(n)$ not $O(n^2)$, right?” Not exactly. The computation per new token is $O(n)$ (you attend over $n$ cached key-value pairs). But generating $T$ tokens with a prompt of length $P$ is $O((P + T) \cdot T)$ total - still quadratic in $T$. The gain is the constant: without caching, you’d recompute all $P$ keys and values at each of the $T$ decode steps, costing $O(P \cdot T + T^2)$ just for the re-computation, on top of the inherent $O(T^2)$ cost. KV caching eliminates the re-computation.
Memory Cost of the KV Cache
The cache doesn’t come free. You’re trading compute for memory.
For a model with $L$ transformer layers, $H$ attention heads, and head dimension $d_{\text{head}}$, the KV cache for $n$ tokens stores:
$$\text{Memory} = 2 \times L \times H \times d_{\text{head}} \times n \times \text{bytes per element}.$$
The factor of 2 is for both keys and values.
Concrete example: LLaMA-7B (32 layers, 32 heads, $d_{\text{head}} = 128$, FP16):
$$\text{Memory per token} = 2 \times 32 \times 32 \times 128 \times 2 \text{ bytes} = 524{,}288 \text{ bytes} \approx 0.5 \text{ MB/token}.$$
At different context lengths:
| Context Length | KV Cache Memory |
|---|---|
| 512 tokens | 256 MB |
| 4096 tokens | 2 GB |
| 32768 tokens | 16 GB |
| 128000 tokens | 64 GB |
For a model with 128K context (like Llama 3.1), the KV cache alone can use 64 GB - nearly the entire capacity of an A100-80GB.
Larger models are worse. LLaMA-70B (80 layers, 64 heads, $d_{\text{head}} = 128$, FP16) uses roughly $2 \times 80 \times 64 \times 128 \times 2 \approx 4.2$ MB per token. At 4096 tokens: 17 GB just for the cache.
This is why KV cache memory management is a first-class concern in LLM inference infrastructure.
Two Phases of LLM Inference
Running a language model in production involves two qualitatively different phases:
Prefill phase: process the input prompt. The full prompt (say, 500 tokens) is passed through the transformer in one parallel forward pass. All 500 KV pairs are computed and cached simultaneously. This is compute-bound: the GPU is doing lots of matrix multiplications in parallel.
Decode phase: generate new tokens one at a time, using the cached KVs. At each step, only one token’s computation runs - but it needs to access the full cache. This is memory-bandwidth-bound: the dominant cost is reading the model weights and KV cache from GPU memory, not the computation itself.
This distinction matters for hardware utilization:
- During prefill, GPU compute is fully utilized. A100: 312 TFLOPS. Fast.
- During decode, GPU compute is mostly idle. The bottleneck is the 2 TB/s memory bandwidth limit. You’re reading ~14 GB of model weights (7B params × 2 bytes) and the growing KV cache per token generated.
For a 7B model: reading 14 GB at 2 TB/s takes 7 ms. That’s a ceiling of about 143 generated tokens per second - and that’s before accounting for the KV cache reads, which add another few GB.
Continuous Batching
The simplest way to serve multiple users is to form batches: group several requests together and process them in parallel. Batching lets the GPU amortize the weight reads across multiple sequences simultaneously, improving throughput.
But requests have different lengths. A batch of 8 sequences might have 3 sequences that finish at step 50 and 5 that continue to step 200. With naive batching, you wait for the longest sequence before releasing any GPU slots - the 3 fast sequences sit idle for 150 steps.
Continuous batching (Orca, Yu et al., 2022): on every decode step, check if any sequence has finished generating. If so, immediately inject the next waiting request into the batch, filling the empty slot. The batch size stays constant; its membership changes dynamically.
Result: GPU utilization improves dramatically. Sequences aren’t waiting in a queue behind slower requests. Throughput improvements of 5 - 10× over naive batching are typical.
PagedAttention: Managing KV Cache Memory
Even with continuous batching, KV cache memory management is tricky.
Each request needs a KV cache proportional to its maximum sequence length. But you don’t know in advance how long a response will be. If you over-allocate (reserve memory for the maximum context length), you waste memory. If you under-allocate, the request may fail mid-generation.
The standard solution before vLLM: pre-allocate a large contiguous block per request. This causes severe fragmentation - gaps of unused memory between allocations, similar to heap fragmentation in C programs. GPU memory utilization of 20 - 40% was common.
PagedAttention (vLLM, Kwon et al., 2023): inspired by OS virtual memory paging.
Instead of one contiguous KV cache block per sequence, allocate KV cache in fixed-size pages (blocks), each holding the key-value pairs for $B$ tokens. Maintain a page table per sequence mapping logical positions to physical pages. Pages can be non-contiguous in GPU memory.
When a sequence needs more KV cache space, allocate a new page - which can go anywhere in memory. When a sequence finishes, its pages are returned to a free pool.
Benefits:
- Near-zero fragmentation: the only wasted memory is at most one page per sequence (the last, partially filled page).
- Memory utilization near 100%: across a batch of sequences, almost all KV memory is actively used.
- Sharing: if multiple sequences share a prefix (common in instruction-tuned models where the system prompt is repeated), the shared prefix pages are stored once and referenced by all sequences. Copy-on-write semantics for the diverging parts.
In practice, vLLM’s PagedAttention enables 2 - 4× more concurrent requests for the same GPU memory budget, compared to prior approaches.
Speculative Decoding
The decode phase bottleneck - sequential token generation - seems fundamental to autoregressive models. But there’s a clever trick.
Speculative decoding (Leviathan et al., 2023; Chen et al., 2023): use a small “draft” model to propose $k$ tokens speculatively, then verify all $k$ with the large target model in parallel.
The procedure:
- The small draft model (e.g., 7B) auto-regressively generates $k$ token proposals $\hat{x}_1, \ldots, \hat{x}_k$ (cheap, fast).
- The large target model (e.g., 70B) processes all $k$ proposed tokens in a single forward pass (parallel, like prefill). It computes the probability of each proposed token.
- Accept or reject each proposal using a rejection sampling criterion that guarantees the accepted tokens follow the target model’s distribution exactly (not an approximation).
- If all $k$ are accepted, you’ve generated $k$ tokens with roughly one large model forward pass. If some are rejected, you’ve still generated the accepted prefix.
Expected tokens per large model forward pass: $k \cdot \alpha + 1$, where $\alpha$ is the acceptance rate. For draft/target model pairs with similar distributions, $\alpha \approx 0.7$ - $0.9$, giving 4 - 8 tokens per large model call instead of 1.
Speedup: 2 - 3× for typical use cases, with no change to output quality (the output distribution is mathematically identical to the target model’s).
KV Cache Quantization
Recall the LLaMA-7B cache: 0.5 MB per token. At 4096 tokens, that’s 2 GB in FP16. If you quantize the cached KV values to INT8 (1 byte instead of 2), the cache halves to 1 GB. INT4 cuts it to 0.5 GB.
KV cache quantization compresses the cached keys and values to lower precision. The quality impact is typically small - the attention mechanism is robust to quantization noise in the keys and values, especially compared to quantizing model weights. INT8 KV cache is nearly lossless on most benchmarks; INT4 shows slight degradation on demanding reasoning tasks.
Combining KV cache quantization with PagedAttention: you can serve sequences 2 - 4× longer for the same GPU memory, or run 2 - 4× more concurrent sequences.
Grouped-Query Attention
Standard multi-head attention (MHA) uses $H$ query heads, $H$ key heads, and $H$ value heads. The KV cache stores $H$ heads of keys and values per token.
Multi-Query Attention (MQA) (Shazeer, 2019): use one key head and one value head shared across all $H$ query heads. KV cache memory: $H$× smaller. Modest quality loss.
Grouped-Query Attention (GQA) (Ainslie et al., 2023): use $G$ key/value heads ($1 < G < H$) shared across groups of $H/G$ query heads. A middle ground between MHA and MQA.
LLaMA-3, Mistral, and most recent models use GQA with $G = 8$ or $G = H/4$. This reduces KV cache by $4\times$ compared to MHA with minimal quality loss - a major reason modern models can handle longer contexts.
Summary
| Concept | Description |
|---|---|
| Naive autoregressive | Recompute full forward pass every step. Cost: $O(T^3)$ |
| KV cache | Store K and V for all previous tokens. Cost: $O(T^2)$ |
| Cache memory | $2 \times L \times H \times d_{\text{head}} \times n \times$ bytes |
| LLaMA-7B cache | $\approx 0.5$ MB per token in FP16 |
| Prefill phase | Process full prompt in parallel; compute-bound |
| Decode phase | Generate one token at a time; memory-bandwidth-bound |
| Continuous batching | Dynamically swap finished sequences out; 5 - 10× throughput gain |
| PagedAttention | Page-table memory management for KV cache; near-100% memory utilization |
| Speculative decoding | Draft model proposes $k$ tokens; target verifies in parallel; 2 - 3× speedup |
| KV quantization | Compress cached KV to INT8/INT4; 2 - 4× memory reduction |
| GQA | Share KV heads across query head groups; 4× cache reduction, minimal quality loss |
Efficient LLM inference is a systems problem as much as a modeling problem. The KV cache is the central data structure - everything else (PagedAttention, continuous batching, quantization, GQA) is infrastructure for managing it at scale.
Read next: