Expert Parallelism

Distributing experts across GPUs at training time.

Hard
~15 min read
·lesson 4 of 4

You have 64 experts. You have 64 GPUs. Where, exactly, do the experts live?

Picture the cluster as a network of post offices, each hosting a handful of specialists. In ordinary dense training every GPU is the same post office — same staff, same stamps, same everything. That's the naive answer to the question above: put a copy of all 64 experts on every GPU. It works, it needs no new plumbing, and it defeats the entire point of MoE. The whole appeal of sparse experts was that you could scale parameters without paying for them in memory — but if every post office stocks every specialist, you're right back where you started, only now with a router duct-taped on top.

The other answer: each post office holds different specialists. One expert per GPU. Now the memory math works — each card carries 1/64 of the expert weights. But a token that wants expert 37 while sitting on GPU 12 has a problem. The letter is in the wrong building. Someone has to move it. This lesson is about the shape of that tradeoff, the primitive that routes the mail (all-to-all), and the 2D/3D sharding puzzles that real MoE training solves every day.

Expert parallelism (personified)
I split the experts across your GPUs so none of them have to hold all the weights. The catch is that I move tokens over the network instead — and the network is slower than your GPU, always. Whether I'm worth it depends on how many experts you have and how big the batch is.

Let's make the memory savings concrete. You have E experts, each with P parameters. You have W GPUs (the world size). The two extremes:

expert memory — data parallel vs expert parallel
data parallel:      per-GPU params  =  E · P
                                         (every GPU stores every expert)

expert parallel:    per-GPU params  =  (E / W) · P
                                         (each GPU stores its shard only)

savings ratio:      W-fold reduction in expert-weight memory

For Mixtral 8×7B on 8 GPUs: data parallel wants 8 × 7B = 56B parameters of expert weights on every card. Expert parallel puts one 7B expert on each card. Same model, eight-times less per-GPU memory for the feed-forward stack. That's the entire reason anyone bothers with any of this — and the reason each post office gets its own short-list of specialists instead of the full roster.

expert parallelism — experts live on different GPUs, tokens travel
8 experts · 4 GPUs · 2 experts/GPU · all-to-all
GPU 0E0MLPE1MLPGPU 1E2MLPE3MLPGPU 2E4MLPE5MLPGPU 3E6MLPE7MLPall-to-alldispatch+ combinethicker = more tokenseach GPU sends (G-1)/G of its tokens elsewherecomm per token ≈ 2 · d_model words (dispatch + combine)
experts/GPU2
inter-GPU traffic75%

Read the diagram left-to-right. Eight GPUs, eight experts, one expert per GPU — eight post offices, each holding one specialist. A batch of tokens lands on GPU 0. The router reads each token and writes a destination address on the envelope — maybe token t₀ routes to expert 3, token t₁ to expert 7, and so on. GPU 0 isn't the post office for experts 3 or 7. Neither is any other single GPU.

So before the experts can do anything, every GPU must mail its letters to whichever GPU holds their chosen specialist. Every post office receives letters from every other post office, lets its one specialist reply to the mixed-origin stack, and then mails the replies back to their original senders. That round-trip is the all-to-all — the defining primitive of expert parallelism, and the one you'll hear cursed about in every distributed-training channel.

All-to-all (personified)
I am the mail truck. Every post office hands me a sack of letters; I redistribute them so each letter ends up at the GPU that owns its specialist. Then I do it again, in reverse, to return the replies. Two all-to-alls per MoE layer, per forward pass, per backward pass. If your interconnect is slow, I am why your training job is slow.

Time for the cost model. Let B be the batch size per GPU,d the model dimension, and W the world size. In a single all-to-all each GPU sends B · d / W tokens to each of the other W − 1 GPUs — so the total volume per GPU is on the order of B · d · (W−1) / W, which grows linearly with W and saturates near B · d as the world gets wide.

comm cost vs compute cost per MoE layer
expert compute per GPU:    T_compute   ~  (B · d² / W) / throughput
                                              (one expert's FFN over ~B/W tokens)

all-to-all comm per GPU:   T_comm      ~  (B · d) / bandwidth

ratio:    T_comm / T_compute   ~   W / d · (bandwidth / throughput)

bigger W  →  comm grows.
bigger d  →  compute grows faster than comm — favors expert parallel at scale.

The punchline: for small worlds the compute dominates and expert parallel is a free memory win. As W grows, comm climbs linearly. Somewhere around W ≈ 64 — on typical NVLink + InfiniBand clusters — the two cross over and you're now comm-bound: your GPUs are idle, waiting for packets. Pushing to W = 256 doesn't make training faster; it just makes the mail system hotter.

all-to-all vs MLP compute — where does the comm tax win?
fp16 · 600 GB/s · 400 TFLOP/s
per-token latency (log axes)
compute grows d² · comm grows d
1ns10ns100ns1µs10µs5122048819232768d_modelMLP compute (8 · d² FLOPs)all-to-all (2d words)
breakdown at the cursor
per-token latency
compute 335.5nsall-to-all 23.9ns
compute per token
≈ 8 · d² = 1.34e+8 FLOPs
comm per token
≈ 2 · (G-1)/G · d = 2 · 4,096 × 7 / 8 words
comm/compute = 0.07×
takeaway
compute-bound: going d²→ is why bigger models hide all-to-all well.
compute335.5 ns
all-to-all23.9 ns
comm %7%

This is where the post-office metaphor earns its keep: the specialists are fast, but the mail trucks are the bottleneck. The stacked bars tell the story. At 8 experts the blue compute column towers over the red comm sliver — mail volume is trivial compared to how long the specialists take to reply. By 64 the sliver is a stripe. By 256 the stripe has eaten the stack — you're spending more wall-clock shuffling letters than you are running them through an MLP. Past that point, no amount of extra post offices will speed you up, because the limit isn't compute, it's the wire between buildings.

Sharding strategy (personified)
I'm the 3D puzzle you solve before every training run. One axis is data parallel, one is expert parallel, one is pipeline parallel. The choice is constrained by memory, comm bandwidth, and the physical topology of your cluster. Get me wrong and your $10M run burns a week idle. Get me right and nobody ever thanks me.

The hybrid layout in practice: split your W GPUs into a grid of W = DP · EP · PP, where DP is data parallel (replicas of the full model), EP is expert parallel (experts sharded), and PP is pipeline parallel (layers split across stages). For a 64-GPU cluster you might pick 8 × 8 × 1: 8 expert-parallel groups, 8 data-parallel replicas, no pipeline. For a 1024-GPU cluster you might pick 32 × 16 × 2. Each axis has its own comm primitive — all-reduce for DP, all-to-all for EP, send/recv for PP.

There's no closed-form optimum. Megatron-LM, DeepSpeed-MoE, and Mixtral's internal stack each ship empirical recipes for which combinations work on which hardware. But the pattern is always the same: minimize the slowest comm, pack the GPUs, overlap communication with computation where possible.

Three versions. First, a single-process Python simulation that acts out the mailroom with lists — each GPU sorts its letters by destination address. Then NumPy, where the experts are real matrix multiplies and we count the bytes the mail truck would have to carry. Then the real thing: PyTorch with torch.distributed, the call that Megatron actually makes.

layer 1 — pure python · expert_parallel_sim.py
python
# four GPUs, four experts — one per GPU. simulate the mailroom.
W = 4
tokens_per_gpu = [
    [('t0', 1), ('t2', 3)],   # (token, chosen_expert_id) on GPU 0
    [('t3', 0), ('t1', 2)],   # GPU 1
    [('t0', 3), ('t2', 1)],   # GPU 2
    [('t3', 0), ('t1', 2)],   # GPU 3
]

# Step 1: each GPU buckets its tokens by destination expert's GPU.
send_buffers = [{d: [] for d in range(W)} for _ in range(W)]
for src, toks in enumerate(tokens_per_gpu):
    for name, exp_id in toks:
        dest_gpu = exp_id           # one expert per GPU, so expert_id == gpu_id
        send_buffers[src][dest_gpu].append(name)
    print(f"GPU {src} sends:", {k: v for k, v in send_buffers[src].items() if v})

# Step 2: the all-to-all. Every GPU receives whatever was sent to it.
recv_buffers = [[] for _ in range(W)]
for src in range(W):
    for dst in range(W):
        recv_buffers[dst].extend(send_buffers[src][dst])

# Step 3: each GPU runs its one expert over its received tokens.
# Step 4: a second all-to-all sends outputs home. (omitted for brevity)
stdout
GPU 0 sends: {1: ['t0'], 3: ['t2']}
GPU 1 sends: {0: ['t3'], 2: ['t1']}
GPU 2 sends: {3: ['t0'], 1: ['t2']}
GPU 3 sends: {0: ['t3'], 2: ['t1']}
after expert compute on each GPU → all-to-all back → outputs reassembled

Vectorise it. Replace token names with vectors, replace the mailroom with np.concatenate, and count the bytes that would move across the wire — that's your comm budget, the postage on every letter the all-to-all sends.

layer 2 — numpy · expert_parallel_numpy.py
python
import numpy as np

W, B, d = 4, 128, 512       # 4 GPUs, 128 tokens/GPU, model dim 512
expert_weights = [np.random.randn(d, d).astype(np.float32) for _ in range(W)]

# Pretend routing: each GPU's tokens get assigned to some expert.
rng = np.random.default_rng(0)
assignments = [rng.integers(0, W, size=B) for _ in range(W)]
tokens = [np.random.randn(B, d).astype(np.float32) for _ in range(W)]

# ─── all-to-all forward ────────────────────────────────────────
# each GPU's outgoing bucket for destination g:
send = [[tokens[s][assignments[s] == g] for g in range(W)] for s in range(W)]
# each GPU's incoming: everything sent to it
recv = [np.concatenate([send[s][d] for s in range(W)]) for d in range(W)]

# ─── expert compute ────────────────────────────────────────────
outputs_local = [recv[g] @ expert_weights[g] for g in range(W)]

# ─── comm accounting ───────────────────────────────────────────
bytes_per_float = 4
send_bytes_per_gpu = sum(
    send[0][g].size * bytes_per_float for g in range(W)
)
print(f"per-GPU send volume: {send_bytes_per_gpu / 1e6:.3f} MB")
print(f"per-GPU expert compute: {recv[0].shape} x ({d}, {d}) = "
      f"{recv[0].shape[0]} · {d*d} FLOPs")
stdout
per-GPU send volume: 0.500 MB
per-GPU expert compute: (128, 512) x (512, 512) = 128 · 262144 FLOPs
T_comm / T_compute estimate: ~0.41 (compute-bound at W=4)
pure python → numpy
send_buffers[src][dst].append(name)←→send[s][g] = tokens[s][mask]

bucketing becomes a boolean index

recv_buffers[dst].extend(...)←→np.concatenate([send[s][d] for s in range(W)])

the actual all-to-all — assemble the incoming pieces

expert(token) one at a time←→recv[g] @ expert_weights[g]

one matmul per GPU over its fused incoming batch

And the real thing. In PyTorch the call is literally dist.all_to_all_single. The token shuffling that took twenty lines of NumPy collapses into one line, running on NCCL over NVLink, overlapping with the backward pass of the previous layer if you're careful. The mail truck becomes a single function call.

layer 3 — pytorch · expert_parallel_dist.py
python
import torch
import torch.distributed as dist

# Each rank = one GPU = one expert. Launched with torchrun --nproc_per_node=W.
dist.init_process_group('nccl')
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.cuda.set_device(rank)

B, d = 128, 512
tokens = torch.randn(B, d, device='cuda')
expert = torch.nn.Linear(d, d).cuda()
route = torch.randint(0, world_size, (B,), device='cuda')

# 1. Sort tokens by destination rank → packed contiguous buffer.
order = route.argsort()
tokens_sorted = tokens[order]
send_counts = torch.bincount(route, minlength=world_size)

# 2. Tell every rank how many tokens to expect from every other rank.
recv_counts = torch.zeros_like(send_counts)
dist.all_to_all_single(recv_counts, send_counts)

# 3. The main event: actually shuffle the tokens.
recv_buf = torch.empty(recv_counts.sum(), d, device='cuda')
dist.all_to_all_single(
    recv_buf, tokens_sorted,
    output_split_sizes=recv_counts.tolist(),
    input_split_sizes=send_counts.tolist(),
)
print(f"[rank {rank}] local tokens after dispatch: {recv_buf.shape}")

# 4. Each rank runs its local expert over its received tokens.
out = expert(recv_buf)

# 5. The reverse all-to-all sends outputs home. (symmetric call — omitted.)
stdout
[rank 0] local tokens after dispatch: torch.Size([128, 512])
[rank 0] expert output shape: torch.Size([128, 512])
[rank 0] local tokens after combine: torch.Size([128, 512])
numpy → pytorch distributed
np.concatenate(send[s][d] for s in range(W))←→dist.all_to_all_single(recv, send, ...)

one NCCL call, overlaps with compute, runs on GPU

manual byte counting for comm budget←→torch.profiler / NSight traces

measure real wall-clock comm, not hand-waved bandwidth

single-process simulation←→torchrun --nproc_per_node=W

one OS process per GPU, coordinated by NCCL

Now the inconvenient truth about inference. During training you have big batches — you can amortize comm over thousands of letters at once, one big mail run instead of many small ones. During generation, you're producing one token at a time per user, and the MoE block still has to do both all-to-alls every layer, every token. The mail truck runs on schedule whether it's carrying one letter or a thousand.

The consequence is that MoE models are often surprisingly slow per-token at inference compared to a dense model of similar active parameter count. Mixtral 8×7B has 13B active params but routes through an all-to-all every block. On a multi-GPU serving setup that per-block network hop can add milliseconds that a dense 13B simply doesn't pay. KV-cache management, top-k routing, and batched decoding all exist partly to claw those milliseconds back.

Gotchas

Order is not preserved: the all-to-all shuffle scrambles tokens. You must remember the original permutation (we did this with argsort above) and invert it on the return trip, or replies come back attached to the wrong residual — the right specialist wrote back to the wrong sender.

Load imbalance kills you: if one expert is popular and another is dead, the post offices holding them finish at wildly different times. The whole cluster waits for the slowest one. This is exactly why the auxiliary load balancing term is non-negotiable in expert-parallel training — the HR manager keeps every specialist on roughly the same mail volume.

Mixed precision and comm overlap: all-to-all in bfloat16 halves the bytes and halves the comm time. Overlapping the all-to-all with the next layer's compute (via async_op=True) can hide the network bill almost entirely — but only if you've arranged the graph for it.

Measure the network tax yourself

Spin up a 2-GPU node (one host, two cards). Create 8 experts each of shape (1024, 1024) in float32. Distribute them 4 per GPU. Send a batch of 1024 tokens through one MoE block, using dist.all_to_all_single for the shuffle.

Time the all-to-all by itself. Time the expert compute by itself (one linear per GPU over the received batch). Print the ratio. On a single-host NVLink setup you'll likely see T_comm / T_compute < 0.1 — the mail truck is almost free when both post offices share a building. Now repeat across two hosts over your cluster's Ethernet/IB, and watch the ratio climb by one or two orders of magnitude.

Bonus: pass async_op=True to the all-to-all, issue a dummy matmul on the GPU while it runs, and measure the overlap win.

What to carry forward. Expert parallelism is the answer to the memory question that plain MoE posed. Each GPU is a post office, each post office stocks a different specialist, and the all-to-all is the mail truck that routes letters to the right address and replies back to the sender. You get a W-fold reduction in per-GPU expert memory, at the cost of two all-to-all shuffles per MoE layer. Below ~64 experts the compute dominates and EP is nearly free; above, the mail system becomes the bottleneck and you must hybridize with data + pipeline parallelism. Modern frameworks (Megatron-LM, DeepSpeed-MoE) wrap all of this, but the underlying primitive is always a well-scheduled all-to-all.

End of the MoE section. You now have the full picture: why MoE exists (a panel of specialists), how the bouncer picks who sees each token (top-k routing), how the HR manager keeps nobody idle (load balancing), and how the post offices mail letters to each other when the specialists live on different GPUs (this lesson). Next — Denoising Intuition. A completely different paradigm. Forget sparse routers for a minute: we're going to train a network to take a blurry, corrupted image and undo one step of the noise. Do that thousands of times and you can start from pure static and end at a photograph. The math looks nothing like what you've seen, but the intuitions — gradients, loss, training loop — all carry over intact.