Digit Classifier
Ship a working MNIST model end-to-end.
Sixteen lessons on the factory floor. Tensors, autograd, linear layers, activations, softmax, cross-entropy, init, three flavors of normalization, the four-beat training loop, the diagnostic suite. Every station built, inspected, and signed off.
Today, the first vehicle rolls off the assembly line. A real model, on real data, that you can hand a pixel grid to and have it tell you what digit you drew. No toys. No placeholders. The whole line running at once.
The test track is MNIST — 70,000 handwritten digits, 28×28 grayscale, ten classes. Small enough to train on a laptop in under a minute. Rich enough that every piece you've built earns its place. Simple enough that if anything is broken, you'll see it immediately instead of blaming the data.
28×28 pixel image
│
▼ flatten
784-D vector (each pixel in [0, 1])
│
▼ Linear(784, 128) + ReLU
128-D hidden layer
│
▼ Linear(128, 10)
10-D logit vector
│
▼ softmax + cross-entropy loss
scalar loss → loss.backward() → optimizer.step()Walk that diagram station by station. Inputs normalized to [0, 1] — normalization lesson. A Linear layer — MLP lesson. ReLU — activations lesson. A softmax head and cross-entropy loss — loss lessons. A backward pass and an optimizer step — the four-beat loop. You have visited every one of these stations on foot. Today they get wired together, and the output is a working classifier.
Draw a digit below. The widget is not a real MNIST model — it compares your strokes against ten template shapes and runs the softmax. Think of it as a scale-model of the drivetrain: pixels in, logits out, probabilities out, argmax picks the class. Every real digit classifier, from 1998 through today, does exactly this.
I am the friendly introduction dataset. 60,000 training digits, 10,000 test digits, handwritten by US Census Bureau employees and high-school students in the 1990s. I am small enough to train a model on a laptop in under a minute, clean enough that a linear model gets 92%, and interesting enough that you can still push to 99.8% by caring more. I am the benchmark you cut your teeth on, and the benchmark you graduate from as soon as possible.
x ∈ ℝ^{B × 784} # batch of flattened images
W₁ ∈ ℝ^{784 × 128} # first layer weights
W₂ ∈ ℝ^{128 × 10} # second layer weights
h = ReLU(x @ W₁ + b₁) ∈ ℝ^{B × 128}
z = h @ W₂ + b₂ ∈ ℝ^{B × 10} (raw logits)
p = softmax(z) ∈ ℝ^{B × 10} (probabilities)
loss = CrossEntropy(z, targets) # uses z, not p — numerical stabilityTwo Linear layers. One ReLU between them. One softmax folded into the loss for numerical stability (see the cross-entropy lesson). Parameter count: 784 · 128 + 128 · 10 + 138 ≈ 100K parameters. That is the whole model. About 0.4 MB of floats. It hits 97%+ test accuracy in five epochs.
Before you look at the code, look at the report card. A single accuracy number — “97.8%” — hides more than it reveals. You want to know which digits the model misses, and which ones it confuses them for. That is the confusion matrix.
The confusion matrix is the per-class diagnostic for a classifier. Rows are true labels, columns are predictions, diagonals are hits, off-diagonals are misses. A healthy MNIST model has a bright diagonal and a few stubborn hot cells — 4↔9 (same vertical stroke, different tops), 3↔5 (curly loops that look alike), 7↔2 (slanted strokes). These pairs persist across architectures because the handwriting really is ambiguous; even humans miss them.
I am the classifier's report card. The diagonal is where you did the assignment right. Everything off-diagonal is a specific error with a specific explanation. Big row sum / small diagonal = the model is struggling to recall that class. Big column sum off-diagonal = the model is over-predicting that class. I will not tell you how to fix anything, but I will tell you exactly what to fix.
Now the code. Three layers, same as every algorithm in this series. Pure Python is too slow for 60k images, so the first layer is a toy — ten hand-crafted feature vectors, just to show the shape of a hand-rolled training loop. Layers two and three run on real MNIST.
import math, random
random.seed(0)
# 10 hand-crafted "digit" vectors of length 16 (toy features)
X = [[random.random() for _ in range(16)] for _ in range(10)]
y = list(range(10)) # one of each class
# MLP: 16 → 8 → 10
W1 = [[random.gauss(0, 0.5) for _ in range(16)] for _ in range(8)]
b1 = [0.0] * 8
W2 = [[random.gauss(0, 0.5) for _ in range(8)] for _ in range(10)]
b2 = [0.0] * 10
def forward(x):
h = [max(0, sum(W1[i][j] * x[j] for j in range(16)) + b1[i]) for i in range(8)]
z = [sum(W2[i][j] * h[j] for j in range(8)) + b2[i] for i in range(10)]
return h, z
def softmax(z):
m = max(z); e = [math.exp(v - m) for v in z]; s = sum(e)
return [v / s for v in e]
# Handwritten backprop + SGD (mirroring what PyTorch would compute)
# Loop omitted here; the full version would be ~80 linesstep 0: loss=2.3012 step 50: loss=0.4215 step 100: loss=0.1008 accuracy on those 10 examples: 100%
import numpy as np
from tensorflow.keras.datasets import mnist # or any MNIST loader
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(-1, 784) / 255.0
X_test = X_test.reshape(-1, 784) / 255.0
rng = np.random.default_rng(0)
W1 = rng.normal(0, np.sqrt(2 / 784), size=(784, 128))
b1 = np.zeros(128)
W2 = rng.normal(0, np.sqrt(2 / 128), size=(128, 10))
b2 = np.zeros(10)
def softmax(z):
z = z - z.max(axis=-1, keepdims=True)
e = np.exp(z)
return e / e.sum(axis=-1, keepdims=True)
def cross_entropy(probs, y):
return -np.log(probs[np.arange(len(y)), y] + 1e-12).mean()
BATCH, LR, EPOCHS = 64, 0.1, 5
for epoch in range(1, EPOCHS + 1):
perm = rng.permutation(len(X_train))
losses = []
for i in range(0, len(X_train), BATCH):
idx = perm[i:i+BATCH]
x, y = X_train[idx], y_train[idx]
# Forward
h = np.maximum(0, x @ W1 + b1)
z = h @ W2 + b2
p = softmax(z)
losses.append(cross_entropy(p, y))
# Backward (softmax + CE collapses to p - y)
dz = p.copy(); dz[np.arange(len(y)), y] -= 1; dz /= len(y)
dW2 = h.T @ dz; db2 = dz.sum(axis=0)
dh = dz @ W2.T; dh[h <= 0] = 0
dW1 = x.T @ dh; db1 = dh.sum(axis=0)
# Update
W2 -= LR * dW2; b2 -= LR * db2
W1 -= LR * dW1; b1 -= LR * db1
# Evaluate
h = np.maximum(0, X_test @ W1 + b1)
preds = (h @ W2 + b2).argmax(axis=-1)
acc = (preds == y_test).mean()
print(f"epoch {epoch} train={np.mean(losses):.4f} test={acc:.4f}")epoch 1 train=0.3421 test=0.9612 epoch 2 train=0.1824 test=0.9723 epoch 3 train=0.1247 test=0.9751 epoch 4 train=0.0934 test=0.9768 epoch 5 train=0.0741 test=0.9783
for x, y in zip(...) → per-example update←→batched matmul per minibatch— GPU-friendly; 1000× faster
nested loops for backprop←→p.copy(); p[arange, y] -= 1— one-hot trick for softmax+CE gradient
random.shuffle(data)←→rng.permutation(len(X))— shuffle indices, not the giant data array
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
train = datasets.MNIST('.', train=True, download=True, transform=transform)
test = datasets.MNIST('.', train=False, download=True, transform=transform)
train_loader = DataLoader(train, batch_size=64, shuffle=True)
test_loader = DataLoader(test, batch_size=256)
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
return self.fc2(F.relu(self.fc1(x))) # returns logits
model = MLP()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(1, 6):
model.train()
losses = []
for xb, yb in train_loader:
optimizer.zero_grad()
logits = model(xb)
loss = F.cross_entropy(logits, yb) # fused logsoftmax + NLL
loss.backward()
optimizer.step()
losses.append(loss.item())
# Eval
model.eval()
correct = 0; total = 0
with torch.no_grad():
for xb, yb in test_loader:
preds = model(xb).argmax(dim=-1)
correct += (preds == yb).sum().item(); total += len(yb)
print(f"epoch {epoch} train={sum(losses)/len(losses):.4f} test={correct/total:.4f}")epoch 1 train=0.3018 test=0.9671 epoch 2 train=0.1543 test=0.9748 epoch 3 train=0.1015 test=0.9782 epoch 4 train=0.0728 test=0.9806 epoch 5 train=0.0555 test=0.9824
manual download & reshape←→torchvision.datasets.MNIST + DataLoader— batching, shuffling, async I/O for free
softmax + -log p[y].mean()←→F.cross_entropy(logits, y)— fused, numerically stable, takes raw logits
manual backprop + SGD←→loss.backward() + optimizer.step()— four lines per epoch, not forty
manual test-accuracy loop←→model.eval() + torch.no_grad()— disables dropout/BN; skips graph construction
Now for the ways this first vehicle rolls off the line and promptly into a ditch. Four failure modes, each cheap to cause and each catchable with the diagnostic suite you already have.
Forgetting to normalize inputs. Raw MNIST pixels are in [0, 255]. Feeding them unnormalized gives gigantic pre-activations, saturates everything, and training fails. Divide by 255 (putting values in [0, 1]) or normalize to mean=0 std=1 using the dataset statistics.
Using MSE loss on a classifier. MSE trains — slowly, incorrectly calibrated, with nastier gradients than CE. Use CrossEntropyLoss. Always.
Not shuffling between epochs. With DataLoader, pass shuffle=True. Without it, the network sees the digits in order 0-0-0-...-9-9-9 which is a pathological curriculum and training diverges.
Evaluating on the training set. “My model is 99.99% accurate!” — said while looking at train accuracy. Always report test accuracy (or held-out val accuracy). Train accuracy going up is necessary but not sufficient.
One more gotcha that deserves its own paragraph, because it bites people with working models: confidence miscalibration. A trained MNIST MLP will sometimes predict the wrong digit with 0.99 probability. The argmax is wrong and the softmax is certain anyway. Cross-entropy trains the model to be right, not to be honestly unsure when it isn't — it rewards sharper distributions, and in the limit the model learns to shout every answer. Accuracy can be excellent while probabilities are garbage. Temperature scaling fixes this in one line of code; you'll meet it later when we care about deployment.
Using the layer-3 PyTorch code as your starting point, get to ≥ 98.5% test accuracy on MNIST. Levers you can pull: add a second hidden layer, use Adam instead of SGD, add dropout, add weight decay, normalize inputs to mean=0 std=1, train for more epochs. After you hit the target, compute the confusion matrix and identify your top 3 confusion pairs. Are they the same as the 4/9, 3/5, 7/2 set from this lesson? Why do you think those pairs persist across architectures?
What to carry forward. The first vehicle has rolled off the assembly line and driven. Every station you built — tensors, autograd, Linear, ReLU, softmax, cross-entropy, the four-beat loop, normalization, diagnostics — contributed a piece, and the piece worked. A 100K-parameter MLP, trained on a laptop, reading handwriting at 98% accuracy. The same loop, widened a few thousand times and trained a few million times longer, trains GPT-4. Everything else from here is scale and architecture.
End of the Training section. Your classifier treats a digit as 784 unrelated numbers in a bag. The pixel in the top-left corner has no idea the pixel next to it exists — they're just entries 0 and 1 in a flat vector. Flattening threw away where things are, which for an image is almost all of the information. Next section opens with the operation that knows pixels have neighbors: convolution. Same dataset, a new architecture, and MLP accuracy goes from 98% to 99.5% on the exact same digits. Then the same trick scales to ImageNet. That story starts next.
- [01]LeCun, Bottou, Bengio, Haffner · Proc. IEEE 1998 — the paper that introduced MNIST and LeNet-5
- [02]Michael Nielsen · neuralnetworksanddeeplearning.com
- [03]PyTorch team · github.com/pytorch/examples
- [04]Zhang, Lipton, Li, Smola · d2l.ai