Weight Initialization
Xavier, He, and the math of exploding gradients.
You built the MLP. Forward pass, backward pass, update rule — every wire, every gradient, by hand. And then, the first time you ran it on a 20-layer network, it did nothing. Loss flat as a desk. Or loss in scientific notation, climbing. Same network. Same data. Same gradient descent. The only thing you changed was the numbers you wrote into the weight tensors before step zero.
Think of every layer as having a volume knob — the scale of its initial weights. Turn every knob too low and the signal fades to silence by the time it reaches the output; activations collapse to zero and gradients follow. Turn every knob too high and the signal clips and blows the speakers; activations saturate, gradients explode. There's a sweet-spot setting, and — here's the part nobody tells you — the setting depends on which activation you picked. Bad init isn't bad math. It's the sound guy fumbling the volume knobs before the band walks on.
Two names carry this entire chapter: Xavier (Glorot & Bengio, 2010) and He (Kaiming He et al., 2015). Same idea, different math, different activation. By the end of the page you'll have turned those knobs yourself, watched a 20-layer network die in real time, and then watched the same network come alive with one line changed. The math is a paragraph. The intuition is a volume knob. The consequence, if you get it wrong, is silent death.
Here's the mechanical story. A forward pass is a stack of matrix multiplies. Each layer takes the previous activations, multiplies by W, adds a bias, and squashes through a nonlinearity. That's it. If the weights are too small on average, Wx is smaller than x, and each layer shrinks the signal a little more than the last. Stack 20 of those and the final layer is looking at numbers so close to zero the float barely represents them. When the gradient flows back, it passes through the same matrices in reverse — so it shrinks too. The chain-rule product blows up or dies with depth, and a dead gradient is a weight that never moves. Congratulations: you have a network, and it is decorative.
Too large is the symmetric disaster. Wx keeps growing, pre-activations shoot past the range where your nonlinearity does anything interesting, and derivatives saturate to zero. Different cause, same outcome — nothing learns. The question isn't should we care about init. It's what does the knob setting need to be.
The knob setting comes from a one-paragraph variance calculation, and this is genuinely the whole lesson — everything after is consequence. Suppose the inputs x and weights W are zero-mean and independent. The pre-activation z = Σⱼ Wⱼ · xⱼ is a sum of fan_in independent products. Variance adds for independent sums; variance of a product of independent zero-mean variables is the product of variances. So:
Var(z) = Σⱼ Var(Wⱼ · xⱼ)
= fan_in · Var(W) · Var(x)
If we want Var(z) = Var(x) → Var(W) = 1 / fan_in (LeCun init)Read the last line slowly. Each layer multiplies the variance of the signal by fan_in · Var(W). If that product is 1, the volume holds steady across depth — the knob is set right. If it's 2, variance doubles every layer; after 20 layers you're at 2²⁰ ≈ 10⁶. If it's 0.5, variance halves every layer and you're at 10⁻⁶. A tiny per-layer error, raised to the depth of the network, is a catastrophe. This is why a detail that looks cosmetic is actually load-bearing.
That's one knob rule. The nonlinearity gets a vote in the final answer. For tanh (zero-centered, derivative ≈ 1 near zero), the post-activation has roughly the same variance as the pre-activation, so preserving Var(z) preserves Var(activations). Glorot added a symmetry argument — you also want variance to be preserved going backwards through the layer, which uses fan_out — and split the difference with a harmonic-style mean:
Var(W) = 2 / (fan_in + fan_out)
Now the ReLU story. ReLU zeros out every negative pre-activation — half the distribution gets clipped. So the post-activation carries only half the variance of the pre-activation. To hold the volume steady, you have to double the variance of W to pay back what ReLU takes. That's He init:
Var(W) = 2 / fan_in
Two knob rules, one per activation. Xavier tunes the knob for tanh/sigmoid; He tunes it for ReLU. The one more knob worth knowing is orthogonal init, which initialises W as an orthogonal matrix and preserves variance exactly by construction; it's activation-agnostic and shines for very deep linear stacks. Three rules, all the same idea — don't let depth amplify or squash the signal. Let's actually listen to what happens when you get it wrong.
Switch the activation to ReLU. Naive init — unit variance, no fan_in scaling — either rockets past 10⁴ (knob blown) or collapses under 10⁻⁹ (fade to silence) by layer 20. Xavier is better but drifts; it was tuned for the wrong activation. He sits right on the green dashed line at Var = 1 all the way down. Now flip the toggle to tanh and watch the roles swap — Xavier lands on the line, He tends to saturate. No single init is “good” in the abstract. The knob has to match the activation.
I decide whether your 100-layer network learns or sits there for two hours pretending to train before the NaNs show up. I am the one hyperparameter the paper's abstract never mentions. Get me right and the cleanest architecture in the world works on the first try; get me wrong and it's a hallucination machine by step one. The knob is me. I am not optional.
Same three strategies, different camera angle. Forget averages — pick a depth and look at the full histogram of activations in a batch. Naive collapses into a spike at zero or a flat explosion you can't even plot on one axis. Xavier is a narrow bump, technically alive but barely. He is a clean half-Gaussian (ReLU ate the negative half; that's the shape you want). Drag the depth k slider and watch each distribution evolve layer-by-layer.
Crank the depth to 15 with ReLU + naive init. Most of the batch has collapsed to zero — millions of dead neurons, each one a weight that will never see a gradient and therefore never train. Flip to He. The distribution at layer 15 looks the way it looked at layer 1. That's the whole point: the right knob preserves the shape of the signal no matter how deep the network.
Four recipes cover essentially every network you'll build. Click through the cards to see the formula, the activation it pairs with, and the PyTorch one-liner. Treat this as your volume-knob cheat sheet.
Designed for ReLU — which kills half the activations, so we compensate by doubling the variance.
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
Bias can be zero. Don't randomize biases. They don't have a symmetry-breaking job (that's what the random weights do), and zeroing them avoids nudging your initial activations in any direction. PyTorch's nn.Linear defaults to a tiny uniform bias, which is nearly zero but not quite; both work fine in practice.
Match the init to the activation. He + tanh is suboptimal — tanh won't saturate, but you're wasting variance for no reason. Xavier + ReLU halves your activations every layer; that's the fade-to-silence failure mode. Match them, or reach for orthogonal init which is activation-agnostic.
PyTorch's default isn't exactly He. nn.Linear ships with Kaiming uniform plus a correction factor. Good enough for most ReLU nets out of the box, suboptimal for tanh/sigmoid, always worth overriding explicitly with nn.init.xavier_normal_ when you change the activation.
LayerNorm / BatchNorm softens the init requirement. If every linear layer is followed by a normalisation that rescales activations, the norm resets the volume knob for you and sloppy init becomes survivable. That's one reason transformers (which use LayerNorm religiously) get away with simpler init than a plain deep MLP.
Three layers, each shorter than the last. Pure Python generates one random weight matrix and shows the explicit sqrt(2 / fan_in) — no abstraction, no library, just the formula you derived above turned into a list comprehension. NumPy vectorises it in a single call. PyTorch hands you He and Xavier as one-liners and picks a reasonable default if you forget.
import math, random
random.seed(0)
def he_init_matrix(fan_in, fan_out):
"""Gaussian, mean 0, std sqrt(2 / fan_in)."""
std = math.sqrt(2 / fan_in)
W = [[random.gauss(0, std) for _ in range(fan_in)] for _ in range(fan_out)]
return W
def xavier_init_matrix(fan_in, fan_out):
std = math.sqrt(2 / (fan_in + fan_out))
W = [[random.gauss(0, std) for _ in range(fan_in)] for _ in range(fan_out)]
return W
W = he_init_matrix(64, 32)
flat = [w for row in W for w in row]
mean_abs = sum(abs(w) for w in flat) / len(flat)
var = sum(w * w for w in flat) / len(flat)
print(f"fan_in=64 fan_out=32")
print(f"W shape: {len(W)} × {len(W[0])}")
print(f"|W| mean: {mean_abs:.3f} std: {var ** 0.5:.3f}")
print(f"expected std: sqrt(2/64) = {(2/64) ** 0.5:.3f}")fan_in=64 fan_out=32 W shape: 32 × 64 |W| mean: 0.147 std: 0.176 expected std: sqrt(2/64) = 0.177
import numpy as np
def he_init(fan_in, fan_out, rng=None):
rng = rng or np.random.default_rng()
return rng.normal(0, np.sqrt(2 / fan_in), size=(fan_out, fan_in))
def xavier_init(fan_in, fan_out, rng=None):
rng = rng or np.random.default_rng()
return rng.normal(0, np.sqrt(2 / (fan_in + fan_out)), size=(fan_out, fan_in))
# Build a 4-layer MLP with He init
rng = np.random.default_rng(0)
sizes = [784, 256, 128, 64, 10]
Ws = [he_init(sizes[i], sizes[i + 1], rng) for i in range(len(sizes) - 1)]
for i, W in enumerate(Ws):
print(f"layer {i}: shape {W.shape} std {W.std():.4f} (target {np.sqrt(2 / W.shape[1]):.4f})")nested comprehensions of random.gauss←→rng.normal(0, std, size=(fan_out, fan_in))— one call for the whole matrix
hardcoded per-layer←→list comprehension over sizes— tracks fan_in/fan_out automatically
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, sizes):
super().__init__()
self.layers = nn.ModuleList(
[nn.Linear(sizes[i], sizes[i + 1]) for i in range(len(sizes) - 1)]
)
self._init_weights()
def _init_weights(self):
for layer in self.layers:
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
nn.init.zeros_(layer.bias)
def forward(self, x):
import torch.nn.functional as F
for layer in self.layers[:-1]:
x = F.relu(layer(x))
return self.layers[-1](x)
model = MLP([784, 256, 128, 64, 10])
for i, layer in enumerate(model.layers[:2]):
fan_in = layer.weight.shape[1]
print(f"layer{i}: std={layer.weight.std().item():.4f} (He target {(2/fan_in)**0.5:.4f})")layer0: std=0.0503 (He target 0.0505) layer1: std=0.1247 (He target 0.1250)
rng.normal(0, sqrt(2/fan_in), ...)←→nn.init.kaiming_normal_(w)— same math, done in-place on any tensor
W = np.zeros(...) # bias←→nn.init.zeros_(bias)— explicit zero-init for biases
Xavier variant←→nn.init.xavier_normal_(w)— matched to tanh/sigmoid
Build a 20-layer MLP, width 64. Use naive init (W ~ 𝒩(0, 1), no fan_in scaling) and ReLU. Push a batch of 𝒩(0, 1) inputs through a forward pass and record the activation variance at every layer. It will either rocket past 10⁶ or collapse below 10⁻⁶ before it reaches the output. That's the knob set wrong.
Now swap in He. The variance should stay within one order of magnitude of 1 at every layer. Plot both curves on the same axes. This is the difference between a network that trains and a very expensive way to produce zeros.
Bonus: drop a BatchNorm after every linear layer and re-run with naive init. The batchnorm drags variance back toward 1 and rescues most of the damage — which is why practitioners can get away with sloppy init as long as a normalisation layer is never far away.
What to carry forward. Every layer is a volume knob, and the setting of that knob is the variance of its weights. Xavier (2 / (fan_in + fan_out)) tunes the knob for tanh/sigmoid. He (2 / fan_in) tunes it for ReLU. Getting the wrong one for your activation silently kills deep networks — not with an exception, with a flat loss curve and wasted GPU hours. BatchNorm and LayerNorm soften the requirement but don't erase it. The three lines you'll actually type in PyTorch are nn.init.kaiming_normal_, nn.init.xavier_normal_, and nn.init.zeros_ for biases.
End of section. You've built a neural network from scratch. Forward pass. Backward pass. Initialization. The optimizer. The loss. You've turned every knob yourself, weight by weight. Now hand it off to a framework that does all of this — initialization included — in three lines. Meet PyTorch.
- [01]Glorot, Bengio · AISTATS 2010 — the Xavier init paper
- [02]He, Zhang, Ren, Sun · ICCV 2015 — the He init paper
- [03]Saxe, McClelland, Ganguli · ICLR 2014 — the orthogonal init paper