Training Loop
Forward, loss, backward, step — the four-line core.
Every model that has ever been trained, anywhere, runs on a four-beat metronome. Forward. Loss. Backward. Step. One, two, three, four, repeat. A tiny MLP learning to fit a line. A 70B-parameter LLM burning through a GPU pod for six weeks. Same four beats. Same order. Different scale.
You've already built every one of those beats by hand. The forward pass is the model computing a prediction. The loss is one scalar saying how wrong that prediction is. The backward pass is backpropagation walking the chain rule to fill .grad. The step is gradient descent nudging the parameters against their gradients. This lesson is the ritual that wires them together into something you can run a million times without thinking about it.
The training loop is not incidentally the heartbeat of modern ML. Structurally, it is the only computation that matters. Schedulers, gradient clipping, mixed precision, distributed backends, checkpointing — all of it is ornamentation on the metronome's four beats. The rhythm is identical across every architecture, every framework, every scale.
┌────── inner loop: one batch ──────┐
for epoch in range(EPOCHS): # 1. outer loop — passes over the dataset
for batch in dataloader: # 2. inner loop — one batch at a time
optimizer.zero_grad() # clear .grad buffers
yhat = model(batch.x) # forward pass (build graph)
loss = criterion(yhat, batch.y) # scalar loss
loss.backward() # walk the graph, fill .grad
optimizer.step() # θ ← θ − α·∇L
One epoch = one full pass over the training set.
One step = one batch processed = one parameter update.
If dataset has N examples and batch size is B, one epoch is ⌈N/B⌉ steps.Press play below. A tiny linear regression trains itself in your browser, one batch at a time, following the four-beat pattern above exactly. The sparkline shows per-batch loss; the dashed line is the epoch-averaged value. Shift the batch size and the curve smoothness changes — which is the topic of the next widget.
I do not care what your model is. I do not care whether it has six parameters or seventy billion. I do one thing: draw a batch, run forward, compute loss, run backward, step the optimizer. I will do this a million times without complaint. Give me good data, a sensible learning rate, and a working forward pass, and I will produce a trained model. Get any of those wrong and I will produce garbage, with equal enthusiasm.
Notice what's moving inside each beat of the metronome. The model never sees the whole dataset at once; it sees a batch. Processing batches — rather than computing the exact gradient over the entire training set — comes down to three things, each one damning on its own:
- Memory. A modern LLM's full dataset is trillions of tokens; the activations for even a thousand examples won't fit on a GPU. Batches are what's physically possible.
- Compute throughput. GPUs are matrix processors. Processing 64 examples simultaneously is maybe 0.5% slower than processing one, because the matmul machinery is saturated either way. Full-batch training leaves hardware idle.
- Generalization. The noise introduced by sampling a different batch each step is, empirically, a regulariser. Stochastic gradients help networks escape sharp local minima that generalise poorly.
Full-batch gradient (what we want, can't afford):
∇L(θ) = (1/N) · Σ_{i=1..N} ∇ℓ(θ; xᵢ, yᵢ)
Minibatch gradient (what we actually compute):
∇L_B(θ) = (1/|B|) · Σ_{i ∈ B} ∇ℓ(θ; xᵢ, yᵢ)
Key property: E[∇L_B] = ∇L for B drawn uniformly at random.
So each minibatch step is, on expectation, a step in the true descent direction —
plus mean-zero noise whose variance shrinks as batch size grows.Variance of the minibatch gradient is proportional to 1 / |B|. Double the batch size and the noise halves. The curves below make it visceral: three training runs on the same loss surface, different batch sizes.
I am the single hyperparameter that changes everything underneath and nothing on top. Pick me too small and you are noise-limited. Pick me too large and you burn memory and lose the regularisation from stochasticity. Pick me a power of two between 32 and 512 for vision, between 1M and 4M tokens for LLMs, and you will be fine.
Now write the metronome from scratch. Three layers, same four beats at each level of the stack you'll meet in practice. The pure Python version does every beat by hand; the NumPy version vectorises the arithmetic; the PyTorch version cedes the bookkeeping to the library and leaves you with the four-line core you'll write from memory for the rest of your career.
import random
random.seed(0)
# Toy dataset: y = 1.7x - 0.4 + noise
data = [(x := random.gauss(0, 1), 1.7 * x - 0.4 + random.gauss(0, 0.3)) for _ in range(128)]
w, b = 0.0, 0.0
lr = 0.05
BATCH = 16
for epoch in range(5):
random.shuffle(data) # ← critical
epoch_loss = 0.0
for start in range(0, len(data), BATCH):
batch = data[start:start + BATCH]
# zero_grad + forward + loss + backward + step, all by hand
gw = 0.0; gb = 0.0; loss = 0.0
for x, y in batch:
yhat = w * x + b # forward
err = yhat - y # loss pre-sum
loss += err * err # MSE
gw += 2 * err * x # backward
gb += 2 * err
loss /= len(batch); gw /= len(batch); gb /= len(batch)
w -= lr * gw # step
b -= lr * gb
epoch_loss += loss
epoch_loss /= (len(data) / BATCH)
print(f"epoch {epoch} batch_avg_loss={epoch_loss:.4f}")epoch 0 batch_avg_loss=1.3245 epoch 1 batch_avg_loss=0.7123 epoch 2 batch_avg_loss=0.3881 epoch 3 batch_avg_loss=0.2094 epoch 4 batch_avg_loss=0.1138
import numpy as np
rng = np.random.default_rng(0)
N, BATCH, EPOCHS, LR = 128, 16, 5, 0.05
x = rng.normal(size=(N,))
y = 1.7 * x - 0.4 + rng.normal(scale=0.3, size=(N,))
w, b = np.array(0.0), np.array(0.0)
for epoch in range(EPOCHS):
perm = rng.permutation(N) # shuffle each epoch
losses = []
for start in range(0, N, BATCH):
idx = perm[start:start + BATCH]
xb, yb = x[idx], y[idx]
yhat = w * xb + b # forward (vectorised)
err = yhat - yb
loss = (err * err).mean() # MSE
gw = (2 * err * xb).mean() # backward
gb = (2 * err).mean()
w -= LR * gw # step
b -= LR * gb
losses.append(loss)
print(f"epoch {epoch} batch_avg_loss={np.mean(losses):.4f}")random.shuffle(data)←→perm = rng.permutation(N)— shuffle indices, not the dataset — cache-friendly
for x, y in batch: ...←→xb = x[idx] # fancy indexing— whole batch slice, no Python loop
gw += 2*err*x←→(2 * err * xb).mean()— vectorised gradient, averaged over the batch
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
torch.manual_seed(0)
x = torch.randn(128, 1)
y = 1.7 * x - 0.4 + 0.3 * torch.randn_like(x)
ds = TensorDataset(x, y)
loader = DataLoader(ds, batch_size=16, shuffle=True) # DataLoader = shuffling + batching
model = nn.Linear(1, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
for epoch in range(5):
losses = []
for xb, yb in loader: # one batch per iter
optimizer.zero_grad() # 1
yhat = model(xb) # 2
loss = criterion(yhat, yb) # 3
loss.backward() # 4
optimizer.step() # 5
losses.append(loss.item())
print(f"epoch {epoch} avg_loss={sum(losses)/len(losses):.4f}")epoch 0 avg_loss=1.2833 epoch 1 avg_loss=0.6942 epoch 2 avg_loss=0.3747 epoch 3 avg_loss=0.2021 epoch 4 avg_loss=0.1096
for start in range(0, N, BATCH): ... idx = perm[...]←→for xb, yb in loader:— DataLoader handles batching + shuffling + worker-parallel loading
explicit forward + grad + step←→the five one-liners— autograd + nn.Module + optim own the bookkeeping
manual RNG for reproducibility←→torch.manual_seed(0) plus DataLoader generator— covers dataset shuffling + layer init + dropout + augmentations
Look at the inner loop of layer 3. Five lines — well, four beats plus a log. zero_grad, forward, loss, backward, step. That is the metronome. You will type those exact five lines thousands of times. You will write them in your sleep. Every “production training script” on GitHub is those five lines with a thousand more wrapped around them for logging, checkpointing, and distributed coordination.
Now the ways the metronome breaks. Most training bugs are not subtle errors deep in a model; they're one of the four beats played in the wrong order, or skipped entirely. The first beat has to be zero_grad. Backward has to come before step. Step has to come before the next forward. Miss any of that and the loop still runs — it just doesn't train.
Forgetting zero_grad — the dropped downbeat. PyTorch accumulates gradients by default. Skip the zero and step N does gradient descent on the sum of gradients from step N plus every previous step. By step 10 the effective learning rate is 10× what you asked for. Training blows up immediately, and it looks like an exploding-gradient problem when it's actually a forgotten beat.
Stepping before backward. Calling optimizer.step() before loss.backward() applies whatever gradients are currently in .grad — on iteration one, that's zero, so nothing moves. On later iterations it's last step's gradient, so you train on stale signal. Silent, and wrong.
Not shuffling between epochs. Without shuffle=True on the DataLoader, every epoch sees batches in the same order. Gradients become correlated and training collapses to a meaningful-looking but biased solution. Always shuffle the training set; never shuffle the validation set.
Reusing a drained iterator. A DataLoader is exhausted after one full pass. Wrapping the outer loop in for batch in iter(loader) inside an epoch loop is wrong — the inner call gets a new iterator each epoch, which is what you want. Don't cache the iterator.
Calling .item() inside the loss accumulator. loss.item() forces a CPU sync, which kills throughput on GPU. Either accumulate the tensor (keep it on GPU) or call .item() only when logging.
Close this page. Open a new notebook. Without referring back, write a complete training loop for MNIST — nn.Linear(784, 10), CE loss, SGD at lr=0.1, 5 epochs, batch size 64. Hit > 90% test accuracy.
If you can do this, you've internalised the metronome and everything in the rest of this curriculum is a variation on it. If you can't, come back here and read the layer-3 code one more time. It should feel familiar, not copied.
What to carry forward. The training loop is four beats — forward, loss, backward, step — preceded by zero_grad and wrapped in two loops (epochs over batches). Minibatching is the default because of memory, throughput, and generalisation, all at once. DataLoader handles shuffling, batching, and async I/O for you. Scale everything else around the metronome and leave the metronome itself alone.
Next up — Training Diagnostics. The metronome is running. The loss is dropping — or it isn't, or it's dropping too slowly, or it went to NaN on epoch three. How do you know the model is actually learning? The next lesson is about reading the vital signs — loss curves, gradient norms, parameter statistics — well enough to distinguish “converging slowly” from “silently diverging”, which is a skill that will save you more time than any single algorithmic improvement.
- [01]Goodfellow, Bengio, Courville · MIT Press, 2016
- [02]Léon Bottou · COMPSTAT 2010
- [03]Zhang, Lipton, Li, Smola · d2l.ai
- [04]Andrej Karpathy · karpathy.github.io, 2019