Multi Headed Self Attention
Parallel attention heads specializing on different patterns.
Last lesson we sat every token at one long table. That's self-attention — one big dinner party, every guest talking to every other guest, one conversation for the whole room. It's a beautiful mechanism. It's also, on any real sentence, one table too few.
Think about what's actually happening at a single table. Every guest asks one question. One. The query vector for each token runs through a single W_Q, every other token offers itself through a single W_K, and the score q · k captures exactly one notion of “relevant.” Whatever flavor of relevance that table happens to land on, every token is stuck with it.
Take the sentence “the black cat sat on the mat.” The token cat wants to talk to sat because that's its verb. It also wants to talk to black because that's its adjective. It also wants to talk to mat because that's where the action lands. Three different relationships — syntactic, modifier, argument — and one table can only host one conversation at a time. Force a single pair of Q/K weights to capture all three and you get a compromise that captures none of them cleanly.
So: throw a bigger party. Set up several tables in parallel, each with its own rules for who finds whom interesting. Table 1 cares about subject-verb links. Table 2 cares about coreference — who does “it” refer to? Table 3 cares about positional patterns — the token two places to my left. Every guest sits briefly at every table, whispers to the others present, and collects what they heard. At the end, each guest concatenates their eight whispers into one coherent update and walks back to the main floor. That is multi-head attention, and this lesson is the mechanics of how many tables, what each table listens for, and how the whispers get sewn back together.
MultiHead(X) = Concat(head₁, head₂, …, head_H) · W_O where head_i = Attention(X·W_Q^i, X·W_K^i, X·W_V^i) and Attention(Q, K, V) = softmax( Q·Kᵀ / √d_head ) · V
Here's the part most tutorials skip. You do not give each table its own full-width d_model subspace to play in — that would multiply the parameter count by H and nobody wants that. You split the existing d_model into H slices of width d_head = d_model / H, and each head gets one slice as its private subspace. A head sees a narrower view; there are just H of them working in parallel, so the total dimensionality is preserved. Same parameter budget as a single-head model of the same width. The whole cost of having multiple tables is bookkeeping.
The widget is the head-split reveal in pictures. Start with d_model = 64 and spin it up into H = 8 heads of d_head = 8. Input of shape (B, N, 64) runs through the Q/K/V projections — still (B, N, 64) coming out — then reshapes and transposes into (B, H, N, d_head) = (B, 8, N, 8). That reshape is the split: each head now owns its own 8-dimensional subspace, carved out of the original 64. From there, eight tables run attention in parallel — eight separate N × N score matrices, each computed only over its head's slice of the representation.
That (B, H, N, d_head) layout is the whole game. The H axis is a batch dimension from attention's point of view: the op has no idea it's doing eight things at once, it just sees B·H independent attention problems and dispatches them to the GPU. Multi-head is embarrassingly parallel by design — which is why you can crank H up without the clock time tracking it linearly.
I am one of eight at my own table. I see 8 dimensions out of 64, and I only have to be good at one kind of relationship — maybe subject-verb, maybe coreference, maybe “the token two positions to my left.” I don't negotiate with my siblings. We all specialize on the same input in our own subspaces and hand our whispers to the integrator. I am a specialist, on purpose.
For i = 1 … H:
Qᵢ = X · W_Qⁱ shape (B, N, d_head)
Kᵢ = X · W_Kⁱ shape (B, N, d_head)
Vᵢ = X · W_Vⁱ shape (B, N, d_head)
Aᵢ = softmax( Qᵢ · Kᵢᵀ / √d_head ) shape (B, N, N)
Hᵢ = Aᵢ · Vᵢ shape (B, N, d_head)
Concat along the last axis:
H = [H₁ | H₂ | … | H_H] shape (B, N, d_model)
Final projection:
out = H · W_O shape (B, N, d_model)The √d_head inside the softmax is the same scaled-dot-product trick from last lesson — but notice the number under the square root. It's d_head, not d_model. Each head lives in its own small subspace, so its dot products concentrate around a variance of d_head, not d_model. Scaling by √d_model here would over-shrink them and collapse the softmax into a near-uniform mush; √d_head is the right whistle.
Now the payoff. Four tables at the same party, same sentence on the menu, and each table's attention map lights up a completely different conversation. One head locks onto syntactic dependency — nouns leaning toward their verbs. Another tracks position — every token whispering to the token immediately before it. A third runs coreference — pronouns reaching back to their antecedents. The fourth is doing something harder to name but clearly structured. Each head specializes in its own relational primitive.
Nobody told head 3 to handle coreference. Nobody hand-programmed table 1 to follow subject-verb. The only training signal is “predict the next token,” and the heads settle into a division of labor because that's the cheapest way to drive the loss down. Eight small experts, each focused on one slice of relational structure, turn out to be strictly easier for gradient descent to sculpt than one giant generalist. Specialization is what gradient descent wants to do, given the chance.
This is also why a single head of width d_model is computationally equivalent to multi-head of the same total width but optimizationally worse. Same parameter budget on both sides of the ledger — but only the multi-head version gives gradient descent clean separate subspaces to carve specialists into. Cram everything into one table and the shared W_Q, W_K, W_V has to encode every pattern at once, which it can, just badly.
I am the integrator at the end of the night. Eight specialists walk back from their tables and hand me their whispers — syntax from table 1, coreference from table 2, position from table 3, five more. I concatenate them end-to-end into a single long vector and project the whole thing back into model space with W_O. My job is to fuse eight views of each token into one coherent representation the next layer can consume. Without me you have eight experts shouting past each other instead of a conversation.Three implementations, same mechanism, three layers of abstraction. First NumPy with every reshape written out so you can see the split happen. Then PyTorch's nn.MultiheadAttention one-liner with a causal mask — the prototype-speed option. Then a custom module — the kind you actually ship when you need fine control over shapes and projections.
import numpy as np
B, N, d_model, H = 2, 5, 64, 8
d_head = d_model // H # 64 / 8 = 8
rng = np.random.default_rng(0)
X = rng.standard_normal((B, N, d_model))
W_Q = rng.standard_normal((d_model, d_model)) * 0.1
W_K = rng.standard_normal((d_model, d_model)) * 0.1
W_V = rng.standard_normal((d_model, d_model)) * 0.1
W_O = rng.standard_normal((d_model, d_model)) * 0.1
# 1. Linear projections — still (B, N, d_model)
Q = X @ W_Q
K = X @ W_K
V = X @ W_V
# 2. Reshape + transpose into (B, H, N, d_head) — the critical step
def split_heads(x):
x = x.reshape(B, N, H, d_head) # (B, N, H, d_head)
return x.transpose(0, 2, 1, 3) # (B, H, N, d_head)
Q_h, K_h, V_h = split_heads(Q), split_heads(K), split_heads(V)
# 3. Scaled dot-product attention, per head, in parallel via broadcasting
scores = Q_h @ K_h.transpose(0, 1, 3, 2) / np.sqrt(d_head) # (B, H, N, N)
attn = np.exp(scores - scores.max(-1, keepdims=True))
attn /= attn.sum(-1, keepdims=True) # softmax
head_out = attn @ V_h # (B, H, N, d_head)
# 4. Merge heads back — transpose and reshape to (B, N, d_model)
merged = head_out.transpose(0, 2, 1, 3).reshape(B, N, d_model)
# 5. Final output projection
output = merged @ W_O
print(f"X : {X.shape}")
print(f"Q_h : {Q_h.shape} # (B, H, N, d_head)")
print(f"scores : {scores.shape} # one N×N attention matrix per head")
print(f"output : {output.shape} # back to (B, N, d_model)")X : (2, 5, 64) Q_h : (2, 8, 5, 8) # (B, H, N, d_head) scores : (2, 8, 5, 5) # one N×N attention matrix per head output : (2, 5, 64) # back to (B, N, d_model)
The split_heads function — reshape, then transpose — is the load-bearing trick. You can't just reshape (B, N, d_model) straight into (B, H, N, d_head), because the memory layout would interleave heads across tokens in the wrong order and you'd silently compute attention on the wrong subspace slices. The correct order is reshape to (B, N, H, d_head), then transpose H with N. Flip those two and every table at the party is suddenly talking to the wrong guests.
import torch
import torch.nn as nn
B, N, d_model, H = 2, 5, 64, 8
x = torch.randn(B, N, d_model)
# batch_first=True so shapes match everything else we've been writing.
mha = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=H,
batch_first=True, # IMPORTANT — default is (N, B, d_model), which is insane
)
# Causal mask: upper-triangular of -inf means "don't look forward".
causal = torch.triu(torch.full((N, N), float("-inf")), diagonal=1)
out, attn = mha(x, x, x, attn_mask=causal, need_weights=True)
print(f"out shape : {out.shape}")
print(f"attn shape: {attn.shape} # averaged over heads by default")
print(f"causal check — row 0 attends only to col 0: {(attn[0, 0, 1:] == 0).all()}")out shape : torch.Size([2, 5, 64]) attn shape: torch.Size([2, 5, 5]) # averaged over heads by default causal check — row 0 attends only to col 0: tensor(True)
PyTorch's built-in compresses everything above into one call. Watch the flags though. batch_first=True is non-negotiable — the default is (N, B, d_model), which no sane code path actually wants, and every transformer student has been bitten by this at least once. The mask is additive: -inf positions become exactly zero after softmax.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""The version you'd actually write — one big QKV projection, explicit split."""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.H = num_heads
self.d_head = d_model // num_heads
# Fused projection — one matmul is faster than three on GPU.
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
B, N, _ = x.shape
# (B, N, 3*d_model) → split into 3 × (B, N, d_model)
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
# (B, N, d_model) → (B, H, N, d_head)
def split(t):
return t.view(B, N, self.H, self.d_head).transpose(1, 2)
q, k, v = split(q), split(k), split(v)
# Scaled dot-product attention — F.scaled_dot_product_attention is fastest
# in PyTorch 2.x (uses Flash-Attention kernels where available).
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0
) # (B, H, N, d_head)
# Merge heads → (B, N, d_model)
out = attn_out.transpose(1, 2).contiguous().view(B, N, self.d_model)
return self.out_proj(self.dropout(out))
# Sanity check: does our MHA match nn.MultiheadAttention on the same weights?
torch.manual_seed(0)
mha = MultiHeadAttention(d_model=64, num_heads=8)
ref = nn.MultiheadAttention(64, 8, batch_first=True, bias=False)
# Copy weights across so the two should produce identical output.
with torch.no_grad():
ref.in_proj_weight.copy_(mha.qkv_proj.weight)
ref.out_proj.weight.copy_(mha.out_proj.weight)
x = torch.randn(2, 5, 64)
a = mha(x)
b, _ = ref(x, x, x, need_weights=False)
print(f"custom : {a.shape}")
print(f"reference: {b.shape}")
print(f"max abs diff (after weight sync): {(a - b).abs().max().item():.1e}")custom : torch.Size([2, 5, 64]) reference: torch.Size([2, 5, 64]) max abs diff (after weight sync): 3.2e-07
Q = X @ W_Q; K = X @ W_K; V = X @ W_V←→qkv = self.qkv_proj(x); q,k,v = qkv.chunk(3, -1)— one fused matmul replaces three — faster on GPU, same math
x.reshape(B, N, H, d_head).transpose(0, 2, 1, 3)←→t.view(B, N, H, d_head).transpose(1, 2)— same reshape-then-transpose dance, different API
manual softmax(QKᵀ/√d_head) @ V←→F.scaled_dot_product_attention(q, k, v, mask)— PyTorch 2's fused kernel — Flash-Attention when hardware allows
output = merged @ W_O←→self.out_proj(out)— final projection — concatenate-and-project-back in one line
d_model must be divisible by num_heads. d_model=64, H=7 gives d_head=9.14…, which is not a tensor shape. PyTorch and most frameworks throw an assertion. Pick H ∈ {1, 2, 4, 8, 16, 32, …} and move on.
(B, N, d_model) vs (B, H, N, d_head). Every shape bug in transformer code is really a confusion between these two layouts. Keep a mental note: anything before the reshape or after the merge is (B, N, d_model); everything in between is (B, H, N, d_head). Print .shape liberally while debugging.
Mask shape broadcasting. A causal mask is shape (N, N) but attention scores are (B, H, N, N). The mask broadcasts over the leading (B, H) dimensions automatically — if the dimensions are in the right order. Passing (H, N, N) or (B, N, N) instead of (N, N) (or explicit (B, H, N, N)) broadcasts wrong and silently masks the wrong positions.
nn.MultiheadAttention is NOT batch-first by default. It expects (N, B, d_model). This is a PyTorch historical artifact that has broken more transformers than any other single API choice. Always pass batch_first=True.
out_proj is not optional. Dropping W_O looks harmless — you're already back at d_model after concat — but without it the head outputs are just concatenated, never mixed. W_O is how information from table 3 ends up influencing the component table 7 produced. It's what turns eight parallel monologues into one integrated representation.
Write a MultiHeadAttention module from scratch with d_model=32, num_heads=4. Use separate W_Q, W_K, W_V, W_O linear layers (no fused QKV). Initialize all four with torch.manual_seed(42).
Then instantiate nn.MultiheadAttention(32, 4, batch_first=True, bias=False), copy your weights into its in_proj_weight (stacked Q/K/V) and out_proj.weight, and run both modules on the same random input of shape (4, 10, 32).
Assert torch.allclose(your_output, reference_output, atol=1e-5). If it fails, the bug is almost always in your reshape order or the way you stacked the weights.
Bonus: add a causal mask and re-verify. Bonus 2: swap your manual softmax for F.scaled_dot_product_attention and confirm the output is still bit-identical to nn.MultiheadAttention.
What to carry forward. Multi-head attention isn't more parameters than single-head of the same width — it's the same parameters reorganized so gradient descent can carve out specialists on separate tables. Each head projects Q/K/V into its own narrow subspace, runs attention there, and the integrator concatenates every head's output and projects it back into model space with W_O. The shape dance — (B, N, d_model) → (B, H, N, d_head) → (B, N, d_model) — is the single most load-bearing piece of bookkeeping in the entire transformer stack. Once you've written it yourself, transformer code stops being intimidating.
Next up — the Transformer Block. Multi-head attention gets information moving between positions — that's one half of a transformer layer. But a layer that only mixes information without also processing each position is half a machine. Attention tells each token who to listen to; it doesn't give it space to think about what it heard. The transformer-block lesson wraps our multi-head attention in the other half — a per-token feedforward MLP — plus residual connections and embeddings-scale normalization that keep the stack trainable at depth. Few new ideas, one very famous diagram, and a whole lot of load-bearing + x.
- [01]Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, Polosukhin · NeurIPS 2017 — the original multi-head attention paper
- [02]Clark, Khandelwal, Levy, Manning · BlackboxNLP 2019 — empirical study of head specialisation
- [03]Michel, Levy, Neubig · NeurIPS 2019 — when and why you can prune heads
- [04]Zhang, Lipton, Li, Smola
- [05]Elhage et al. (Anthropic) · interpretability — heads as composable circuits · 2021