RMS Normalization
LayerNorm without mean subtraction — why it works.
LayerNorm does two things in one breath. It subtracts the mean of the feature vector (the center step), then divides by the standard deviation (the scale step). Two moves, one op, and every transformer from 2017 onwards took it as gospel.
Then someone stress-tested the first half. Turns out the mean subtraction is the part that nobody bothered to ablate. Keep the scale, drop the center, retrain — and the model is fine. Not equivalent-on-paper fine, but within-noise-of-your-seed fine, which is the only kind of fine that actually matters when you're spending eight-figure sums on training runs.
That's RMSNorm. A 2019 paper by Zhang and Sennrich. It's LayerNorm minus a step. The step we found we didn't need. It ships in Llama, PaLM, Gemma, Mistral, and basically every open large-language model published after 2022. The rest of this lesson is why that step was safe to skip, what the math looks like after you skip it, and why “a step we didn't need” translates into real wall-clock savings at scale.
# LayerNorm (what we already know) μ = (1/D) · Σⱼ xⱼ # ← mean subtracted x̂ⱼ = (xⱼ − μ) / √(σ² + ε) # ← std divided yⱼ = γⱼ · x̂ⱼ + βⱼ # ← γ scale + β shift # RMSNorm rms = √( (1/D) · Σⱼ xⱼ² + ε) # ← just RMS (no mean) yⱼ = γⱼ · (xⱼ / rms) # ← γ scale, no β
Two lines vanish. The mean μ never gets computed, and the learnable shift β goes with it — there's nothing to shift around once you're no longer centering. What survives is the divide-by-RMS and the learnable scale γ. That's the entire op. Drag the widget below and you'll see both layers chew on the same input with the same dial.
γ absorbs most of it, and empirically RMSNorm trains just as well while skipping the mean computation entirely.Slide the offset. LayerNorm's output barely flinches — the mean subtraction cancels any bulk shift in the input before it hits the scale step. RMSNorm's output moves with the shift, because there's nothing left to cancel it. On paper that sounds like a real property loss. In practice the next Linear layer has a bias, and it learns to eat whatever offset the upstream tensor arrives with. The invariance wasn't doing much; the downstream weights were doing the work anyway.
I skip the mean. If your feature vector is [1, 2, 3, 4], LayerNorm would subtract 2.5 from each element before rescaling. I just rescale. Turns out the mean subtraction wasn't earning its keep — your network compensates through other parameters — and I save you one full read and subtract over the activation tensor. At D=4096, L=32 layers, batch of millions of tokens, those saved passes turn into real wall-clock time.Let's quantify what the skipped step actually buys you. Norm layers aren't compute-bound — they're memory-bound. The clock time you pay is mostly the time spent shuffling activations in and out of HBM, not the arithmetic inside. Drop one pass over the tensor and you drop roughly that fraction of the runtime.
| sub-op | LayerNorm | RMSNorm |
|---|---|---|
| compute mean | ✓ | — |
| subtract mean | ✓ | — |
| compute mean(x²) | via var | ✓ |
| compute var | ✓ | — (uses mean(x²)) |
| rsqrt + divide | ✓ | ✓ |
| γ · scale | ✓ | ✓ |
| + β shift | ✓ | — (no β) |
(8, 2048, 4096) tensor takes about ~200 µs; RMSNorm comes in around ~170 µs. A 15% per-norm saving, summed over every transformer block, becomes a noticeable speedup at scale.Crank the dimensions up to pre-training scale (B = 8, S = 2048, D = 4096 is a realistic batch). The op-count delta is around 20% — which tracks with the wall-clock speedups teams report in practice. Multiply that by 32 transformer blocks, each with two norm layers, and the mean-subtract step you deleted just saved you hundreds of millions of operations per forward pass. Per forward pass. You do a lot of forward passes.
When should you pick which? Rough guide:
- New transformer architecture: use RMSNorm. It's the modern default and you'll match the codebase conventions of every recent paper.
- Reproducing an older paper: use whatever it used. GPT-2/3 and BERT used LayerNorm. Llama, PaLM, Gemma, Mistral use RMSNorm.
- Non-transformer networks: stick with LayerNorm or BatchNorm. RMSNorm hasn't been extensively validated outside the transformer setting and the savings matter less without hundreds of norm layers stacked back-to-back.
- Super-long context: the relative savings grow with sequence length. For million-token contexts, RMSNorm is essentially required.
No β in RMSNorm. If you're porting a LayerNorm layer to RMSNorm, don't forget the learnable bias is gone. The next Linear or attention projection usually has a bias of its own, so it absorbs whatever shift β would have learned. If you port it naively and leave a dangling β parameter around, it will sit there unused and you'll wonder why your param count is off.
Pre-norm placement still matters. RMSNorm is a drop-in for LayerNorm, but you still want pre-norm ordering — normalize before the sublayer, then add the residual. Every modern transformer codebase does it this way. Post-norm will train; it will also fight you the whole way.
Precision matters for the rsqrt. 1 / √(mean(x²) + ε) must be computed in fp32 even in fp16 or bfloat16 models, or numerical underflow gives you NaNs mid-training and a long debugging afternoon. PyTorch's built-in nn.RMSNorm (2.4+) handles this correctly; hand-rolled versions often do not.
nn.RMSNorm needs PyTorch 2.4+. For older versions, write it by hand in five lines. The tradeoff is no CUDA kernel fusion — the built-in is meaningfully faster on modern hardware.
Now the code. Two layers — NumPy and PyTorch. The op is simple enough that the NumPy version reads like the formula, which is exactly the point.
import numpy as np
def rms_norm(x, gamma=None, eps=1e-6):
"""RMSNorm over the last dim."""
rms = np.sqrt(np.mean(x ** 2, axis=-1, keepdims=True) + eps)
x_hat = x / rms
if gamma is not None:
x_hat = x_hat * gamma
return x_hat
rng = np.random.default_rng(0)
x = rng.normal(loc=[0.3, -0.2], scale=[1, 1], size=(8, 2)).T
print("input rms =", np.round(np.sqrt(np.mean(x ** 2, axis=-1)), 4))
y = rms_norm(x)
print("output rms per row ≈", round(np.sqrt(np.mean(y ** 2, axis=-1)).mean(), 4), "for all")input rms = [1.1456 0.8732] output rms per row ≈ 1.0000 for all
x.mean(-1, keepdims=True)←→[removed]— no mean subtraction
x - mean←→[removed]— no centering
x.var(-1)←→np.mean(x ** 2, -1)— mean(x²) instead of variance around μ
(x - mean) / sqrt(var + eps)←→x / sqrt(mean(x**2) + eps)— RMS divide
gamma * x_hat + beta←→gamma * x_hat— no β parameter
Five lines removed, one line changed. That's the whole simplification. The PyTorch version is the same thing with rsqrt for the divide (faster and more numerically friendly) and an explicit fp32 cast for mixed-precision sanity.
import torch
import torch.nn as nn
# PyTorch 2.4+ ships nn.RMSNorm natively. For older versions:
class RMSNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(features))
def forward(self, x):
# Cast to float for numerical stability, cast back after.
dtype = x.dtype
x = x.float()
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return (x * rms * self.gamma).to(dtype)
norm = RMSNorm(features=768)
x = torch.randn(8, 12, 768)
y = norm(x)
print(y.shape)
print(f"mean output rms ≈ {y.pow(2).mean(-1).sqrt().mean().item():.3f}")torch.Size([8, 12, 768]) mean output rms ≈ 1.000
x = x + Attn(LayerNorm(x))←→x = x + Attn(RMSNorm(x))— drop-in replacement — same position, same residual
x = x + MLP(LayerNorm(x))←→x = x + MLP(RMSNorm(x))— second norm point, also pre-norm
Pick any small transformer you have lying around, or spin up a two-block one from nn.TransformerEncoderLayer. Replace every nn.LayerNorm with the RMSNorm module above. Train both on a small language-modeling objective (WikiText-2 is fine) for 500 steps. Plot training loss for each.
Expected observations: (1) the loss curves are basically identical, (2) RMSNorm is ~10% faster end-to-end on a GPU, and (3) the difference in final validation perplexity is smaller than your run-to-run seed variance. If that last one surprises you, welcome to the empirical side of deep learning.
Bonus: benchmark just the norm layer in isolation with torch.cuda.synchronize() and time.perf_counter(). You'll see the 15–25% speedup on the op itself, which is where the end-to-end gains come from.
What to carry forward. RMSNorm is LayerNorm with one step removed — the mean subtraction — and the learnable bias that goes with it. Trains as well as LayerNorm on transformer workloads, runs meaningfully faster because it skips one pass over the activation tensor. Modern LLMs default to it. Old codebases still use LayerNorm. The two are drop-in replacements in either direction; the quality gap is below noise. This is the quiet architectural win — the kind of optimization you only notice because Llama uses it.
End of section. You now have the parts. Tensors, autograd, modules, a normalization layer that ships in every frontier model on the open web. The next section wires them into the four-line ritual that makes a model actually learn — forward, loss, backward, step. Every training run on every GPU in every data center you've heard of is that loop, with a lot of bookkeeping bolted on. We'll build it next.
- [01]Biao Zhang, Rico Sennrich · NeurIPS 2019 — the original RMSNorm paper
- [02]Touvron et al. · Meta, 2023 — first major open LLM to standardize on RMSNorm
- [03]PyTorch core team · pytorch.org (available from 2.4)