Self Attention

Queries, keys, values — derived and animated.

Hard
~15 min read
·lesson 1 of 3

Picture a dinner party. Twenty guests around a long table, everyone talking at once. Now picture the same party the way a RNNs would have to run it: one guest speaks, the next guest listens to a compressed summary of everything that's been said so far, then whispers their own summary to the guest on their right. By the time the rumor has crossed the room, the first joke has been rewritten, forgotten, and re-forgotten through nineteen hops of a very tired game of telephone. Long-range conversation is possible in principle and miserable in practice.

Self-attention runs the party differently. Every guest turns to every other guest at the same instant and silently asks the same question: how relevant are you to me right now? No passing notes down the chain. No summarizing. Guest 100 hears guest 1 with the same clarity as guest 99 — the shortest path between any two people in the room is always one look across the table. That single architectural decision — “everyone reads everyone, in parallel” — is the engine underneath BERT, GPT, Claude, every modern LLM you've touched.

This lesson builds the room from the ground up. We'll meet the three things each guest carries, derive the Q/K/V projection trick, unpack the scaled dot-product formula a factor at a time, see why the √d_k in the denominator is mathematically non-negotiable, watch a live attention matrix light up over a real sentence, and implement causal attention from scratch in three flavors.

        tokens X                  (N × d)   — "The cat sat on the mat"
            │
            │  linear projections (learned)
            ├──────────┬──────────┐
            ▼          ▼          ▼
         Q = X·Wq   K = X·Wk   V = X·Wv      (N × d_k, N × d_k, N × d_v)
            │          │          │
            └────┐     │     ┌────┘
                 ▼     ▼     │
              scores = Q · Kᵀ                  (N × N)
                 │
                 ÷  √d_k                        keeps softmax un-saturated
                 │
                 (optional mask — causal, pad)
                 │
              softmax along  KEYS               rows sum to 1
                 │
              attn weights  A                   (N × N)
                 │
                 ·  V                           (N × d_v)
                 │
                 ▼
             output  Z                          each row = context-aware token
self-attention block — the shape you will see a thousand times

Three learned matrices. One matmul, one scale, one softmax, one more matmul. That is all of attention. Every paper you'll read in this section is a variation, optimization, or re-packaging of the diagram above — the same dinner party, re-catered.

Now the reveal. Each guest at the table arrives carrying three things — not two, not five, three — and once you see them separately the whole mechanism clicks. Given a sequence of N token embeddings (our word embeddings for the guests) X ∈ ℝ^(N × d), attention projects each guest three ways:

three learned projections per token
Q  =  X · W_Q          queries   (N × d_k)
K  =  X · W_K          keys      (N × d_k)
V  =  X · W_V          values    (N × d_v)

Same guest, three different pitches. A guest's query is “what am I interested in right now?” — the topic they're trying to follow. Their key is “here's what I have to offer” — the label they're broadcasting across the room, the nametag that tells everyone else what they're about. Their value is “here's what I'll actually say if you call on me” — the contribution they make once they're chosen. The weights W_Q, W_K, W_V are the only learnable parameters in a single-head attention block, and they're what distinguish a great party host from a random one — they've learned which questions to ask, which nametags to print, and which things to say when called on.

Why three roles and not one? Because the question a guest is asking (“who here is talking about music?”) is genuinely different from the nametag they're wearing (“I play jazz”) which is different from the story they'll tell when called on (a fifteen-minute anecdote about a bass player in 1974). Collapse any two of those roles into one and the mechanism stops working — the room loses the ability to look for something different from what it's offering.

self-attention computed by hand — click through each stage
d_model = 4 · head_dim = 3 · 3 tokens
input X3 tokens × d_model=4 — hand-picked embeddings.
X (3×4)
the
cat
sat
d0
d1
d2
d3
1.00
0.00
0.50
0.20
0.20
1.00
0.10
0.30
0.30
0.40
1.00
0.10
·
W_Q (4×3)
d0
d1
d2
d3
q0
q1
q2
0.50
0.10
-0.20
0.30
0.70
0.10
-0.10
0.20
0.60
0.40
-0.30
0.20
=
stage0/6

Four guests at the table, fully projected. Highlight any query row and the widget draws a dot product with every key — that's one guest turning to the room and sizing up every nametag. The values don't enter the picture yet; they're waiting patiently on the right, ready to be weighted and summed once softmax has decided who gets the floor. Notice the shapes: Q and K are the same width (d_k, so their dot product is scalar), V can be a different width (d_v) — in practice they're usually equal, but nothing in the math requires it.

scaled dot-product attention — the entire formula
                       ┌─ Q · Kᵀ ─┐
Attention(Q, K, V)  =  softmax │ ─────── │  · V
                       └─  √d_k  ┘

One line. Four operations. Let's unpack it piece by piece, because every factor in there is doing work.

Query (personified)
I'm the question each token asks on every forward pass. I don't carry the answer — I just score how relevant every key is to what I'm looking for. Dot me against every key, normalize with softmax, and the resulting distribution tells the layer above which values to weight and which to ignore. I am replaced at every layer; the network gets a fresh set of questions at every depth.
step-by-step, one factor at a time
(1)  scores   =  Q · Kᵀ                    (N × N)   — raw similarity
(2)  scaled   =  scores / √d_k                       — variance fix
(3)  weights  =  softmax(scaled, axis=keys)          — row-wise distribution
(4)  output   =  weights · V               (N × d_v) — context-mixed values

Step 1 — raw scores. Q · Kᵀ is an N × N matrix. Entry (i, j) is guest i's query dotted with guest j's key — how interested this guest is in what that guest is offering. Every pair in the room scored at once, in a single matmul. This is the payoff of the matrix formulation: the GPU does the whole thing with no Python loops in sight.

Step 2 — the √d_k scaling. This looks arbitrary. It isn't. Suppose q and k are vectors of dimension d_k with unit-variance entries. Their dot product is a sum of d_k zero-mean unit-variance products, so by the central limit theorem the dot product has variance d_k and standard deviation √d_k. For d_k = 64, scores routinely land at ±8. Feed that into softmax and one entry dominates so hard that the gradient on every other entry goes to ~0 — the softmax has saturated. A single guest is shouting so loudly that nobody else in the room can be heard. Dividing by √d_k pulls the variance back to 1, keeping the distribution soft and gradients alive.

Step 3 — softmax. Applied along the keys axis (rows of the score matrix). Row i becomes a probability distribution over keys: exactly how guest i splits their attention budget across the room. They have one unit of listening to spend, and softmax tells them how to spend it. The “wrong axis” mistake — softmaxing along queries instead of keys — is one of the most common bugs in from-scratch attention. Read Step 3 twice and tape it to your monitor.

Step 4 — weighted sum of values. Multiply the N × N attention weights by the N × d_v value matrix. Output row i is a convex combination of every guest's value vector, weighted by how much guest i was listening to each of them. Each output is a context-aware impression of that position — a single guest's final takeaway from the room, assembled out of what everyone else said, pre-weighted by how much they mattered.

attention heatmap — hover rows to inspect a query
"the cat sat on the mat yesterday afternoon"
the
cat
sat
on
the
mat
yesterday
afternoon
the
1.00
cat
0.33
0.67
sat
0.10
0.64
0.26
on
0.08
0.51
0.14
0.28
the
0.06
0.40
0.11
0.15
0.29
mat
0.11
0.03
0.04
0.06
0.75
yesterday
0.12
0.03
0.05
0.06
0.55
0.17
afternoon
0.11
0.03
0.04
0.05
0.47
0.10
0.19
tokens can only see the past (and themselves).
hover a row
hovering a row shows its attention distribution.
query
max α1.000

That's an actual attention matrix from a small model run on a real sentence. Scrub the query selector and watch one row at a time light up — that row is a single guest showing you who they ended up listening to. The bright diagonal is a hint that guests attend to themselves a lot (sensible — a token's own embedding is usually the best single source of information about it, in the same way that knowing what you came to the party to talk about is a decent prior on what you'll say next). The off-diagonal heat is where the interesting learning lives: a pronoun listening to its antecedent, a verb listening to its subject, a closing bracket listening to its opener. Every pattern you see here is emergent — no one programmed “link pronoun to antecedent,” the network learned to eavesdrop that way because doing so lowered the loss.

Softmax(QK^T / √d_k) (personified)
I'm the attention distribution — the who listens to whom of this layer. Every row of me is a probability distribution; I sum to 1 across keys by construction. When I'm sharp, I pick one source. When I'm flat, I average everything. The network tunes me by moving the Q/K weights that produced me; I don't have parameters of my own, I'm just the shape their interaction takes.

Three implementations of the same block. Pure Python on a tiny 4-guest party so you can see every index. NumPy with einsum to do the whole thing in one expression. PyTorch's F.scaled_dot_product_attention — the call you actually use in production, which dispatches to FlashAttention kernels under the hood.

layer 1 — pure python · attention_scratch.py
python
import math
import random

def matmul(A, B):
    m, k = len(A), len(A[0])
    k2, n = len(B), len(B[0])
    assert k == k2
    return [[sum(A[i][t] * B[t][j] for t in range(k)) for j in range(n)] for i in range(m)]

def transpose(A):
    return [[A[i][j] for i in range(len(A))] for j in range(len(A[0]))]

def softmax_row(row):
    m = max(row)                              # stability — subtract max before exp
    ex = [math.exp(v - m) for v in row]
    s = sum(ex)
    return [v / s for v in ex]

def attention(X, Wq, Wk, Wv):
    Q = matmul(X, Wq)
    K = matmul(X, Wk)
    V = matmul(X, Wv)
    d_k = len(Q[0])
    scores = matmul(Q, transpose(K))          # (N × N) raw similarities
    scale = math.sqrt(d_k)
    scaled = [[s / scale for s in row] for row in scores]
    weights = [softmax_row(row) for row in scaled]   # softmax along keys
    return matmul(weights, V), weights

# 4 tokens, embedding dim 4, projection dim 4
random.seed(0)
X  = [[random.gauss(0, 1) for _ in range(4)] for _ in range(4)]
Wq = [[random.gauss(0, 0.5) for _ in range(4)] for _ in range(4)]
Wk = [[random.gauss(0, 0.5) for _ in range(4)] for _ in range(4)]
Wv = [[random.gauss(0, 0.5) for _ in range(4)] for _ in range(4)]

out, A = attention(X, Wq, Wk, Wv)
for row in A:
    print([round(v, 2) for v in row])
print("output[0] =", [round(v, 2) for v in out[0]])
stdout
attention weights (row-stochastic):
 [0.37, 0.28, 0.19, 0.16]
 [0.22, 0.41, 0.24, 0.13]
 [0.18, 0.27, 0.38, 0.17]
 [0.15, 0.20, 0.28, 0.37]
output[0] = [0.41, -0.12, 0.33, 0.07]

Now with NumPy. The whole block collapses to five lines. einsum makes the axis contractions explicit — "nd,md->nm" is literally “for each (n, m) pair, sum over the d axis,” which is the definition of Q · Kᵀ. Every guest scoring every other guest, one index at a time, but all at once.

layer 2 — numpy · attention_numpy.py
python
import numpy as np

def scaled_dot_product_attention(X, Wq, Wk, Wv, mask=None):
    Q = X @ Wq                                              # (N, d_k)
    K = X @ Wk                                              # (N, d_k)
    V = X @ Wv                                              # (N, d_v)
    d_k = Q.shape[-1]

    scores = np.einsum("nd,md->nm", Q, K) / np.sqrt(d_k)    # (N, N) scaled
    if mask is not None:
        scores = np.where(mask, scores, -np.inf)            # mask BEFORE softmax

    weights = np.exp(scores - scores.max(axis=-1, keepdims=True))
    weights /= weights.sum(axis=-1, keepdims=True)          # softmax along keys

    return weights @ V, weights

rng = np.random.default_rng(0)
X  = rng.standard_normal((4, 4))
Wq, Wk, Wv = [rng.standard_normal((4, 4)) * 0.5 for _ in range(3)]

out, A = scaled_dot_product_attention(X, Wq, Wk, Wv)
print("rows of A sum to 1:", np.allclose(A.sum(axis=-1), 1.0))
print("A shape:", A.shape, " out shape:", out.shape)
pure python → numpy
triple-nested loop for Q · Kᵀ←→np.einsum("nd,md->nm", Q, K)

axis labels spell out the contraction — no indexing off-by-ones

per-row softmax with manual max subtract←→exp(scores - scores.max(-1, keepdims=True))

same numerical-stability trick, broadcast over all rows at once

math.sqrt(d_k) at each scale←→np.sqrt(d_k)

identical scalar, applied to the whole score matrix by broadcast

And now PyTorch. In real code you never handwrite attention — you call F.scaled_dot_product_attention, which dispatches to FlashAttention or memory-efficient kernels depending on your hardware. Or for a full multi-head block, nn.MultiheadAttention. What used to be 20 lines of NumPy is a single call with a causal flag.

layer 3 — pytorch · attention_pytorch.py
python
import torch
import torch.nn.functional as F

torch.manual_seed(0)
B, N, d_k = 2, 4, 4                                # batch, sequence, head-dim

X = torch.randn(B, N, d_k)
Wq = torch.randn(d_k, d_k) * 0.5
Wk = torch.randn(d_k, d_k) * 0.5
Wv = torch.randn(d_k, d_k) * 0.5

Q, K, V = X @ Wq, X @ Wk, X @ Wv

# (1) Handwritten, for comparison
scores = (Q @ K.transpose(-2, -1)) / (d_k ** 0.5)
A = F.softmax(scores, dim=-1)                      # along keys
manual_out = A @ V

# (2) Built-in — dispatches to FlashAttention when available
sdpa_out = F.scaled_dot_product_attention(Q, K, V, is_causal=False)

print("manual == sdpa:", torch.allclose(manual_out, sdpa_out, atol=1e-6))
print("out shape:", sdpa_out.shape)
stdout
manual == sdpa: True
out shape: torch.Size([2, 4, 4])
numpy → pytorch
Q = X @ Wq; K = X @ Wk; V = X @ Wv←→same, tracked for autograd, runs on GPU

@ is identical; tensors carry grad

np.einsum + manual softmax + mask←→F.scaled_dot_product_attention(Q, K, V, is_causal=True)

one call — dispatches to FlashAttention, memory-efficient kernels, or math impl

custom block for each head, concatenate←→nn.MultiheadAttention(embed_dim, num_heads)

production-ready multi-head with all the bookkeeping done for you

One more piece — causal masking. A language model generating text can't be allowed to peek at future tokens while predicting the next one. At our dinner table: guest t is only allowed to listen to guests who've already spoken. Letting them eavesdrop on guests who haven't arrived yet would be training-time cheating of the most embarrassing kind. The fix is a mask applied to the score matrix before softmax: set every entry above the diagonal to -∞, so after softmax those weights are exactly 0 and never leak into the output.

causal mask — prevent queries from seeing future keys
         keys →
         k0   k1   k2   k3
q0  [   s00  -∞   -∞   -∞  ]
q1  [   s10  s11  -∞   -∞  ]     ← token t can only attend to tokens ≤ t
q2  [   s20  s21  s22  -∞  ]
q3  [   s30  s31  s32  s33 ]

after softmax along keys:
         k0   k1   k2   k3
q0  [  1.00  0    0    0   ]
q1  [  0.43  0.57 0    0   ]
q2  [  0.29  0.35 0.36 0   ]
q3  [  0.21  0.26 0.27 0.26]

This is what is_causal=True does inside PyTorch's SDPA — it fills the upper triangle with -∞ before the softmax runs. In a decoder-only model like GPT, every attention layer is causal. In an encoder like BERT, none of them are — everyone at the table speaks at once and hears everyone else. In an encoder-decoder like T5, the encoder is bidirectional and the decoder is causal with an extra cross-attention step.

Gotchas

Softmax on the wrong axis: the attention softmax goes over keys (the last dim of a Q Kᵀ matrix). Softmax over queries instead, and rows of A no longer sum to 1 — the output is silently nonsense and the loss still decreases a little, so you won't notice until validation tanks. Always dim=-1, always triple-checked.

Forgetting √d_k: the network will still train, slowly, and with a softmax so peaked that gradients flow through one key at a time. Everything converges worse. This bug is invisible on toy d_k = 4 sequences and devastating at d_k = 64+.

Masking after softmax: if you zero out masked positions after the softmax, the remaining weights no longer sum to 1 and you've re-introduced a tiny leak from the masked side through the normalization. The mask must be applied by setting scores to -∞ before softmax. Every time. No exceptions.

Mask shape mismatch: SDPA expects the mask to broadcast against (B, H, N, N). A (N, N) bool mask works; a (B, N) padding mask does not (that's key_padding_mask on nn.MultiheadAttention). Read the docstring every time — the conventions shift between APIs.

Build causal self-attention from scratch, verify against PyTorch

Implement a single-head causal self-attention block in PyTorch using only @, F.softmax, and torch.triu — no calls to F.scaled_dot_product_attention or nn.MultiheadAttention. Your signature: attention(x, Wq, Wk, Wv) -> (out, weights) for input x of shape (B, N, d).

Build a causal mask with torch.triu(torch.ones(N, N), diagonal=1).bool(). Apply it with scores.masked_fill_(mask, float('-inf')) before the softmax. Divide by √d_k.

Now verify. Run the same (x, Wq, Wk, Wv) through F.scaled_dot_product_attention(Q, K, V, is_causal=True) and assert that your output matches to 1e-6. If it doesn't, the culprit is almost always (a) wrong softmax axis, (b) mask applied after softmax, or (c) missing the √d_k.

Bonus: print the attention weights for a small N = 6 sequence. Confirm the upper triangle is exactly 0 and every row sums to exactly 1.

What to carry forward. Attention is a dinner party rendered in linear algebra: every guest turns to every other guest with a query, a key, and a value — three learned projections, one scaled dot product, one softmax, one weighted sum of values. The √d_k keeps the softmax from saturating at reasonable d_k. Softmax goes along keys. Masks go in before softmax. The whole room is permutation-equivariant, which is why we'll staple positional encodings to the input. And the O(N²) memory cost — every guest scoring every other guest — is the single biggest bottleneck in modern transformer scaling, which is why your next six papers will all be about sidestepping it.

Next up — Multi-Headed Self Attention. One conversation at one table is fine — but the room would be richer if every guest could have several conversations at once, each specializing in a different angle. A pronoun probably wants to find its antecedent and its syntactic role and its verb, simultaneously. The answer is multi-headed self attention: run several attention heads in parallel with different W_Q / W_K / W_V, concatenate their outputs, and project down. We'll derive why n_heads matters, what each head tends to specialize in, and how a 12-head attention block is still a single matmul if you squint at it right.

what next
quiz

Why does scaled dot-product attention divide by √d_k before the softmax?

References