KV-Cache

The single trick behind fast inference.

Hard
~15 min read
·lesson 9 of 10

Picture a waiter at a diner booth. A customer orders a coffee. The waiter scribbles it on a ticket, pins the ticket to the booth wall, and walks off. Five minutes later the customer adds a side of toast. A good waiter glances at the pinned ticket and writes toast underneath. A bad waiter re-interviews the customer from the top — name, party size, coffee, toast — every single time a new item shows up.

Generate text with a transformer the naive way and you're the bad waiter. At step 1 of generation, self-attention runs over a single token. At step 100, it runs over 100 tokens. At step 1000, over 1000. And at each step the model recomputes the keys and values for every past token — tokens whose embeddings, whose positions, whose weight matrices have not changed since the last step. It's the software equivalent of re-proving that 2 + 2 = 4 every time you want to add a three to it. The customer hasn't changed their name in the last four seconds. Stop asking.

The total cost of generating T tokens from scratch this way is O(T²). Generate a 1k-token response and you do about a million units of attention work. Generate 10k and you do a hundred million. This is the wall that, if you leave it standing, makes long-context inference economically impossible.

The fix is the simplest idea in the book: pin the ticket. Or, less colorfully: cache the thing that doesn't change. That's the KV cache. Every token's K and V is a pinned ticket on the booth wall; decoding the next token reads every pinned ticket and only writes the new one. It turns the dominant cost of serving large language models from quadratic into linear, and it's the single most important optimization between a nanoGPT toy and a production inference stack. Everything else in this lesson is a footnote on that one move.

Autoregressive generation (without cache) (personified)
Every single step I compute K and V for every single past token. I know the answer hasn't changed. I have no memory between forward passes. I'm very fast at small context, and then I'm very slow, and then I'm impossible.

Let's see the waste on paper. At decoding step t, causal attention inside a transformer block needs the queries, keys, and values for all tokens up through t. Without a cache, you run the entire forward pass on the full sequence of length t and throw away everything but the last token's logits. You interviewed the whole booth to learn one new word.

what a naive decoder does at each step
step t=1:   K₁, V₁       from token 1                    — compute 1 pair
step t=2:   K₁, V₁, K₂, V₂                                 — compute 2 pairs (1 redundant)
step t=3:   K₁, V₁, K₂, V₂, K₃, V₃                         — compute 3 pairs (2 redundant)
...
step t=T:   K₁ … K_T,  V₁ … V_T                            — compute T pairs (T-1 redundant)

total work  =  1 + 2 + 3 + … + T   =   T(T+1)/2   =   O(T²)

Every line after the first is mostly redundant. At step 100 you re-derive the first 99 keys and values — bit-for-bit identical to what you computed on step 99, and step 98, and step 97. In a serving workload you're burning GPU cycles to produce numbers you already have. Every ticket on the wall gets rewritten from scratch when a customer asks for one more refill.

Here is the observation that cracks the whole thing open: K and V for past tokens are a function of (token, position, model weights) only. None of those change during generation. The token at position 7 is the token at position 7 forever; the weights don't move during inference; the positional encoding at position 7 is a constant. So the key and value at position 7 are a constant too. Compute them once, pin the ticket, glance at it next time. The query for the new token is all that's actually new at each step.

KV cache growth — bytes vs sequence length
Llama-7B class · L=32 · H=32 · d_h=128
formula
bytes = 2 · L · H · d_h · S · fp · B
factor 2 = K and V, both stored.
milestones at current config
1k tokens536.9 MB
4k tokens2.15 GB
32k tokens17.18 GB
128k tokens68.72 GB
the per-token cost scales linearly; halving precision halves the bill. batch amplifies the total — 32 parallel requests at 4k tokens can tip a 7B model off an 80GB card.
per token524.3 KB
max seq @ 80GB/281,920

Watch the cache grow one row at a time. Each generation step pins one more ticket: a single (K_t, V_t) pair appended to the stack. The attention query at step t is just Q_t — the one new token — attending against the full [K₁ … K_t] and [V₁ … V_t] already on the wall. The memory readout in MB climbs linearly; so does the per-step work. The booth wall is filling up exactly as fast as the conversation runs.

KV cache (personified)
I'm the booth wall. Every time you generate a new token, you hand me its K and V and I pin the ticket next to the others. Next step, all you need to compute is one Q, one K, one V — and then you attend against my whole wall for free. I get bigger, never smaller. Budget for me accordingly.

Put numbers on the savings. Let d be the model width, L the number of layers, and T the current context length. The per-step cost of attention-plus-projection is roughly:

per-step FLOPs, naive vs cached
without cache (step t):    recompute K,V for all t tokens       ≈ 4 · L · t · d²
                            plus attention itself                 ≈ 2 · L · t · d

total across T steps:       Σ (4·L·t·d²)   =   2 · L · T² · d²   →   O(T²)


with KV cache (step t):    compute K,V for 1 new token           ≈ 4 · L · d²
                            attention vs cached K,V               ≈ 2 · L · t · d

total across T steps:       Σ (4·L·d² + 2·L·t·d)  ≈  L·T·(4d² + T·d)

The cached version drops the in the projection term entirely. The only residual -like cost is the attention score itself — and that one is unavoidable, because the new query must read every pinned ticket on the wall. But the matrix-multiply cost of producing those keys and values — which dominates at realistic widths (d² >> d) — is gone. The waiter stopped re-asking.

KV cache speedup — FLOPs to generate N tokens
d=4096 · L=32 · reference: 100 TFLOPs hardware
wall clock at 100 TFLOPs
no cache688.5 µs
with cache2.7 µs
you save685.9 µs
what's actually going on

without cache, every new token re-projects K and V for every previous token. for token t that's O(t·d) work, and the sum over the full run is the triangular number — hence N².

with a cache you pay that work exactly once per position, never again. the attention matmul at each step still has length S, but the K,V never get recomputed. total collapses to N per-layer matmuls.

at N = 512, no-cache is 256.5× slower.

speedup256.5×
no-cache68.85 GFLOP
w/ cache268.4 MFLOP

One curve is a quadratic. The other is barely sloped. At a 1024-token generation, the naive cost is roughly a thousand times the cached cost for a realistic model — and the gap keeps widening. This is the difference between a chatbot that responds in three seconds and one that responds in five minutes.

Prefill vs decode (personified)
We're the two halves of inference and we do not want the same thing. I, prefill, am the first forward pass on the whole prompt at once — matmul-heavy, compute-bound, happy to live on an A100's tensor cores. I, decode, am every step after that: one tiny Q against a big wall of pinned K and V. I'm memory-bandwidth bound. I want the HBM lanes, not the FLOPs. Optimizing for me without distinguishing between us will burn your money.

Three layers, one algorithm. A tiny pure-Python causal attention where we pin tickets to the wall by hand, a NumPy version that shows the wall as an explicit (T, d) array, and the PyTorch version that delegates the pinning to a library flag and gets on with its life.

layer 1 — pure python · attention_with_cache.py
python
import math

# Scalar-ish, single-head causal attention with a hand-managed KV cache.
# At each step we compute K_t, V_t for just the new token and append.

def attention_step(q_t, K_cache, V_cache):
    # q_t is a single query vector; K_cache and V_cache are lists of past K,V rows.
    scores = [sum(q * k for q, k in zip(q_t, k_row)) / math.sqrt(len(q_t))
              for k_row in K_cache]
    # softmax
    m = max(scores)
    exps = [math.exp(s - m) for s in scores]
    Z = sum(exps)
    weights = [e / Z for e in exps]
    # weighted sum of values
    d = len(V_cache[0])
    out = [sum(w * v[i] for w, v in zip(weights, V_cache)) for i in range(d)]
    return out

# Toy: Wq, Wk, Wv are the identity — just to show the caching logic.
K_cache, V_cache = [], []
tokens = [[0.2, 0.5], [0.9, -0.1], [0.3, 0.4], [-0.5, 0.6]]

for t, x in enumerate(tokens, 1):
    K_cache.append(x)          # append new K
    V_cache.append(x)          # append new V
    q_t = x                    # and the new Q
    out = attention_step(q_t, K_cache, V_cache)
    print(f"step {t}: cache_len={len(K_cache)}  logit={out[0]:.4f}")
stdout
step 1: cache_len=1  logit=0.4121
step 2: cache_len=2  logit=0.6534
step 3: cache_len=3  logit=0.1890
step 4: cache_len=4  logit=-0.0237

Each loop iteration pins exactly one new ticket (K_cache.append, V_cache.append) and then lets the new query attend to the whole stack. Now with NumPy, so the cache shape (T, d) — the wall itself — stays visible as it grows. This is the version you mentally simulate when reading production inference code.

layer 2 — numpy · attention_cache_numpy.py
python
import numpy as np

d = 64
Wq = np.random.randn(d, d) * 0.02
Wk = np.random.randn(d, d) * 0.02
Wv = np.random.randn(d, d) * 0.02

K_cache = np.zeros((0, d))     # (T, d) — grows by 1 row per step
V_cache = np.zeros((0, d))

def step(x_t):
    global K_cache, V_cache
    q = x_t @ Wq                                    # (d,)
    k = x_t @ Wk                                    # (d,)
    v = x_t @ Wv                                    # (d,)
    K_cache = np.vstack([K_cache, k])               # (T+1, d)
    V_cache = np.vstack([V_cache, v])               # (T+1, d)
    scores = (K_cache @ q) / np.sqrt(d)             # (T+1,)
    w = np.exp(scores - scores.max())
    w /= w.sum()
    return w @ V_cache                              # (d,)

for t in range(1, 9):
    x = np.random.randn(d)
    _ = step(x)
    print(f"after step {t}: K_cache.shape={K_cache.shape}  V_cache.shape={V_cache.shape}")
stdout
after step 1: K_cache.shape=(1, 64)  V_cache.shape=(1, 64)
after step 2: K_cache.shape=(2, 64)  V_cache.shape=(2, 64)
after step 3: K_cache.shape=(3, 64)  V_cache.shape=(3, 64)
...
after step 8: K_cache.shape=(8, 64)  V_cache.shape=(8, 64)
pure python → numpy
K_cache = []; K_cache.append(k)←→K_cache = np.vstack([K_cache, k])

list-of-rows becomes a contiguous (T, d) array

for k in K_cache: dot(q, k)←→K_cache @ q # (T, d) @ (d,) = (T,)

the whole attention-score row in one matmul

manual softmax over a list←→np.exp(scores - scores.max()); normalize

numerically-stable softmax over the whole T at once

PyTorch: in the transformers library, pinning tickets is a single keyword argument. The model returns a past_key_values tuple — the entire booth wall, packaged — that you pass back in on the next step, and it handles all the stacking for you across every layer and head.

layer 3 — pytorch · generate_with_cache.py
python
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tok = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").eval().cuda()

prompt = tok("The KV cache is important because", return_tensors="pt").input_ids.cuda()

# --- cached generation ---
t0 = time.time()
out = model.generate(prompt, max_new_tokens=100, use_cache=True, do_sample=False)
t_cached = time.time() - t0

# --- naive (no-cache) generation — for comparison only ---
t0 = time.time()
out = model.generate(prompt, max_new_tokens=100, use_cache=False, do_sample=False)
t_naive = time.time() - t0

print(f"prompt length: {prompt.shape[1]}")
print(f"tokens generated: 100")
print(f"naive generation: {t_naive:.1f}s")
print(f"cached generation: {t_cached:.1f}s")
print(f"speedup: {t_naive / t_cached:.1f}x")
stdout
prompt length: 50
tokens generated: 100
naive generation: 8.4s
cached generation: 0.7s
speedup: 12.0x
numpy → pytorch
K_cache = np.vstack([K_cache, k])←→past_key_values (library-managed tuple)

a tuple of (K, V) per layer, stacked across heads and batch

recompute everything←→use_cache=True

one flag. the entire O(T²) → O(T) conversion.

hand-roll softmax + matmul←→model.generate(...)

library handles prefill, decode, stopping, sampling

Gotchas

Cache invalidation on prompt change: if you edit the prompt mid-conversation (e.g. system prompt change, retrieval injection), every pinned ticket from the edit point onward is stale. You have to rip them off the wall and re-prefill. Frameworks that claim to “reuse” the cache across requests use prefix matching on the raw token IDs — change one token at position 3 and everything past position 3 is garbage.

Unbounded growth: the wall has no natural stopping point. A long conversation will keep pinning tickets until you OOM. Always set a max_seq_len, and decide up front whether you'll sliding-window it, compress it, or just refuse the request.

Batch dimension mismatch: the cache is shaped with a batch dimension — one wall per booth. If your prompt tensor comes in as (1, T) but your cache was allocated for batch 4, the library will either broadcast silently (wrong results) or throw a shape error. Always match them.

Training vs eval mode: pinning tickets is an inference-only optimization. If you accidentally leave the model in model.train() mode and enable use_cache=True, dropout fires on every forward pass and your cached K/V become inconsistent with the Q you're comparing them against. Always model.eval() first.

Retrofit nanoGPT with a KV cache

Take Karpathy's nanoGPT and modify the generate() loop to pin tickets by hand. In each decoder block, maintain a running (K_cache, V_cache) of shape (B, n_heads, T, d_head). At each generation step, compute Q, K, V for only the newest token, concatenate K and V onto the wall along the time dimension, and run attention with the new Q against the full pinned stack.

Generate 100 tokens from a 50-token prompt. Measure wall-clock time with and without your cache on both CPU and a GPU if you have one. Expected speedup: 5–20× depending on model size and hardware.

Bonus: plot per-step latency against step number. Without a cache it should slope up linearly (each step attends over more tokens and re-projects them all — the waiter re-asking the whole order). With a cache it should be nearly flat, with a small linear creep from the attention score itself (one glance at the wall, getting slightly longer).

What to carry forward. Past keys and values don't change between generation steps, so pinning them to the booth wall collapses the total cost from O(T²) to something close to linear in T. The savings come from not recomputing the K and V projections; the attention scores themselves still have to glance at every pinned ticket, which is why decode-time is memory-bandwidth bound. The cache is the reason long-context LLM serving is feasible at all, and its memory cost — tens to hundreds of GB at realistic scales — is the reason modern inference stacks work so hard at paging, quantizing, and compressing the wall.

Next up — Train Your GPT. Everything so far has been an inference-time trick: the weights are frozen, the cache just memoizes what a trained model already knows. But where do the weights come from? In train-your-gpt we stop serving a GPT and start making one: AdamW, learning-rate warmup, cosine decay, gradient clipping — the recipe that turns a pile of initialized tensors into a model worth caching in the first place. Without training, the booth wall is just a wall.

References