Self Attention
Queries, keys, values — derived and animated.
Picture a dinner party. Twenty guests around a long table, everyone talking at once. Now picture the same party the way a RNNs would have to run it: one guest speaks, the next guest listens to a compressed summary of everything that's been said so far, then whispers their own summary to the guest on their right. By the time the rumor has crossed the room, the first joke has been rewritten, forgotten, and re-forgotten through nineteen hops of a very tired game of telephone. Long-range conversation is possible in principle and miserable in practice.
Self-attention runs the party differently. Every guest turns to every other guest at the same instant and silently asks the same question: how relevant are you to me right now? No passing notes down the chain. No summarizing. Guest 100 hears guest 1 with the same clarity as guest 99 — the shortest path between any two people in the room is always one look across the table. That single architectural decision — “everyone reads everyone, in parallel” — is the engine underneath BERT, GPT, Claude, every modern LLM you've touched.
This lesson builds the room from the ground up. We'll meet the three things each guest carries, derive the Q/K/V projection trick, unpack the scaled dot-product formula a factor at a time, see why the √d_k in the denominator is mathematically non-negotiable, watch a live attention matrix light up over a real sentence, and implement causal attention from scratch in three flavors.
tokens X (N × d) — "The cat sat on the mat"
│
│ linear projections (learned)
├──────────┬──────────┐
▼ ▼ ▼
Q = X·Wq K = X·Wk V = X·Wv (N × d_k, N × d_k, N × d_v)
│ │ │
└────┐ │ ┌────┘
▼ ▼ │
scores = Q · Kᵀ (N × N)
│
÷ √d_k keeps softmax un-saturated
│
(optional mask — causal, pad)
│
softmax along KEYS rows sum to 1
│
attn weights A (N × N)
│
· V (N × d_v)
│
▼
output Z each row = context-aware tokenThree learned matrices. One matmul, one scale, one softmax, one more matmul. That is all of attention. Every paper you'll read in this section is a variation, optimization, or re-packaging of the diagram above — the same dinner party, re-catered.
Now the reveal. Each guest at the table arrives carrying three things — not two, not five, three — and once you see them separately the whole mechanism clicks. Given a sequence of N token embeddings (our word embeddings for the guests) X ∈ ℝ^(N × d), attention projects each guest three ways:
Q = X · W_Q queries (N × d_k) K = X · W_K keys (N × d_k) V = X · W_V values (N × d_v)
Same guest, three different pitches. A guest's query is “what am I interested in right now?” — the topic they're trying to follow. Their key is “here's what I have to offer” — the label they're broadcasting across the room, the nametag that tells everyone else what they're about. Their value is “here's what I'll actually say if you call on me” — the contribution they make once they're chosen. The weights W_Q, W_K, W_V are the only learnable parameters in a single-head attention block, and they're what distinguish a great party host from a random one — they've learned which questions to ask, which nametags to print, and which things to say when called on.
Why three roles and not one? Because the question a guest is asking (“who here is talking about music?”) is genuinely different from the nametag they're wearing (“I play jazz”) which is different from the story they'll tell when called on (a fifteen-minute anecdote about a bass player in 1974). Collapse any two of those roles into one and the mechanism stops working — the room loses the ability to look for something different from what it's offering.
Four guests at the table, fully projected. Highlight any query row and the widget draws a dot product with every key — that's one guest turning to the room and sizing up every nametag. The values don't enter the picture yet; they're waiting patiently on the right, ready to be weighted and summed once softmax has decided who gets the floor. Notice the shapes: Q and K are the same width (d_k, so their dot product is scalar), V can be a different width (d_v) — in practice they're usually equal, but nothing in the math requires it.
┌─ Q · Kᵀ ─┐
Attention(Q, K, V) = softmax │ ─────── │ · V
└─ √d_k ┘One line. Four operations. Let's unpack it piece by piece, because every factor in there is doing work.
I'm the question each token asks on every forward pass. I don't carry the answer — I just score how relevant every key is to what I'm looking for. Dot me against every key, normalize with softmax, and the resulting distribution tells the layer above which values to weight and which to ignore. I am replaced at every layer; the network gets a fresh set of questions at every depth.
(1) scores = Q · Kᵀ (N × N) — raw similarity (2) scaled = scores / √d_k — variance fix (3) weights = softmax(scaled, axis=keys) — row-wise distribution (4) output = weights · V (N × d_v) — context-mixed values
Step 1 — raw scores. Q · Kᵀ is an N × N matrix. Entry (i, j) is guest i's query dotted with guest j's key — how interested this guest is in what that guest is offering. Every pair in the room scored at once, in a single matmul. This is the payoff of the matrix formulation: the GPU does the whole thing with no Python loops in sight.
Step 2 — the √d_k scaling. This looks arbitrary. It isn't. Suppose q and k are vectors of dimension d_k with unit-variance entries. Their dot product is a sum of d_k zero-mean unit-variance products, so by the central limit theorem the dot product has variance d_k and standard deviation √d_k. For d_k = 64, scores routinely land at ±8. Feed that into softmax and one entry dominates so hard that the gradient on every other entry goes to ~0 — the softmax has saturated. A single guest is shouting so loudly that nobody else in the room can be heard. Dividing by √d_k pulls the variance back to 1, keeping the distribution soft and gradients alive.
Step 3 — softmax. Applied along the keys axis (rows of the score matrix). Row i becomes a probability distribution over keys: exactly how guest i splits their attention budget across the room. They have one unit of listening to spend, and softmax tells them how to spend it. The “wrong axis” mistake — softmaxing along queries instead of keys — is one of the most common bugs in from-scratch attention. Read Step 3 twice and tape it to your monitor.
Step 4 — weighted sum of values. Multiply the N × N attention weights by the N × d_v value matrix. Output row i is a convex combination of every guest's value vector, weighted by how much guest i was listening to each of them. Each output is a context-aware impression of that position — a single guest's final takeaway from the room, assembled out of what everyone else said, pre-weighted by how much they mattered.
That's an actual attention matrix from a small model run on a real sentence. Scrub the query selector and watch one row at a time light up — that row is a single guest showing you who they ended up listening to. The bright diagonal is a hint that guests attend to themselves a lot (sensible — a token's own embedding is usually the best single source of information about it, in the same way that knowing what you came to the party to talk about is a decent prior on what you'll say next). The off-diagonal heat is where the interesting learning lives: a pronoun listening to its antecedent, a verb listening to its subject, a closing bracket listening to its opener. Every pattern you see here is emergent — no one programmed “link pronoun to antecedent,” the network learned to eavesdrop that way because doing so lowered the loss.
I'm the attention distribution — the who listens to whom of this layer. Every row of me is a probability distribution; I sum to 1 across keys by construction. When I'm sharp, I pick one source. When I'm flat, I average everything. The network tunes me by moving the Q/K weights that produced me; I don't have parameters of my own, I'm just the shape their interaction takes.
Three implementations of the same block. Pure Python on a tiny 4-guest party so you can see every index. NumPy with einsum to do the whole thing in one expression. PyTorch's F.scaled_dot_product_attention — the call you actually use in production, which dispatches to FlashAttention kernels under the hood.
import math
import random
def matmul(A, B):
m, k = len(A), len(A[0])
k2, n = len(B), len(B[0])
assert k == k2
return [[sum(A[i][t] * B[t][j] for t in range(k)) for j in range(n)] for i in range(m)]
def transpose(A):
return [[A[i][j] for i in range(len(A))] for j in range(len(A[0]))]
def softmax_row(row):
m = max(row) # stability — subtract max before exp
ex = [math.exp(v - m) for v in row]
s = sum(ex)
return [v / s for v in ex]
def attention(X, Wq, Wk, Wv):
Q = matmul(X, Wq)
K = matmul(X, Wk)
V = matmul(X, Wv)
d_k = len(Q[0])
scores = matmul(Q, transpose(K)) # (N × N) raw similarities
scale = math.sqrt(d_k)
scaled = [[s / scale for s in row] for row in scores]
weights = [softmax_row(row) for row in scaled] # softmax along keys
return matmul(weights, V), weights
# 4 tokens, embedding dim 4, projection dim 4
random.seed(0)
X = [[random.gauss(0, 1) for _ in range(4)] for _ in range(4)]
Wq = [[random.gauss(0, 0.5) for _ in range(4)] for _ in range(4)]
Wk = [[random.gauss(0, 0.5) for _ in range(4)] for _ in range(4)]
Wv = [[random.gauss(0, 0.5) for _ in range(4)] for _ in range(4)]
out, A = attention(X, Wq, Wk, Wv)
for row in A:
print([round(v, 2) for v in row])
print("output[0] =", [round(v, 2) for v in out[0]])attention weights (row-stochastic): [0.37, 0.28, 0.19, 0.16] [0.22, 0.41, 0.24, 0.13] [0.18, 0.27, 0.38, 0.17] [0.15, 0.20, 0.28, 0.37] output[0] = [0.41, -0.12, 0.33, 0.07]
Now with NumPy. The whole block collapses to five lines. einsum makes the axis contractions explicit — "nd,md->nm" is literally “for each (n, m) pair, sum over the d axis,” which is the definition of Q · Kᵀ. Every guest scoring every other guest, one index at a time, but all at once.
import numpy as np
def scaled_dot_product_attention(X, Wq, Wk, Wv, mask=None):
Q = X @ Wq # (N, d_k)
K = X @ Wk # (N, d_k)
V = X @ Wv # (N, d_v)
d_k = Q.shape[-1]
scores = np.einsum("nd,md->nm", Q, K) / np.sqrt(d_k) # (N, N) scaled
if mask is not None:
scores = np.where(mask, scores, -np.inf) # mask BEFORE softmax
weights = np.exp(scores - scores.max(axis=-1, keepdims=True))
weights /= weights.sum(axis=-1, keepdims=True) # softmax along keys
return weights @ V, weights
rng = np.random.default_rng(0)
X = rng.standard_normal((4, 4))
Wq, Wk, Wv = [rng.standard_normal((4, 4)) * 0.5 for _ in range(3)]
out, A = scaled_dot_product_attention(X, Wq, Wk, Wv)
print("rows of A sum to 1:", np.allclose(A.sum(axis=-1), 1.0))
print("A shape:", A.shape, " out shape:", out.shape)triple-nested loop for Q · Kᵀ←→np.einsum("nd,md->nm", Q, K)— axis labels spell out the contraction — no indexing off-by-ones
per-row softmax with manual max subtract←→exp(scores - scores.max(-1, keepdims=True))— same numerical-stability trick, broadcast over all rows at once
math.sqrt(d_k) at each scale←→np.sqrt(d_k)— identical scalar, applied to the whole score matrix by broadcast
And now PyTorch. In real code you never handwrite attention — you call F.scaled_dot_product_attention, which dispatches to FlashAttention or memory-efficient kernels depending on your hardware. Or for a full multi-head block, nn.MultiheadAttention. What used to be 20 lines of NumPy is a single call with a causal flag.
import torch
import torch.nn.functional as F
torch.manual_seed(0)
B, N, d_k = 2, 4, 4 # batch, sequence, head-dim
X = torch.randn(B, N, d_k)
Wq = torch.randn(d_k, d_k) * 0.5
Wk = torch.randn(d_k, d_k) * 0.5
Wv = torch.randn(d_k, d_k) * 0.5
Q, K, V = X @ Wq, X @ Wk, X @ Wv
# (1) Handwritten, for comparison
scores = (Q @ K.transpose(-2, -1)) / (d_k ** 0.5)
A = F.softmax(scores, dim=-1) # along keys
manual_out = A @ V
# (2) Built-in — dispatches to FlashAttention when available
sdpa_out = F.scaled_dot_product_attention(Q, K, V, is_causal=False)
print("manual == sdpa:", torch.allclose(manual_out, sdpa_out, atol=1e-6))
print("out shape:", sdpa_out.shape)manual == sdpa: True out shape: torch.Size([2, 4, 4])
Q = X @ Wq; K = X @ Wk; V = X @ Wv←→same, tracked for autograd, runs on GPU— @ is identical; tensors carry grad
np.einsum + manual softmax + mask←→F.scaled_dot_product_attention(Q, K, V, is_causal=True)— one call — dispatches to FlashAttention, memory-efficient kernels, or math impl
custom block for each head, concatenate←→nn.MultiheadAttention(embed_dim, num_heads)— production-ready multi-head with all the bookkeeping done for you
One more piece — causal masking. A language model generating text can't be allowed to peek at future tokens while predicting the next one. At our dinner table: guest t is only allowed to listen to guests who've already spoken. Letting them eavesdrop on guests who haven't arrived yet would be training-time cheating of the most embarrassing kind. The fix is a mask applied to the score matrix before softmax: set every entry above the diagonal to -∞, so after softmax those weights are exactly 0 and never leak into the output.
keys →
k0 k1 k2 k3
q0 [ s00 -∞ -∞ -∞ ]
q1 [ s10 s11 -∞ -∞ ] ← token t can only attend to tokens ≤ t
q2 [ s20 s21 s22 -∞ ]
q3 [ s30 s31 s32 s33 ]
after softmax along keys:
k0 k1 k2 k3
q0 [ 1.00 0 0 0 ]
q1 [ 0.43 0.57 0 0 ]
q2 [ 0.29 0.35 0.36 0 ]
q3 [ 0.21 0.26 0.27 0.26]This is what is_causal=True does inside PyTorch's SDPA — it fills the upper triangle with -∞ before the softmax runs. In a decoder-only model like GPT, every attention layer is causal. In an encoder like BERT, none of them are — everyone at the table speaks at once and hears everyone else. In an encoder-decoder like T5, the encoder is bidirectional and the decoder is causal with an extra cross-attention step.
Softmax on the wrong axis: the attention softmax goes over keys (the last dim of a Q Kᵀ matrix). Softmax over queries instead, and rows of A no longer sum to 1 — the output is silently nonsense and the loss still decreases a little, so you won't notice until validation tanks. Always dim=-1, always triple-checked.
Forgetting √d_k: the network will still train, slowly, and with a softmax so peaked that gradients flow through one key at a time. Everything converges worse. This bug is invisible on toy d_k = 4 sequences and devastating at d_k = 64+.
Masking after softmax: if you zero out masked positions after the softmax, the remaining weights no longer sum to 1 and you've re-introduced a tiny leak from the masked side through the normalization. The mask must be applied by setting scores to -∞ before softmax. Every time. No exceptions.
Mask shape mismatch: SDPA expects the mask to broadcast against (B, H, N, N). A (N, N) bool mask works; a (B, N) padding mask does not (that's key_padding_mask on nn.MultiheadAttention). Read the docstring every time — the conventions shift between APIs.
Implement a single-head causal self-attention block in PyTorch using only @, F.softmax, and torch.triu — no calls to F.scaled_dot_product_attention or nn.MultiheadAttention. Your signature: attention(x, Wq, Wk, Wv) -> (out, weights) for input x of shape (B, N, d).
Build a causal mask with torch.triu(torch.ones(N, N), diagonal=1).bool(). Apply it with scores.masked_fill_(mask, float('-inf')) before the softmax. Divide by √d_k.
Now verify. Run the same (x, Wq, Wk, Wv) through F.scaled_dot_product_attention(Q, K, V, is_causal=True) and assert that your output matches to 1e-6. If it doesn't, the culprit is almost always (a) wrong softmax axis, (b) mask applied after softmax, or (c) missing the √d_k.
Bonus: print the attention weights for a small N = 6 sequence. Confirm the upper triangle is exactly 0 and every row sums to exactly 1.
What to carry forward. Attention is a dinner party rendered in linear algebra: every guest turns to every other guest with a query, a key, and a value — three learned projections, one scaled dot product, one softmax, one weighted sum of values. The √d_k keeps the softmax from saturating at reasonable d_k. Softmax goes along keys. Masks go in before softmax. The whole room is permutation-equivariant, which is why we'll staple positional encodings to the input. And the O(N²) memory cost — every guest scoring every other guest — is the single biggest bottleneck in modern transformer scaling, which is why your next six papers will all be about sidestepping it.
Next up — Multi-Headed Self Attention. One conversation at one table is fine — but the room would be richer if every guest could have several conversations at once, each specializing in a different angle. A pronoun probably wants to find its antecedent and its syntactic role and its verb, simultaneously. The answer is multi-headed self attention: run several attention heads in parallel with different W_Q / W_K / W_V, concatenate their outputs, and project down. We'll derive why n_heads matters, what each head tends to specialize in, and how a 12-head attention block is still a single matmul if you squint at it right.
If a step here felt fast, revisit these first.
Natural continuations that build directly on this.
Why does scaled dot-product attention divide by √d_k before the softmax?
- [01]Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, Polosukhin · NeurIPS 2017 — the paper that introduced the transformer
- [02]Bahdanau, Cho, Bengio · ICLR 2015 — attention in NMT, the seed of the whole idea
- [03]Andrej Karpathy · YouTube / nanoGPT — the clearest code walkthrough of causal self-attention
- [04]Zhang, Lipton, Li, Smola · d2l.ai §11.1–11.3 — textbook derivation with runnable code
- [05]Dao, Fu, Ermon, Rudra, Ré · NeurIPS 2022 — how production engines dodge the O(N²) memory wall