Top-k Routing

The gating network that picks which experts see each token.

Hard
~15 min read
·lesson 2 of 4

Picture a room with N specialist experts inside and a single door. Every token in your batch is a guest approaching that door one at a time. Who gets in? If you throw the door wide open and let everyone enter every room, you have N experts doing work per token — a dense model wearing a sparse costume, and you just paid the compute bill for nothing. If you slam the door shut, nobody runs and your model outputs zeros. Neither extreme is what MoE promised you.

You need a bouncer. The bouncer stands at the door with a clipboard, reads each token's ID, scores every expert on the guest list against that ID, and waves exactly k names through. Everyone else stays outside for this token. Next token walks up, gets a fresh scan, differentk names get waved through. The room is never crowded and never empty — it's always running at exactly k experts per guest, which is the one setting where sparse MoE actually pays off.

That bouncer is the router: a tiny linear layer that scores every expert for every token, plus a top-k selection step that keeps only the best few names on the list. It's three lines of code and it's the reason sparse models exist. This lesson covers what the bouncer actually is, why we pick k experts instead of 1 or N, the noisy top-k trick that keeps the whole thing from collapsing in training, and how to write it in pure Python, NumPy, and PyTorch.

Router (personified)
I am a Linear(d_model, n_experts). That is all. A softmax on top of me turns my logits into gate probabilities, and top-k picks the winners I wave through the door. I'm three hundred parameters guarding a billion. Treat me carefully.

Start with one token, vector x ∈ ℝ^d — the guest's ID the bouncer reads. The router is a single matrix W_r ∈ ℝ^(d × N). Multiply to get a score per expert, softmax to turn raw scores into a probability distribution over the N names on the list:

router — logits and gate probabilities
logits   =   x · W_r               ∈ ℝᴺ

g(x)ᵢ    =   exp(logitsᵢ)
             ────────────────────    softmax over experts
             Σⱼ exp(logitsⱼ)

Now g(x) is a probability vector over experts — it sums to 1 and every entry is positive. In a dense MoE, the output is just Σᵢ g(x)ᵢ · Eᵢ(x): run every expert, weight each output by its gate probability, sum. That's the “open door” version — correct, differentiable, and financially ruinous. If you have 128 experts you just did 128× the work of a dense model with one FFN.

Sparse MoE makes a different bet: most of g(x) is approximately zero anyway. Only the top few entries carry real weight. Keep those, drop the rest — the bouncer only waves through the names that actually matter. That's top-k.

top-k selection and renormalization
TopK(g, k) = { indices of the k largest entries of g }

g̃(x)ᵢ  =  {  g(x)ᵢ / Σⱼ∈TopK g(x)ⱼ     if i ∈ TopK
          {  0                           otherwise

y      =   Σᵢ∈TopK  g̃(x)ᵢ · Eᵢ(x)

Three things just happened at the door. The bouncer picked the k best-scoring names from the list. Their gates got renormalized so they still sum to 1 (otherwise the total gate mass drops below 1 and every residual connection starts to drift). And only those k experts actually ran — not N. The FLOPs per token are now k · C_expert instead of N · C_expert, independent of how many experts you have. Adding the 128th expert costs parameters but not compute. That's the whole game.

Below: a batch of token embeddings flowing through a router with 8 experts. Watch the logits, the softmax, the top-k cut, and the renormalization. The chosen names on the list highlight; the rest fall dark at the door. Change k and see the mass redistribute.

router → softmax → top-k → renormalize
one token, four stages, E = 8
1 · raw logits z_i
soft, broad affinity — almost any expert works
E01.40
E11.60
E21.10
E31.30
E41.20
E51.00
E61.50
E71.30
2 · softmax p_i = exp(z_i) / Σ exp(z_j)
all positive, sums to 1
E00.136
E10.166
E20.101
E30.123
E40.111
E50.091
E60.150
E70.123
3 · mask — keep top-2, rest → 0
dropped experts contribute no compute and no weight
E0·
E10.166
E2·
E3·
E4·
E5·
E60.150
E7·
4 · renormalize → mixture weights
weights for the experts that will actually run
E0·
E10.525
E2·
E3·
E4·
E5·
E60.475
E7·
activeE1, E6
Σ weights1.000
Router (personified)
I don't get gradient through who I pick at the door. I get gradient through how much I weight them once they're inside. Turns out that's plenty — if my top name keeps giving good outputs, the softmax will keep voting for it, and its share of the gate will keep growing.

Worth being explicit about the trade-off. Soft routing is the open-door policy: every name on the list gets the token, weighted by g(x). Fully differentiable, slightly more accurate, and it nukes the entire reason to use MoE — you're doing the FLOPs of a single FFN. Hard routing (top-k with k < N) is the bouncer: only the chosen k experts run. The sparsity is the payoff; the non-differentiable selection is the cost. Every production MoE pays that cost gladly.

flops per token: dense vs sparse MoE
Dense FFN              =   C_expert                    (one FFN)

Soft MoE (k = N)       =   N · C_expert                 (every expert runs)

Sparse MoE (top-k)     =   k · C_expert  +  d · N        (k experts + router)
                              ▲                ▲
                              dominant         rounding-error cost

Typical:   d = 4096,  N = 64,  k = 2
           router  ≈  4096 · 64       ≈  0.26M flops
           experts ≈  2 · 50M         ≈  100M flops
           → k dominates; router is free

So the slider that matters is k: it scales FLOPs linearly while quality gains plateau fast. Play with it.

k sweep — quality saturates, FLOPs keep climbing
E = 8 experts · k=2 is the sweet spot
FLOPs per token vs model quality
FLOPs linear in k · quality saturates
124812345678k (experts per token)relativek = 2 · sweet spotFLOPs (× base)quality (0..1)
routing entropy − Σ p log₂ p
more slots → wider distributions
012312345678k (experts per token)bitsH (bits)
+14.8% quality for +100% FLOPs going k=1 → 2diminishing returns past that
FLOPs2.0×
quality0.868
H(router)1.06 bits

Drag k from 1 to N. The FLOPs line is a ruler — perfectly linear in k. The quality line is not: it rises sharply from k=1 to k=2, creeps up a bit to k=4, and flattens out. Return to the bouncer — this is how k changes the tradeoff at the door, and the published choices follow the curve exactly.

  • k=1 — Switch Transformer. The simplest, fastest, and — thanks to Fedus et al. 2021 — the one that showed you can scale MoE to trillions of parameters without the training instability everyone feared. A hard choice: the bouncer reads the ID, picks one name, done. load balancing is doing a lot of heavy lifting to keep it stable.
  • k=2 — Mixtral, GLaM, ST-MoE. Twice the FLOPs of k=1, a meaningful quality bump, and enough redundancy that a single bad pick at the door doesn't tank a token's forward pass. Top expert plus a backup — a tie-break built in. This is the current default.
  • k=4+ — diminishing returns, free lunch over. You're buying noise. The room gets crowded, the top two names already carry the vast majority of the probability mass on most tokens, and adding a third or fourth barely changes the weighted sum. The sparsity discount shrinks while the quality line has already flattened.
Top-k (personified)
I'm the capacity planner at the door. Pick me too small and one lucky expert eats the whole batch. Pick me too big and the room is crowded and you've paid for a dense model in expert's clothing. Two. The answer is usually two.

There's also a train / inference split worth knowing about. Some recipes use dense (soft) routing during training so every expert gets gradient every step, and switch to sparse top-k only at inference for speed. Others use top-k throughout and rely on noisy top-k plus a load balancing loss to keep experts healthy. Reinforce-style gradients — treating expert selection as a policy and getting gradient through the discrete choice via score-function estimators — have also been explored but aren't the standard today. Straight noisy top-k with a balance loss is the current best practice.

Three implementations. The pure-Python version shows the skeleton — what top-k means when you spell out the bouncer line by line. NumPy introduces argpartition, the best-of-both-worlds trick (partial sort for free). PyTorch hands you torch.topk and scatter and you use them like any other op.

layer 1 — pure python · top_k_scratch.py
python
import math

def softmax(xs):
    m = max(xs)
    exps = [math.exp(x - m) for x in xs]                   # subtract max for stability
    z = sum(exps)
    return [e / z for e in exps]

def top_k_gate(logits, k):
    gates = softmax(logits)                                # probability per expert

    # argsort descending, take first k indices
    ranked = sorted(range(len(gates)), key=lambda i: gates[i], reverse=True)
    top_idx = ranked[:k]

    # renormalize: gates / sum(top-k gates)
    top_mass = sum(gates[i] for i in top_idx)
    top_gates = [gates[i] / top_mass for i in top_idx]

    return top_idx, top_gates

logits = [0.8, -0.2, 1.5, 0.3, -1.1, 2.1, 0.0, 0.9]
idx, gates = top_k_gate(logits, k=2)
print("top-2 idxs  =", idx)
print("top-2 gates =", [round(g, 2) for g in gates])
stdout
logits       = [0.8, -0.2, 1.5, 0.3, -1.1, 2.1, 0.0, 0.9]
gates        = [0.14, 0.05, 0.29, 0.09, 0.02, 0.53, 0.07, 0.16]
top-2 idxs   = [5, 2]
top-2 gates  = [0.64, 0.36]   # renormalized, sums to 1.0

Vectorise. In NumPy we process an entire batch at once — the bouncer checks every guest's ID in parallel — and we use argpartition to avoid paying for a full sort just to find the top-k. Partitioning is O(N) instead of O(N log N).

layer 2 — numpy · top_k_numpy.py
python
import numpy as np

def softmax(x, axis=-1):
    x = x - x.max(axis=axis, keepdims=True)                # log-sum-exp trick
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

def top_k_router(x, W_r, k):
    # x: (batch, d_model)        token embeddings
    # W_r: (d_model, n_experts)  router weights
    logits = x @ W_r                                       # (batch, n_experts)
    gates  = softmax(logits, axis=-1)                      # probabilities

    # argpartition: puts the top-k in the last k slots, unsorted. O(N).
    top_idx = np.argpartition(-gates, k, axis=-1)[:, :k]   # (batch, k)

    # gather the top-k gate values
    top_gates = np.take_along_axis(gates, top_idx, axis=-1)

    # renormalize across the k chosen experts
    top_gates = top_gates / top_gates.sum(axis=-1, keepdims=True)

    return top_idx, top_gates

np.random.seed(0)
x   = np.random.randn(4, 16)                               # batch of 4 tokens
W_r = np.random.randn(16, 8) * 0.1                         # 8 experts
idx, gates = top_k_router(x, W_r, k=2)
print("chosen experts\n", idx)
print("normalized gates\n", np.round(gates, 3))
pure python → numpy
sorted(range(n), key=...)[:k]←→np.argpartition(-gates, k)[:, :k]

partial sort — O(N) instead of O(N log N)

loop over tokens, build gates list←→x @ W_r, one matmul for the whole batch

routing is one matrix multiply

sum(gates[i] for i in idx)←→take_along_axis(gates, idx, axis=-1).sum(-1, keepdims=True)

batched gather + sum

PyTorch. torch.topk does both the selection and the gather in one call and returns the values sorted. scatter (or a simpler masked softmax) pushes the renormalized gates back into a full-size vector if you need it.

layer 3 — pytorch · top_k_router.py
python
import torch
import torch.nn as nn
import torch.nn.functional as F

class TopKRouter(nn.Module):
    def __init__(self, d_model, n_experts, k, noise_std=1.0):
        super().__init__()
        self.w_r     = nn.Linear(d_model, n_experts, bias=False)
        self.w_noise = nn.Linear(d_model, n_experts, bias=False)    # for noisy top-k
        self.k       = k
        self.noise_std = noise_std

    def forward(self, x, train=True):
        # x: (batch, d_model)
        logits = self.w_r(x)                                        # (B, N)

        if train and self.noise_std > 0:
            noise = torch.randn_like(logits) * F.softplus(self.w_noise(x))
            logits = logits + self.noise_std * noise                 # noisy top-k

        # torch.topk: values and indices of the k largest, in one call
        top_logits, top_idx = logits.topk(self.k, dim=-1)            # (B, k) each

        # softmax over just the k chosen logits — equivalent to full softmax
        # followed by renormalization, but numerically nicer and cheaper
        top_gates = F.softmax(top_logits, dim=-1)                    # (B, k)
        return top_idx, top_gates

torch.manual_seed(0)
router = TopKRouter(d_model=16, n_experts=8, k=2, noise_std=0.0)
x = torch.randn(4, 16)
idx, gates = router(x, train=False)
print("top-k idx\n", idx)
print("top-k gates (renormalized)\n", torch.round(gates, decimals=2))
stdout
top-k idx
 tensor([[5, 2],
         [3, 7],
         [5, 1],
         [2, 4]])
top-k gates (renormalized)
 tensor([[0.64, 0.36],
         [0.57, 0.43],
         [0.52, 0.48],
         [0.55, 0.45]])
numpy → pytorch
np.argpartition(-gates, k)[:, :k]←→logits.topk(k, dim=-1)

returns values + indices; GPU-native

gates/gates.sum(..., keepdims=True)←→F.softmax(top_logits, dim=-1)

softmax over the k chosen logits ≡ renormalized gates

W_r = np.random.randn(d, N)←→nn.Linear(d, N, bias=False)

tracked for autograd, weights live on the right device

Gotchas

Forgetting to renormalize: if you mask out the non-top-k gates but don't renormalize, your total gate mass drops below 1 and the MoE layer scales its output down by a random factor for every token. The bouncer waved the right names through but forgot to add up to 1. The residual stream starts drifting and you blame the optimizer.

Dense routing at inference: a classic copy-paste bug. You train with top-k, evaluate with the whole softmax (because you switched to model.eval() and forgot a branch). Suddenly inference is N/k× slower and the metrics look great because every expert is helping. Check the FLOPs at eval time.

Dropping noisy top-k too early: without noise the bouncer locks onto a single name within a few hundred steps. Keep σ > 0 for at least the first epoch, then decay toward zero. If you see a wildly unbalanced expert-usage histogram in the first 1k steps, the noise is almost always the culprit.

k = N: the soft-routing trap. Someone on the team “just wanted to try dense” and set k = n_experts. Congratulations, you removed the bouncer — you have a dense model that costs the compute and a router you're not using.

Watch an expert eat the batch

Build a Switch-style MoE (k=1) with 8 experts and 64-dim tokens. Use the TopKRouter above but set noise_std=0.0 to disable noisy top-k. Train it on a toy sequence-modeling task (or even random targets) for 1000 steps.

At every step, compute the fraction of tokens in the batch routed to each of the 8 experts — call it load[i]. Plot load over training. Without noise and without a load-balancing loss, you will almost certainly watch one name on the list see its load rise to near 1.0 while the other seven starve. That's router collapse, and it's what the next lesson exists to fix.

Bonus: re-run with noise_std=1.0. The collapse softens. Re-run with noise plus a load-balancing penalty (pre-view of the next lesson) and the loads stay roughly uniform across experts for the whole run.

What to carry forward. The router is the bouncer at the door: a single Linear(d_model, n_experts) plus a softmax. Top-k picks the k names it waves through, renormalizes their gates, and routes each token to only those experts. k=1 is Switch (a hard pick), k=2 is the current default (top expert plus a backup), k>2 and the room gets crowded while the “free lunch” disappears. Noisy top-k is non-optional early in training — without it the bouncer collapses onto a single name. The two most common bugs are forgetting to renormalize the chosen gates, and accidentally running dense routing at inference because you forgot the top-k branch.

Next up — Load Balancing Loss. Noisy top-k buys you exploration, but it doesn't guarantee the names on the list end up with similar workloads. The load-balancing auxiliary loss does: a small penalty that pushes the bouncer toward a roughly uniform distribution over experts. Without it a few experts carry the whole model and the rest are dead weight. With it you get what MoE actually promises — N experts, each learning its own specialty. After that, one more lesson on expert parallelism — how to shard the room itself across GPUs — and this section closes. Then: a whole new room. Diffusion Models, starting with denoising intuition.

References