Batch Normalization
The CNN-era normalization trick, fully unpacked.
You just met LayerNorm, which normalizes each example in isolation. Meet its older, rowdier cousin. BatchNorm looks at a whole batch at once and asks a different question: not “how does this example's features compare to each other,” but “how does this feature compare to the same feature across everyone else in the batch?”
Picture a teacher grading on the curve. You and thirty-one classmates take an exam. For each question (call that a feature), she computes the class mean and the class spread, then rescales every student's score: subtract the class mean, divide by the class spread. A 74 on a question the class bombed is now great. A 74 on a question everyone aced is now mediocre. That is BatchNorm. The batch is the class. The feature is the exam question. The statistics come from across the class, not from within one student.
Batch Normalization landed in 2015 (Ioffe & Szegedy) and quietly rewrote deep learning. Vision networks doubled in depth the year after. Learning rates went up by an order of magnitude. The whole “go very deep” CNN era is hard to picture without it. Sequence models eventually defected to LayerNorm because BatchNorm has a flaw we'll get to — but every convnet you will ever load still has BN sprinkled through it like salt.
Mechanically it looks like its cousin: subtract mean, divide by standard deviation, restore with a learned affine. The only thing that changes is which axis the mean and variance are computed over. For BatchNorm on a dense layer with feature dim D and a batch of N examples:
For each feature j ∈ {1 … D}, over the batch:
μⱼ = (1/N) · Σᵢ xᵢⱼ # mean across the batch
σⱼ² = (1/N) · Σᵢ (xᵢⱼ − μⱼ)² # variance across the batch
x̂ᵢⱼ = (xᵢⱼ − μⱼ) / √(σⱼ² + ε) # normalize
yᵢⱼ = γⱼ · x̂ᵢⱼ + βⱼ # learned per-feature affineRead the indices carefully, because this is where BatchNorm and LayerNorm part ways. LayerNorm fixes the example i and sweeps the mean across the feature index j. BatchNorm fixes the feature j and sweeps across the example index i. Same operation, perpendicular axis. Both output a tensor shaped like the input. Both ship with two learned parameters per feature. The axis is the whole fight.
Back to the classroom. Each student is one row i. Each exam question is one column j. BatchNorm reads column-wise — grade question 1 on the class curve, grade question 2 on the class curve, and so on. LayerNorm reads row-wise — grade each student against their own average. Different teachers, different philosophies.
Now the infamous catch. What happens on exam day when only one student shows up? The curve is meaningless — a single score has no mean to subtract and no spread to divide by. BatchNorm has the same problem: at inference time you might feed it a single example, and a sample of one has no variance. So it keeps a second set of books.
During training, the mean and variance come from the current batch — live class statistics. On the side, BatchNorm maintains an exponential moving average of those per-feature stats across every batch it has seen. Think of it as the historical class curve: what has this feature looked like, on average, across the whole semester? At inference time the live batch is ignored and the historical curve takes over. Grades stay sensible even when one student walks in alone.
running_mean ← (1 − momentum) · running_mean + momentum · batch_mean running_var ← (1 − momentum) · running_var + momentum · batch_var # momentum is typically 0.1 (i.e. 10% new info per step) # at eval time: use running_mean and running_var, do not update them
Watch it happen. Below, a simulated training run samples batches from a slowly drifting distribution — imagine the upstream layers (which, by the way, are being tuned by gradient descent and whatever weights you started from) steadily shifting what each feature looks like. The dashed curves are the per-batch stats. The solid curves are the exponential average. Flip to eval mode and the solid curves freeze. That freeze is model.eval() in PyTorch.
In training I trust the class in the room. I take the mean and variance of the thirty-two exams on my desk, per question, and hand back curved scores. While I do that I'm also updating a running record of the semester. At eval time the room is empty — maybe one student walks in — so I ignore them and use the record. Shrink the class to four students and my curve gets noisy. Shrink it to one and I have nothing to grade on.
That noise problem is the whole reason BatchNorm is not a universal answer. The standard error of a sample mean scales like σ / √N. Halve the batch and your error goes up by √2. Run BatchNorm with a batch of one and the sample mean is literally the sample — zero variance, nothing to divide by except ε. Our teacher with a class of one just turns in a blank gradebook.
At N ≥ 32 the estimate is trustworthy. Below that, BatchNorm starts injecting batch-specific noise into the forward pass, and the running averages lag whatever distribution the layers above are actually producing. This is why transformer people eventually gave up on BN: per-device batch sizes of 1-4 are routine in that world (sequences are long, memory is short), and BatchNorm simply does not cope. LayerNorm, normalizing within a single example, never cared how many students were in the room.
BatchNorm1d vs BatchNorm2d vs BatchNorm3d. PyTorch ships three variants. BatchNorm1d takes feature vectors ((N, D) or (N, D, L)). BatchNorm2d takes images ((N, C, H, W)) and averages over (N, H, W) per channel — every pixel of every image in the batch counts as a sample for that channel's curve. Pick the variant that matches your tensor layout; the error message when you don't is unhelpful.
Don't add a bias to the Linear right before BatchNorm. BatchNorm subtracts the mean, which cancels any upstream constant. The β parameter inside BN already plays the bias role. Pass nn.Linear(.., .., bias=False) when you're feeding into a BN. Same for Conv2d. Otherwise you're training a parameter that gets zeroed on every forward pass, which is a fun thing to explain in code review.
SyncBatchNorm for multi-GPU training. A plain BN on 8 GPUs with batch 8 each computes eight separate class curves of size 8 — somehow worse than a single curve of 64. Wrap with nn.SyncBatchNorm to pool stats across devices. Not optional for anything serious on multi-GPU.
Don't weight-decay γ, β. BN's scale and bias don't control model capacity the way conv weights do. Decaying them toward zero is mild sabotage. Put them in a separate parameter group with zero weight decay — every modern optimizer config does this, and most beginner bugs come from forgetting it.
Fusion at inference. At eval time BN is a fixed affine transform (the stats are frozen, the γ/β are frozen). Every production inference stack folds that affine into the preceding Conv or Linear weights, deleting the BN op entirely. Your model shrinks, latency drops, and the output doesn't change. That fusion only works because of the train/eval distinction — another reason the mode switch matters.
From scratch, then PyTorch. The NumPy version has to manage the running stats by hand, which is exactly what PyTorch hides inside nn.BatchNorm1d. Watch the bookkeeping; then watch it disappear.
import numpy as np
class BatchNorm1d:
def __init__(self, features, momentum=0.1, eps=1e-5):
self.gamma = np.ones(features)
self.beta = np.zeros(features)
self.running_mean = np.zeros(features)
self.running_var = np.ones(features)
self.momentum, self.eps = momentum, eps
self.training = True
def __call__(self, x):
if self.training:
mu = x.mean(axis=0) # per-feature batch mean
var = x.var(axis=0) # per-feature batch var
# Exponential running average — running_mean drifts toward batch_mean
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mu
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
else:
mu, var = self.running_mean, self.running_var # frozen at eval
x_hat = (x - mu) / np.sqrt(var + self.eps)
return self.gamma * x_hat + self.beta
# Demo
rng = np.random.default_rng(0)
bn = BatchNorm1d(features=2)
for step in range(100):
x = rng.normal(loc=[0.4, 0.2], scale=[1, 1], size=(32, 2))
bn(x)
bn.training = False
test = rng.normal(loc=[0.4, 0.2], scale=[1, 1], size=(32, 2))
out = bn(test)
print("eval output mean=", np.round(out.mean(axis=0), 2),
" (= running_mean, frozen)")training step 0: batch_mean=[0.15 -0.03] running_mean=[0.015 -0.003] training step 99: batch_mean=[0.42 0.18] running_mean=[0.28 0.14] eval output mean=[0.28 0.14] (= running_mean, frozen)
self.training = True # compute from batch←→model.train()— batch stats in, update running stats
self.training = False # use running←→model.eval()— frozen stats, no updates
manual running_mean bookkeeping←→registered as a "buffer" — saves with the model— saved to checkpoint, restored on load
import torch
import torch.nn as nn
# For a dense feed-forward tensor (batch, features)
bn_1d = nn.BatchNorm1d(num_features=64)
x = torch.randn(32, 64)
print(bn_1d(x).shape)
# For a conv tensor (batch, channels, height, width)
bn_2d = nn.BatchNorm2d(num_features=3) # num_features = channel count
img = torch.randn(16, 3, 28, 28)
print("conv layout:", bn_2d(img).shape)
# In a real model you'd combine them:
class Block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, bias=False) # bias=False → BN takes over
self.bn = nn.BatchNorm2d(out_ch)
def forward(self, x):
return torch.relu(self.bn(self.conv(x)))torch.Size([32, 64]) conv layout: torch.Size([16, 3, 28, 28])
Take any classifier with BatchNorm. Put it in .train() mode and feed it one example at a time. The first thing that breaks: the per-feature variance of a one-sample batch is exactly zero, so the division by √(σ² + ε) is entirely determined by ε. Training goes sideways immediately.
Now switch to .eval() and run the same inputs. The output is fine — BN uses the running stats, which don't care about batch size. Write a small script that asserts the two modes give different outputs at batch size 1, and explain why in one sentence.
Bonus: swap the BatchNorm for LayerNorm and rerun. The batch-of-one case suddenly works in training mode too. That is not magic; it is the whole reason transformers moved to per-sample normalization.
What to carry forward. BatchNorm grades on the curve across the batch, per feature — the perpendicular axis to LayerNorm. Two modes: training reads the live class, updates the running record; eval reads the running record and freezes. Switch with model.train() / model.eval(), or expect confusion. It breaks at small batch sizes because a class of four is a lousy curve and a class of one is no curve at all. It dominated vision from 2015 to 2018, still lives inside every convnet you'll load, and lost the sequence-modeling world specifically because of that small-batch weakness.
Next up — RMS Normalization. One more normalization variant, and this one took a hard look at LayerNorm and decided it could skip a step. No mean subtraction. Just divide by root-mean-square and move on. Does it cost accuracy? In transformers, barely. Does it save compute? Yes — and that is why it is now the default inside Llama, PaLM, and most 2023+ large language models. We'll find out why dropping the mean subtract turned out to be free.
- [01]Ioffe, Szegedy · ICML 2015 — the original paper
- [02]Santurkar, Tsipras, Ilyas, Madry · NeurIPS 2018 — shows the internal-covariate-shift story is mostly wrong
- [03]Yuxin Wu, Kaiming He · ECCV 2018 — the batch-size-robust alternative