Training Loop

Forward, loss, backward, step — the four-line core.

Easy
~15 min read
·lesson 1 of 4

Every model that has ever been trained, anywhere, runs on a four-beat metronome. Forward. Loss. Backward. Step. One, two, three, four, repeat. A tiny MLP learning to fit a line. A 70B-parameter LLM burning through a GPU pod for six weeks. Same four beats. Same order. Different scale.

You've already built every one of those beats by hand. The forward pass is the model computing a prediction. The loss is one scalar saying how wrong that prediction is. The backward pass is backpropagation walking the chain rule to fill .grad. The step is gradient descent nudging the parameters against their gradients. This lesson is the ritual that wires them together into something you can run a million times without thinking about it.

The training loop is not incidentally the heartbeat of modern ML. Structurally, it is the only computation that matters. Schedulers, gradient clipping, mixed precision, distributed backends, checkpointing — all of it is ornamentation on the metronome's four beats. The rhythm is identical across every architecture, every framework, every scale.

                          ┌────── inner loop: one batch ──────┐

for epoch in range(EPOCHS):             # 1. outer loop — passes over the dataset
    for batch in dataloader:            # 2. inner loop — one batch at a time
        optimizer.zero_grad()           #      clear .grad buffers
        yhat = model(batch.x)           #      forward pass (build graph)
        loss = criterion(yhat, batch.y) #      scalar loss
        loss.backward()                 #      walk the graph, fill .grad
        optimizer.step()                #      θ ← θ − α·∇L

One epoch  =  one full pass over the training set.
One step   =  one batch processed  =  one parameter update.
If dataset has N examples and batch size is B, one epoch is ⌈N/B⌉ steps.
the training loop, in context

Press play below. A tiny linear regression trains itself in your browser, one batch at a time, following the four-beat pattern above exactly. The sparkline shows per-batch loss; the dashed line is the epoch-averaged value. Shift the batch size and the curve smoothness changes — which is the topic of the next widget.

a full training run, batch by batch
8 epochs · 8 batches/epoch · lr α = 0.08
loss per batch (faded = future)w = -0.279 · b = 0.607 · target w* = 1.7, b* = -0.4
last 10018.6946
batch
epoch1/8
batch1/8
batch loss18.6946
epoch mean18.6946
Training loop (personified)
I do not care what your model is. I do not care whether it has six parameters or seventy billion. I do one thing: draw a batch, run forward, compute loss, run backward, step the optimizer. I will do this a million times without complaint. Give me good data, a sensible learning rate, and a working forward pass, and I will produce a trained model. Get any of those wrong and I will produce garbage, with equal enthusiasm.

Notice what's moving inside each beat of the metronome. The model never sees the whole dataset at once; it sees a batch. Processing batches — rather than computing the exact gradient over the entire training set — comes down to three things, each one damning on its own:

  • Memory. A modern LLM's full dataset is trillions of tokens; the activations for even a thousand examples won't fit on a GPU. Batches are what's physically possible.
  • Compute throughput. GPUs are matrix processors. Processing 64 examples simultaneously is maybe 0.5% slower than processing one, because the matmul machinery is saturated either way. Full-batch training leaves hardware idle.
  • Generalization. The noise introduced by sampling a different batch each step is, empirically, a regulariser. Stochastic gradients help networks escape sharp local minima that generalise poorly.
stochastic gradients — unbiased estimators of the full gradient
Full-batch gradient (what we want, can't afford):
  ∇L(θ)   =   (1/N) · Σ_{i=1..N}   ∇ℓ(θ; xᵢ, yᵢ)

Minibatch gradient (what we actually compute):
  ∇L_B(θ) =   (1/|B|) · Σ_{i ∈ B}   ∇ℓ(θ; xᵢ, yᵢ)

Key property: E[∇L_B] = ∇L for B drawn uniformly at random.
So each minibatch step is, on expectation, a step in the true descent direction —
plus mean-zero noise whose variance shrinks as batch size grows.

Variance of the minibatch gradient is proportional to 1 / |B|. Double the batch size and the noise halves. The curves below make it visceral: three training runs on the same loss surface, different batch sizes.

loss curve smoothness — three batch sizes on the same problem
same data · same LR · different batch
full-batch converges smoothly but needs N forward passes per update · SGD is noisy but generalises · minibatch is the compromise every practitioner uses
step 0 / 300
full batch4.994
minibatch (32)4.994
SGD (batch=1)4.994
Batch size (personified)
I am the single hyperparameter that changes everything underneath and nothing on top. Pick me too small and you are noise-limited. Pick me too large and you burn memory and lose the regularisation from stochasticity. Pick me a power of two between 32 and 512 for vision, between 1M and 4M tokens for LLMs, and you will be fine.

Now write the metronome from scratch. Three layers, same four beats at each level of the stack you'll meet in practice. The pure Python version does every beat by hand; the NumPy version vectorises the arithmetic; the PyTorch version cedes the bookkeeping to the library and leaves you with the four-line core you'll write from memory for the rest of your career.

layer 1 — pure python · training_loop_scratch.py
python
import random
random.seed(0)

# Toy dataset: y = 1.7x - 0.4 + noise
data = [(x := random.gauss(0, 1), 1.7 * x - 0.4 + random.gauss(0, 0.3)) for _ in range(128)]

w, b = 0.0, 0.0
lr = 0.05
BATCH = 16

for epoch in range(5):
    random.shuffle(data)                                         # ← critical
    epoch_loss = 0.0
    for start in range(0, len(data), BATCH):
        batch = data[start:start + BATCH]
        # zero_grad + forward + loss + backward + step, all by hand
        gw = 0.0; gb = 0.0; loss = 0.0
        for x, y in batch:
            yhat = w * x + b                                      # forward
            err = yhat - y                                        # loss pre-sum
            loss += err * err                                     # MSE
            gw += 2 * err * x                                     # backward
            gb += 2 * err
        loss /= len(batch); gw /= len(batch); gb /= len(batch)
        w -= lr * gw                                              # step
        b -= lr * gb
        epoch_loss += loss
    epoch_loss /= (len(data) / BATCH)
    print(f"epoch {epoch}  batch_avg_loss={epoch_loss:.4f}")
stdout
epoch 0  batch_avg_loss=1.3245
epoch 1  batch_avg_loss=0.7123
epoch 2  batch_avg_loss=0.3881
epoch 3  batch_avg_loss=0.2094
epoch 4  batch_avg_loss=0.1138
layer 2 — numpy · training_loop_numpy.py
python
import numpy as np
rng = np.random.default_rng(0)

N, BATCH, EPOCHS, LR = 128, 16, 5, 0.05
x = rng.normal(size=(N,))
y = 1.7 * x - 0.4 + rng.normal(scale=0.3, size=(N,))
w, b = np.array(0.0), np.array(0.0)

for epoch in range(EPOCHS):
    perm = rng.permutation(N)                                    # shuffle each epoch
    losses = []
    for start in range(0, N, BATCH):
        idx = perm[start:start + BATCH]
        xb, yb = x[idx], y[idx]
        yhat = w * xb + b                                        # forward (vectorised)
        err = yhat - yb
        loss = (err * err).mean()                                # MSE
        gw = (2 * err * xb).mean()                               # backward
        gb = (2 * err).mean()
        w -= LR * gw                                             # step
        b -= LR * gb
        losses.append(loss)
    print(f"epoch {epoch}  batch_avg_loss={np.mean(losses):.4f}")
pure python → numpy
random.shuffle(data)←→perm = rng.permutation(N)

shuffle indices, not the dataset — cache-friendly

for x, y in batch: ...←→xb = x[idx] # fancy indexing

whole batch slice, no Python loop

gw += 2*err*x←→(2 * err * xb).mean()

vectorised gradient, averaged over the batch

layer 3 — pytorch · training_loop_pytorch.py
python
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

torch.manual_seed(0)
x = torch.randn(128, 1)
y = 1.7 * x - 0.4 + 0.3 * torch.randn_like(x)

ds = TensorDataset(x, y)
loader = DataLoader(ds, batch_size=16, shuffle=True)              # DataLoader = shuffling + batching

model = nn.Linear(1, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)

for epoch in range(5):
    losses = []
    for xb, yb in loader:                                         # one batch per iter
        optimizer.zero_grad()                                     # 1
        yhat = model(xb)                                          # 2
        loss = criterion(yhat, yb)                                # 3
        loss.backward()                                           # 4
        optimizer.step()                                          # 5
        losses.append(loss.item())
    print(f"epoch {epoch}  avg_loss={sum(losses)/len(losses):.4f}")
stdout
epoch 0  avg_loss=1.2833
epoch 1  avg_loss=0.6942
epoch 2  avg_loss=0.3747
epoch 3  avg_loss=0.2021
epoch 4  avg_loss=0.1096
numpy → pytorch
for start in range(0, N, BATCH): ... idx = perm[...]←→for xb, yb in loader:

DataLoader handles batching + shuffling + worker-parallel loading

explicit forward + grad + step←→the five one-liners

autograd + nn.Module + optim own the bookkeeping

manual RNG for reproducibility←→torch.manual_seed(0) plus DataLoader generator

covers dataset shuffling + layer init + dropout + augmentations

Look at the inner loop of layer 3. Five lines — well, four beats plus a log. zero_grad, forward, loss, backward, step. That is the metronome. You will type those exact five lines thousands of times. You will write them in your sleep. Every “production training script” on GitHub is those five lines with a thousand more wrapped around them for logging, checkpointing, and distributed coordination.

Now the ways the metronome breaks. Most training bugs are not subtle errors deep in a model; they're one of the four beats played in the wrong order, or skipped entirely. The first beat has to be zero_grad. Backward has to come before step. Step has to come before the next forward. Miss any of that and the loop still runs — it just doesn't train.

Gotchas

Forgetting zero_grad — the dropped downbeat. PyTorch accumulates gradients by default. Skip the zero and step N does gradient descent on the sum of gradients from step N plus every previous step. By step 10 the effective learning rate is 10× what you asked for. Training blows up immediately, and it looks like an exploding-gradient problem when it's actually a forgotten beat.

Stepping before backward. Calling optimizer.step() before loss.backward() applies whatever gradients are currently in .grad — on iteration one, that's zero, so nothing moves. On later iterations it's last step's gradient, so you train on stale signal. Silent, and wrong.

Not shuffling between epochs. Without shuffle=True on the DataLoader, every epoch sees batches in the same order. Gradients become correlated and training collapses to a meaningful-looking but biased solution. Always shuffle the training set; never shuffle the validation set.

Reusing a drained iterator. A DataLoader is exhausted after one full pass. Wrapping the outer loop in for batch in iter(loader) inside an epoch loop is wrong — the inner call gets a new iterator each epoch, which is what you want. Don't cache the iterator.

Calling .item() inside the loss accumulator. loss.item() forces a CPU sync, which kills throughput on GPU. Either accumulate the tensor (keep it on GPU) or call .item() only when logging.

Write the loop from memory, no Googling

Close this page. Open a new notebook. Without referring back, write a complete training loop for MNIST — nn.Linear(784, 10), CE loss, SGD at lr=0.1, 5 epochs, batch size 64. Hit > 90% test accuracy.

If you can do this, you've internalised the metronome and everything in the rest of this curriculum is a variation on it. If you can't, come back here and read the layer-3 code one more time. It should feel familiar, not copied.

What to carry forward. The training loop is four beats — forward, loss, backward, step — preceded by zero_grad and wrapped in two loops (epochs over batches). Minibatching is the default because of memory, throughput, and generalisation, all at once. DataLoader handles shuffling, batching, and async I/O for you. Scale everything else around the metronome and leave the metronome itself alone.

Next up — Training Diagnostics. The metronome is running. The loss is dropping — or it isn't, or it's dropping too slowly, or it went to NaN on epoch three. How do you know the model is actually learning? The next lesson is about reading the vital signs — loss curves, gradient norms, parameter statistics — well enough to distinguish “converging slowly” from “silently diverging”, which is a skill that will save you more time than any single algorithmic improvement.

References