Layer Normalization

Per-sample normalization for transformers.

Medium
~15 min read
·lesson 2 of 4

You spent the last section hand-tuning weight initialization so activations wouldn't blow up or collapse as signals propagated through your multi-layer perceptron. Good init buys you a clean starting distribution. The problem: the moment gradient descent starts moving the weights, the statistics start drifting again. Layer by layer, the signal gets louder on some channels, quieter on others, and by layer forty it sounds like every instrument is fighting for the same three dB.

Think of the post-MLP signal as a mixing console. Every input sample is its own track. The per-track sound mixer sits after each layer, takes one track at a time, and evens it out: pulls it to the same reference loudness (mean zero), tames the peaks (unit variance), and hands it back. That mixer is LayerNorm.

LayerNorm sits inside every transformer block ever built. GPT uses it. BERT uses it. Llama uses a close cousin (RMSNorm, two lessons out) that strips one of the knobs off. Before we get to attention you need to understand what LayerNorm does, which axis it averages over, and why it dethroned its older sibling BatchNorm in sequence models.

The whole operation is three lines of arithmetic, then a fourth line that puts the color back. For a single track — a vector of features x ∈ ℝᴰ belonging to one example:

LayerNorm — the full operation
μ    =   (1/D) · Σⱼ xⱼ                             # mean over features
σ²   =   (1/D) · Σⱼ (xⱼ − μ)²                       # variance over features

x̂ⱼ  =   (xⱼ − μ) / √(σ² + ε)                      # normalize to mean 0, var 1
yⱼ   =   γⱼ · x̂ⱼ  +  βⱼ                            # learned per-feature scale + shift

Read it left to right and the mixer does exactly what the metaphor promised. Lines one and two compute this track's mean and variance — the loudness and the peak energy — using only its own features. Line three pulls the loudness to zero and divides out the peak. Nothing in any of that has glanced at another track in the batch. Your batch size could be ten thousand, or one; the numbers would be identical.

Line four is where the engineer gets to EQ. γ (gain) and β (bias) are two learned parameters per feature that let the network scale and shift each normalized channel back up. The normalization itself is parameter-free — the per-feature EQ is the only part training touches.

The widget below starts with a batch of six tracks, each recorded at a different loudness and dynamic range — exactly the mess that drifts through a deep network. Toggle LayerNorm on. Each row snaps to mean 0, std 1, one row at a time. Click a row to see its stats before and after.

layer normalization — each example normalized independently
shape: (6, 8)·axis=-1
batch (rows) × features (cols) · click a row to inspect
ffffffff
the operation applied to row 2
μ = (1/D) · Σⱼ xⱼ = -0.543 · σ² = (1/D) · Σⱼ (xⱼ − μ)² = 1.303 · x'ⱼ = (xⱼ − μ) / √(σ² + ε)
The row's mean is subtracted and it's rescaled by its std. After LayerNorm, every example has mean ≈ 0 and std ≈ 1 — independently of the batch. Compare a batch-of-1 inference call vs a batch-of-1000 training call: identical normalization either way.
row mean-0.543
row std1.141
LayerNorm (personified)
I work one sample at a time. I do not care how many examples you gave me, or what other examples in the batch look like. I take your feature vector, compute its mean and variance across the feature axis, subtract and divide, scale and shift with my learned γ and β, and hand you a feature vector whose statistics are stable layer after layer. A batch of one works. A batch of ten thousand works. I do not discriminate.

“Which axis is being averaged” is the single thing most people get wrong about LayerNorm, BatchNorm, and their cousins. Keep the mixer picture in mind: LayerNorm runs down one track at a time, averaging across that track's own features. BatchNorm does the opposite — it picks one channel and averages across every track in the studio. Completely different sets of numbers.

which axis does each normalization average over?
tensor: (B=4, S=3, F=6)·nn.LayerNorm(features)
batch 0
f
f
f
f
f
f
s
s
s
batch 1
f
f
f
f
f
f
s
s
s
batch 2
f
f
f
f
f
f
s
s
s
batch 3
f
f
f
f
f
f
s
s
s
LayerNorm
Average across features (axis=-1) for each (batch, seq) position. Independent per example.
normalizationaxes averagedstats perneeds batch?
LayerNormfeatures(b, s)no — per sample
BatchNormbatch + seqfeatureyes — different modes for train/eval
RMSNormfeatures (no mean)(b, s)no
GroupNormfeatures in a group(b, s, group)no
anchor cell · b=1, s=1, f=2

Flip between the four options. The highlighted cells are the ones pooled into a single mean and variance. LayerNorm's pool is always one example's feature vector. BatchNorm's pool is every example in the batch (plus the sequence dimension in 3D tensors) for a single feature. Same-feature-across-examples versus same-example-across-features. The whole argument comes down to that one line.

Two spots in a transformer block where the mixer can sit, and the choice is not cosmetic. The original 2017 paper used post-norm — normalize after the residual add. Modern implementations (GPT-2 onwards) use pre-norm — normalize before the sublayer. Pre-norm is more stable, trains more reliably, and is the default in essentially every transformer implementation you'll read today.

post-norm vs pre-norm — one indent matters
# original transformer (Vaswani 2017) — post-norm
x = LayerNorm( x + Sublayer(x) )          # normalise after the residual add

# modern implementation — pre-norm
x = x + Sublayer( LayerNorm(x) )          # normalise first, then add residual
Gotchas

nn.LayerNorm(shape) takes the shape to normalize over, not the batch size. For a tensor of shape (batch, seq, features), you want nn.LayerNorm(features). The axis is the last N dimensions matching that shape.

Do not add dropout between linear + LayerNorm + activation unless you mean to. It changes the variance LayerNorm sees and can destabilize training. Standard transformer block is LN → attention → residual → LN → mlp → residual, dropout applied to the residual output, not between layers.

LayerNorm is implemented in higher precision than the rest of the model. Under fp16, the reciprocal-sqrt operation can catastrophically underflow. PyTorch and every serious framework computes the mean, variance, and division in fp32 and casts back down. Don't try to “optimize” by forcing LayerNorm into fp16 — it's one of the few ops where precision really matters.

From scratch in four lines of NumPy — the mixer, with a soldering iron. PyTorch ships it as a one-liner with the EQ knobs already wired to autograd. Same operation either way; only one runs on a GPU.

layer 1 — pure numpy · layer_norm_numpy.py
python
import numpy as np

def layer_norm(x, eps=1e-5, gamma=None, beta=None):
    """x: (..., D). Normalize over the last axis."""
    mean = x.mean(axis=-1, keepdims=True)
    var = x.var(axis=-1, keepdims=True)
    x_hat = (x - mean) / np.sqrt(var + eps)
    if gamma is not None: x_hat = x_hat * gamma
    if beta is not None:  x_hat = x_hat + beta
    return x_hat

rng = np.random.default_rng(0)
x = rng.normal(loc=[1, -1, 2], scale=[1, 1, 1], size=(8, 3)).T  # rows have different means
print("before: mean per row =", np.round(x.mean(-1), 2), " std per row =", np.round(x.std(-1), 2))
y = layer_norm(x)
print(" after: mean per row =", np.round(y.mean(-1), 2), " std per row =", np.round(y.std(-1), 2))
stdout
before: mean per row = [ 1.38 -0.97  2.41]  std per row = [0.87 0.95 1.01]
 after: mean per row = [-0.00 -0.00  0.00]  std per row = [1.00 1.00 1.00]
the four lines of math ↔ the four lines of code
μ = (1/D) · Σⱼ xⱼ←→x.mean(axis=-1, keepdims=True)

keepdims so broadcasting works

σ² = (1/D) · Σⱼ (xⱼ − μ)²←→x.var(axis=-1, keepdims=True)
x̂ = (x − μ) / √(σ² + ε)←→(x - mean) / np.sqrt(var + eps)

the ε avoids div-by-zero

y = γ · x̂ + β←→x_hat * gamma + beta

learned affine restoration

layer 2 — pytorch · layer_norm_pytorch.py
python
import torch
import torch.nn as nn

# Shape conventions for a transformer tensor:
# batch × sequence × features.   LayerNorm is applied per (batch, seq).
x = torch.randn(32, 12, 768)

norm = nn.LayerNorm(normalized_shape=768)   # γ, β have shape (768,)
y = norm(x)

print(y.shape)
print(f"mean per (b, s) = ~{y.mean(dim=-1).abs().mean():.2f},  "
      f"std per (b, s) = ~{y.std(dim=-1).mean():.2f}")
stdout
torch.Size([32, 12, 768])
mean per (b, s) = ~0.00,  std per (b, s) = ~1.00
numpy → pytorch
hand-rolled 4 lines←→nn.LayerNorm(features)

fused kernel, mixed-precision aware, gradient-ready

manual gamma, beta←→automatically learnable via autograd

they register as nn.Parameter

always fp32←→internally fp32 even in fp16/bfloat16 mode

PyTorch handles the precision hack for you

Compare training with and without LayerNorm

Take the MLP from the previous section. Build two versions: one plain, one with nn.LayerNorm after every linear layer (before the activation). Train both on MNIST for 5 epochs. Plot validation accuracy vs epoch for both.

Bonus: swap in nn.BatchNorm1d instead of nn.LayerNorm for a third configuration. At standard MLP sizes all three converge; notice that BatchNorm converges fastest for medium-sized batches and LayerNorm is the most robust across batch sizes. The real differences show up in transformers — which is the whole point of the upcoming attention section.

What to carry forward. LayerNorm is the per-track mixer. It standardizes each example's feature vector to mean 0, std 1 — independently of the batch — then lets the network EQ the result back with a learned γ and β. Two parameters per feature; four lines of arithmetic plus a learnable rescale. Because it never glances across tracks, batch size and sequence length don't enter the picture, which is precisely why every transformer you've ever heard of reaches for it.

Next up — Batch Normalization. LayerNorm's older cousin from 2015, which works the opposite way: one feature at a time, averaged across the batch. It dominated CNN training for years and invented a whole new category of “why did the loss spike on batch-size-two” debugging stories. When is averaging across the batch the better call, and when does it fall apart? That's the next lesson.

References