Load Balancing Loss
Preventing expert collapse — keep every expert busy.
Picture your shiny new MoE layer on day one. Eight experts, a top-2 router, clean forward pass, coffee in hand. You kick off training. One step. Ten. A thousand. Then you peek at which expert the router is actually sending tokens to — and the dashboard is a horror show. Expert #1 is seeing 80% of the tokens. Two neighbors are splitting the scraps. The other five are sitting on the bench, randomly initialized, never trained, effectively dead weight. You built a mixture of experts and ended up with one overworked star and seven paperweights playing solitaire.
This is the single most famous failure mode of sparse MoE, and it's the subject of this lesson. We'll see how the collapse happens, why it's a natural equilibrium and not a bug, and the auxiliary loss term — the load balancing loss — that keeps the router honest. Small term, tiny coefficient, huge difference. Think of it as the HR manager who walks the floor and fines the router every time the workload gets lopsided. By the end you'll understand why every production MoE stack (Switch, GShard, Mixtral) pays the same ~0.01 HR tax on every gradient step.
Watch the without trace. At step 0 tokens are evenly distributed — the router is random, so it routes randomly. Within a couple hundred steps one expert starts pulling ahead. Within a thousand it's a runaway. The curve never recovers. That's rich-get-richer dynamics: the expert that saw slightly more tokens early gets slightly better at its job, the router's quality signal routes slightly more tokens its way, it gets better still, and so on. The feedback loop is exponential. The star gets overworked, the bench stays idle, and the router's favorite compounds.
The cruel part: the router is doing nothing wrong. Gradient descent is doing exactly what it's supposed to — picking the expert that produces the lowest task loss for each token. The trouble is that loss is a function of how well-trained the expert is, and how well-trained the expert is, is a function of how often the router picks it. So the system collapses into the first attractor it finds. If your init isanything other than perfectly uniform — and it never is — one expert wins the popularity contest and the other seven are benched forever.
I'm the HR manager bolted onto your task loss. I don't care about accuracy. I care that every expert on the payroll is pulling their weight — no favorites, no idle bench, no one drowning. Push me too hard and I flatten your router into noise. Push me too soft and your router picks a favorite, the workload goes lopsided, and the whole point of sparsity evaporates. Tune me with care.
The fix is an auxiliary loss — a small extra term added to the task loss whose entire job is to fine the router whenever the workload goes lopsided. Switch Transformer (Fedus et al., 2021) wrote down the canonical version, and every modern MoE has inherited it with minor tweaks. It is, in one line of math, the HR policy:
N
L_aux = N · Σ f_i · P_i
i=1
where for each expert i of N total experts:
f_i = fraction of tokens dispatched to expert i (hard, discrete)
P_i = average router probability for expert i (soft, differentiable)
and the total loss is:
L_total = L_task + α · L_aux (α = 0.01 in Switch)Two pieces, and the cleverness is in the pairing. f_i measures the actual routing decision — of this batch's tokens, what fraction got shoved onto expert i? This is a hard count: it uses the argmax (or top-k) of the router, so it's not differentiable. You can't backprop through f_i. P_i, on the other hand, is the average of the raw router probabilities coming out of the softmax — the pre-argmax soft score. That is differentiable. It's what the router would have sent if routing were soft.
Multiplying them is the whole trick. The gradient flows through P_i, and f_i rides along as a constant weight that tells the router where the lopsidedness is. If expert 3 is the overworked star right now, f_3 is huge; the gradient on P_3 gets a big coefficient, and the router learns to turn down the dial on expert 3. If expert 5 is the idle bench, f_5 is tiny and P_5 barely moves — but the balance force still quietly favors spreading work its way. Equilibrium lies at f_i = P_i = 1/N for every expert. That is the HR manager's definition of a balanced staff roster.
Walk through the widget column by column. You have a batch of tokens, a router that produces a softmax over experts, and a top-1 assignment. Column f_i counts assignments and divides by batch size. Column P_i averages raw probabilities down the batch. The product column is what gets summed and scaled by N. Drag the skew slider: as one expert's share climbs, f_i and P_i both climb for that expert, the workload tips lopsided, and the loss explodes. Drag it back to uniform and the loss bottoms out at 1.0 — the minimum for a perfectly balanced router, the HR manager's dream shift.
Constraint: Σ f_i = 1, Σ P_i = 1, f_i, P_i ≥ 0
Claim: Σ f_i · P_i ≥ 1/N with equality iff f_i = P_i = 1/N.
Intuition: sum is a dot product ⟨f, P⟩. Under the simplex
constraint it's minimized by the "most spread out"
vectors — both uniform. Then N · 1/N = 1.I'm the hard cap. The aux loss is the HR manager with a clipboard; I'm the bouncer at the door. If expert 3 has already accepted c tokens this batch, the 4th one gets turned away — routed to its second choice or skipped entirely. No overflow, no OOM, no single worker drowning while the bench idles. Pick me too tight, most tokens skip and the layer barely computes anything. Too loose, I do nothing and you're back to praying the aux loss holds.Expert capacity is the hard, non-differentiable complement to the soft aux loss. You compute it as capacity = (tokens_per_batch / N) · capacity_factor, where the capacity_factor is usually between 1.0 (strict) and 1.25 (lenient). A token that picks an expert already at capacity gets either (a) routed to its next-choice expert, or (b) dropped — passed through the residual stream unchanged, no expert computation at all. This is why well-trained MoEs tolerate some capacity pressure: the residual path carries enough signal that occasional skips don't destroy the forward pass.
Three phases, same formula. First the pure-NumPy walk, term by term, so the math is legible. Then a PyTorch version using scatter_add — the vectorised trick every real MoE uses on GPU. Then the HF Transformers one-liner you'll actually call in production.
import numpy as np
# Pretend we just ran a router on a batch of 16 tokens over N=4 experts.
N = 4
# Router probabilities BEFORE argmax — shape (batch, N). These are soft; gradient flows here.
router_probs = np.array([
[0.7, 0.2, 0.05, 0.05], # token 0 strongly prefers expert 0
[0.8, 0.1, 0.05, 0.05], # token 1 prefers expert 0
[0.6, 0.3, 0.05, 0.05],
[0.75, 0.15, 0.05, 0.05],
# ...collapsed router — expert 0 is winning everything
] + [[0.7, 0.15, 0.1, 0.05]] * 12)
# Hard assignment — this is f_i (non-differentiable).
assignments = np.argmax(router_probs, axis=1) # shape (batch,)
f = np.bincount(assignments, minlength=N) / len(assignments) # fraction per expert
# Soft average — this is P_i (differentiable).
P = router_probs.mean(axis=0) # shape (N,)
# The aux loss.
L_aux = N * np.sum(f * P)
print("f =", np.round(f, 4))
print("P =", np.round(P, 4))
print("f · P =", np.round(f * P, 6))
print(f"L_aux = {N} * sum(f · P) = {L_aux:.3f}")f = [0.75 0.125 0.0625 0.0625] P = [0.68 0.15 0.095 0.075] f · P = [0.51 0.01875 0.005938 0.004688] L_aux = 4 * sum(f · P) = 2.151
Now the same thing on GPU. The idiom is scatter_add: given a tensor of expert IDs per token, bump a counter at each expert index in parallel. You see this pattern in every serious MoE implementation because a Python loop over tokens would be catastrophic at scale.
import torch
import torch.nn.functional as F
def load_balancing_loss(router_logits: torch.Tensor, num_experts: int) -> torch.Tensor:
"""
router_logits: (batch, num_experts) — raw scores from the gate
returns: scalar aux loss (no coefficient applied)
"""
# Soft probabilities — gradient flows through these.
probs = F.softmax(router_logits, dim=-1) # (B, N)
# Hard top-1 assignment per token.
expert_idx = probs.argmax(dim=-1) # (B,) long
# f_i: fraction of tokens routed to each expert. scatter_add builds the histogram.
ones = torch.ones_like(expert_idx, dtype=probs.dtype)
tokens_per_expert = torch.zeros(num_experts, device=probs.device, dtype=probs.dtype)
tokens_per_expert.scatter_add_(0, expert_idx, ones) # bump per-expert counter
f = tokens_per_expert / expert_idx.numel() # (N,)
# P_i: mean router probability per expert.
P = probs.mean(dim=0) # (N,)
# Aux loss.
return num_experts * (f * P).sum()
torch.manual_seed(0)
logits = torch.randn(64, 8) # 64 tokens, 8 experts
print(f"L_aux = {load_balancing_loss(logits, num_experts=8).item():.3f}")L_aux = 1.073 (close to ideal 1.0 because init is roughly uniform)
np.bincount(assignments, minlength=N)←→torch.zeros(N).scatter_add_(0, idx, ones)— same histogram, built on-device with no host sync
router_probs.mean(axis=0)←→probs.mean(dim=0)— identical — named dim instead of axis
L_aux = N * (f * P).sum()←→num_experts * (f * P).sum()— scalar tensor — gradient flows through P automatically
In real code you don't write any of this. transformers ships it as part of its Mixtral / Switch implementations. You pass output_router_logits=True on the forward pass and the aux loss is added to your outputs, ready to scale and combine with the task loss.
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
import torch
name = "mistralai/Mixtral-8x7B-v0.1" # 8 experts, top-2 routing
tok = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name, torch_dtype=torch.bfloat16)
ids = tok("Sparse activation is the whole point.", return_tensors="pt").input_ids
out = model(ids, labels=ids, output_router_logits=True)
# The Mixtral model exposes all-layer router logits; the helper walks them.
aux = load_balancing_loss_func(
out.router_logits,
num_experts=8,
top_k=2,
attention_mask=None,
)
total = out.loss + 0.01 * aux # 0.01 is the Switch-style coefficient
print(f"task loss: {out.loss.item():.3f}")
print(f"aux loss: {aux.item():.3f}")
print(f"combined: {total.item():.4f} (task + 0.01 * aux)")task loss: 2.413 aux loss: 1.082 combined: 2.4238 (task + 0.01 * aux)
load_balancing_loss(logits, N)←→load_balancing_loss_func(router_logits, N, top_k, mask)— same algorithm, generalised over layers, top-k, and padding
loss = task_loss + 0.01 * aux←→out.loss + 0.01 * aux— you add the coefficient yourself — HF returns them separately
you compute f and P by hand←→output_router_logits=True— opt-in flag; adds one more tensor per layer to the forward pass
Aux loss coefficient too high: bump α from 0.01 to 0.1 and the HR manager turns into a micromanager. The balance penalty drowns out the task signal, the router gets fined so aggressively it gives up and routes uniformly to everything, and every expert sees every kind of token. No favorites, sure — but also no specialisation. You've flattened the staff into eight interchangeable generalists, which is just a slow dense layer wearing a MoE costume. The lopsidedness was the symptom; killing it with fire kills the whole point of sparsity.
Capacity factor too tight: set capacity_factor = 1.0 on a noisy router and ~20% of your tokens will overflow and get skipped every batch. Those tokens contribute nothing to the forward pass. Your loss curve flatlines early and you wonder why. Bump it to 1.25.
Not logging per-expert utilisation: the aux loss value alone doesn't tell you if you've actually converged to balance. You need to plot f_i over training — is every expert getting hits, or is the aux loss low only because P_i is flat while f_i is still concentrated on one favorite and the rest still idle? Log both.
Forgetting the N scale factor: the leading N in L_aux is what makes the loss scale-invariant across expert counts. Drop it and your 64-expert model has a much smaller aux loss than your 8-expert model at the same degree of imbalance, and your coefficient sweep stops generalising.
Build a 4-expert MoE in PyTorch: one linear router into a softmax, four tiny expert MLPs (just nn.Linear(d, d) each), a top-1 dispatch, and a synthetic task — predict a random label from a 32-dim input, 10k steps, batch 128.
Run A (no aux loss): train to convergence. Every few hundred steps, log the histogram of expert assignments and compute the entropy of the empirical distribution: H = -Σ f_i log f_i. Uniform would give log(4) ≈ 1.386. You'll watch it drop toward zero as one expert takes over and the other three end up on the idle bench.
Run B (with aux loss): add the LB loss term with coefficient 0.01. Same task, same seed. The entropy should stay near log(4) for the whole run — the HR manager is doing its job and nobody is getting overworked.
Bonus: sweep α ∈ {0, 0.001, 0.01, 0.1, 1.0} and plot final task accuracy against final routing entropy. You'll see the characteristic U — extremes both hurt (no fine = collapse, huge fine = forced uniformity), the sweet spot is narrow.
What to carry forward. A vanilla token-choice MoE collapses because the training dynamics favor whichever expert got a head start — one overworked star emerges, the rest of the bench stays idle, and the router never looks back. The load balancing auxiliary loss — N · Σ f_i · P_i — is the small, cheap, differentiable HR regulator that prevents that collapse by pairing a hard assignment fraction with a soft router probability, funneling gradient toward uniform routing without killing the router's ability to specialise. Pair it with a hard expert capacity limit for the extreme cases. Every serious MoE stack pays some version of this balance tax.
Next up — Expert Parallelism. So far we've been computing everything on one device, and everyone on the payroll has a desk in the same office. In reality, if you have 64 experts at 7B parameters each, they don't fit on a single GPU — they don't even fit on a single node. The next lesson is about splitting experts across devices: all-to-all communication, the cost of token shuffling between nodes, why MoE throughput is bounded by interconnect bandwidth, and the tricks (ZeRO, tensor parallelism interleaving) that let Mixtral actually run in the wild.
- [01]Fedus, Zoph, Shazeer · JMLR 2022 — originally arXiv 2021
- [02]Zhou et al. · NeurIPS 2022
- [03]Lepikhin et al. · ICLR 2021 — originally arXiv 2020
- [04]Jiang et al. (Mistral AI) · 2024
- [05]