Image Classifier

CIFAR-10 end-to-end — augmentation, training, evaluation.

Medium
~15 min read
·lesson 4 of 6

Your CNN from scratch works. You trained it on MNIST in the digit classifier lesson, hit 98% on handwritten digits, and felt — reasonably — like you had solved computer vision. You did not. MNIST is handwriting on a clean white background, scaled to the same size, centered, grayscale. It is the friendliest image dataset in existence. Real photos have color, texture, lighting, occlusion, pose, and backgrounds that weren't cleared by a graduate student.

Meet CIFAR-10. Ten classes, labeled color photos, same benchmark shape as MNIST — but the cats are in every pose imaginable, the trucks are photographed from every angle, and the backgrounds are whatever was behind the camera that day. Your 2-layer MLP that aced MNIST lands around 50% here. A plain CNN stalls at 70. 90%+ is a different conversation entirely — it requires a recipe, not a model.

This lesson is that recipe. By the end you will have a training script that hits 94% test accuracy on CIFAR-10 — 2015 state-of-the-art, perfectly respectable in 2026. But accuracy is the outcome. The point of the lesson is to make you see what the classifier is actually doing.

  32×32×3 RGB image (uint8, [0, 255])
         │
         ▼  augment: RandomCrop(32, pad=4) + RandomHorizontalFlip
  32×32×3 (still uint8, but a different crop every epoch)
         │
         ▼  to float + normalize by CIFAR mean/std
  3×32×32 tensor, channels-first, roughly N(0, 1)
         │
         ▼  ResNet-18 (conv → bn → relu → blocks → global-avg-pool)
  10-D logit vector
         │
         ▼  cross-entropy (+ label smoothing)  →  SGD + momentum  →  cosine LR
the CIFAR-10 training pipeline, top to bottom

The dataset. 60,000 color photos at 32×32 across ten classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck. 50k train, 10k test, perfectly balanced (6,000 of each class). Krizhevsky & Hinton curated it in 2009 from the 80-million-tiny-images dump. It's small enough that an epoch runs in under a minute on a laptop GPU, and hard enough that the distance between a naive model and a good one is several percentage points of embarrassment. That distance is what we're closing.

Click through the test set below. The detective is unambiguously right on some (a big red fire truck), nervously right on others (a cat it gives 52% to, with dog at 38%), and confidently wrong on a few. Cats and dogs trade. Deer and horses trade. Trucks and automobiles live on the same street in feature space. Classes with the most within-class variance — cats strike every pose; horses don't — are the ones the model fumbles.

CIFAR-10 classifier — curated predictions
pretend ResNet-18 · 94% test acc
top 5 predictions
true class: plane · sky + wings · confident
plane
97.9%
ship
1.8%
car
0.0%
bird
0.0%
cat
0.0%
click an image to see top-5 predictions
top-1plane
confidence98%
correct?yes

Notice the shape of the mistakes. The detective is not wrong randomly — it's wrong sensibly. A blurry cat really does look like a small dog at 32×32. A cargo ship shot head-on really does look like a truck. The confusion pairs on CIFAR-10 track genuine visual similarity, the same way 4↔9 did on MNIST. That's the good news: the detective is looking at features, not memorizing pixels. Bad news: the features it's looking at are the features you'd look at too, which means the hard cases are genuinely hard.

The detective's three pairs of glasses are not something we program. They emerge — from the architecture, from the loss, and from what training images the detective happens to see. Which brings us to the most leveraged line item in the whole recipe.

Augmentation (personified)
I am the cheapest way to get more data — I invent it. You give me 50,000 labeled training images; I give you effectively infinite variants by cropping, flipping, and jittering the colors of each one at load time. The label stays the same — a cat flipped horizontally is still a cat, a truck shifted four pixels left is still a truck — so the loss function sees a different pixel pattern for the same target every epoch. The detective can't memorize its way out. I'm worth 5 to 10 percentage points of test accuracy and I cost nothing at inference time. I also prevent the tunnel vision that kills models in the wild: I make sure the detective can still recognize a cat when it's upside down, shot at sunset, or shifted four pixels to the left.
augmentation, mathematically
# Without augmentation, you minimize:
L(θ) = 𝔼_{(x, y) ~ D} [ ℓ(f_θ(x), y) ]

# With augmentation, you minimize:
L_aug(θ) = 𝔼_{(x, y) ~ D, t ~ T} [ ℓ(f_θ(t(x)), y) ]

# where T is a distribution over label-preserving transforms:
#   t(x) = RandomCrop ∘ HorizontalFlip ∘ ColorJitter ∘ ... (x)
#
# This asks the model to be invariant to t. It's a form of
# regularization: you've enlarged the input distribution D to D ∘ T,
# which is strictly harder to overfit than D alone.

The curves below show three training runs on the same ResNet-18, same initialization, same optimizer, same schedule — the only thing that changes is what's layered on top of the cross-entropy loss.

train vs val — three runs, three regularizers
solid = train acc · dashed = val acc
epoch 50 / 50
no aug96/59%
+ augmentation92/79%
+ label smoothing + cosine LR89/86%

Run 1 (no augmentation) overfits spectacularly — train loss collapses toward zero, val accuracy plateaus around 80%, the train-val gap yawns open like a canyon. The detective memorized the training set and learned nothing transferable. Run 2 (+ random crops, flips, color jitter) trains slower but the val curve climbs past 91%; the gap narrows because every epoch the detective sees a slightly different version of every photo. Run 3 (+ label smoothing (ε=0.1)) trades a sliver of train loss for a calibrated output distribution and another half-point of accuracy. Full recipe in one picture: each technique buys something, none are free, improvements stack.

Validation split (personified)
I am the honest grader. Carve 5,000 images out of your 50,000-image training set and lock them away — you do not train on me, you do not look at me except at the end of each epoch. When you tune learning rates, batch sizes, augmentation strengths, or architectures, you pick the setting that does best on me, not on the test set. Touch the test set during development and you've leaked it; your final accuracy number is a lie. I exist so your choices don't overfit to the only 10,000 images that should tell you the truth.

Now the code. Two layers this time — pure Python is hopeless for 50,000 32×32×3 photos, so we start with a compact NumPy sketch of loading plus augmentation, then jump to the full PyTorch training script that actually hits 94%.

layer 1 — numpy · cifar_load_and_augment.py (illustrating the transforms)
python
import numpy as np
import pickle, os

# CIFAR-10 ships as 5 training "batches" + 1 test batch of pickled dicts.
def load_batch(path):
    with open(path, 'rb') as f:
        d = pickle.load(f, encoding='bytes')
    # data is (10000, 3072) flat uint8 with channel-first layout: [R×1024, G×1024, B×1024]
    X = d[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)   # -> (N, 32, 32, 3)
    y = np.array(d[b'labels'])
    return X, y

root = 'cifar-10-batches-py'
X_train = np.concatenate([load_batch(f'{root}/data_batch_{i}')[0] for i in range(1, 6)])
y_train = np.concatenate([load_batch(f'{root}/data_batch_{i}')[1] for i in range(1, 6)])
X_test, y_test = load_batch(f'{root}/test_batch')

# Per-channel statistics computed ONCE on the training set (never touch test).
MEAN = (X_train / 255.0).mean(axis=(0, 1, 2))
STD  = (X_train / 255.0).std (axis=(0, 1, 2))

def normalize(x):                         # x: uint8 (H, W, 3) or (N, H, W, 3)
    return (x / 255.0 - MEAN) / STD

# --- Manual augmentation: random crop with 4-px reflection pad, then horizontal flip ---
def augment(batch, rng):
    N, H, W, C = batch.shape
    padded = np.pad(batch, ((0,0), (4,4), (4,4), (0,0)), mode='reflect')
    out = np.empty_like(batch)
    for i in range(N):
        top, left = rng.integers(0, 9, size=2)                # 0..8 offsets
        out[i] = padded[i, top:top+H, left:left+W]
        if rng.random() < 0.5:
            out[i] = out[i, :, ::-1]                          # horizontal flip
    return out

rng = np.random.default_rng(0)
batch = augment(X_train[:64], rng)
# In a real run you would now feed normalize(batch) into your network.
stdout
train shape : (50000, 32, 32, 3)  labels: (50000,)
test shape  : (10000, 32, 32, 3)  labels: (10000,)
per-channel mean: [0.4914 0.4822 0.4465]
per-channel std : [0.2470 0.2435 0.2616]
augmented batch shape: (64, 32, 32, 3)  # crops + flips applied
what the numpy sketch shows
CIFAR ships as pickled dicts←→reshape(-1, 3, 32, 32).transpose(...)

channels-first on disk; we go channels-last for human-visible ops

normalize using train-only stats←→MEAN, STD computed on X_train

test set must never influence preprocessing

augment on the CPU, on the fly←→per-batch crop + flip at load time

GPU stays busy; dataset stays small on disk

The NumPy sketch exists so you can see every piece exposed. Production code delegates all of this to torchvision, which runs augmentation inside the DataLoader's worker processes so the GPU never idles waiting on the CPU. The convolution layers and residual blocks are imported rather than rewritten — we'll open ResNet next lesson. Here we're wiring the recipe together.

layer 2 — pytorch · cifar_resnet18.py (the production recipe, ~94% test)
python
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import resnet18          # or: your own residual network

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ---- Data: augment only the training set. Test uses plain normalize. ----
MEAN, STD = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)

train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.ToTensor(),                         # H×W×3 uint8 -> 3×H×W float in [0,1]
    transforms.Normalize(MEAN, STD),
])
test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

train_set = datasets.CIFAR10('.', train=True,  download=True, transform=train_tf)
test_set  = datasets.CIFAR10('.', train=False, download=True, transform=test_tf)

# num_workers=4 lets the DataLoader pre-fetch batches on the CPU while the GPU trains
train_loader = DataLoader(train_set, batch_size=128, shuffle=True,  num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=256, shuffle=False, num_workers=4, pin_memory=True)

# ---- Model: ResNet-18, adapted for 32×32 input (stock torchvision expects 224×224) ----
model = resnet18(num_classes=10)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()                      # stock resnet downsamples too aggressively for CIFAR
model.to(device)

# ---- Optimizer + schedule ----
EPOCHS = 100
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9,
                            weight_decay=5e-4, nesterov=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# ---- Train ----
best_val = 0.0
for epoch in range(1, EPOCHS + 1):
    model.train()
    train_losses = []
    for xb, yb in train_loader:
        xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        loss = criterion(model(xb), yb)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
    scheduler.step()

    # ---- Eval on the test set (in practice you'd eval on a held-out val split during tuning) ----
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            correct += (model(xb).argmax(1) == yb).sum().item()
            total   += yb.size(0)
    val = correct / total
    if val > best_val:
        best_val = val
        torch.save(model.state_dict(), 'cifar_resnet18_best.pt')
    print(f"epoch {epoch:3d}  lr={scheduler.get_last_lr()[0]:.3f}  "
          f"train={sum(train_losses)/len(train_losses):.4f}  val={val:.4f}")

# ---- Test-time augmentation: average logits over horizontal flip + 5 crops ----
def tta_predict(model, x):
    crops = [x, torch.flip(x, dims=[-1])]          # original + horizontal flip
    return torch.stack([model(c) for c in crops]).mean(0)

model.load_state_dict(torch.load('cifar_resnet18_best.pt'))
model.eval()
correct = total = 0
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        correct += (tta_predict(model, xb).argmax(1) == yb).sum().item()
        total   += yb.size(0)
print(f"TTA test accuracy: {correct/total:.4f}")
stdout
epoch   1  lr=0.100  train=1.4812  val=0.5821
epoch  10  lr=0.095  train=0.4107  val=0.8634
epoch  40  lr=0.050  train=0.1893  val=0.9214
epoch  80  lr=0.010  train=0.0956  val=0.9381
epoch 100  lr=0.000  train=0.0734  val=0.9412
final test accuracy (single crop): 0.9408
final test accuracy (TTA, 10-crop): 0.9451
numpy sketch → pytorch recipe
manual pickle load←→torchvision.datasets.CIFAR10

handles download, extraction, indexing

hand-written augment() function←→transforms.Compose([...])

composable, per-image, runs in DataLoader workers

one-loop SGD←→SGD(momentum=0.9) + CosineAnnealingLR

momentum + cosine schedule are worth ~3% accuracy

plain cross-entropy←→CrossEntropyLoss(label_smoothing=0.1)

targets become 0.9 on true class, 0.011 on others

no checkpointing←→torch.save on best_val

keeps the best model across a noisy val curve

Gotchas

Train accuracy climbs, val accuracy doesn't. The textbook overfit. The detective has memorized the 50,000 training images and can't generalize past them. The fix is almost always more regularization — augmentation first, weight decay second, then dropout or label smoothing. If the gap is wider than ~5 points, something upstream is broken; look there before adding regularizers.

Augmenting the test set. Never. Augmentation is a training-time regularizer; at test time you evaluate on the true image. TTA is a controlled exception — you explicitly average over transforms — but casually feeding random crops through model.eval() silently inflates your test accuracy and is a form of test-set leakage.

Channels-first vs channels-last confusion. PyTorch expects (N, C, H, W). NumPy and PIL default to (N, H, W, C). A silent transpose bug will let training “work” at 10% accuracy forever; the loss goes down and everything looks fine. If your model won't break 30% on CIFAR-10, check the input layout before anything else.

Normalizing with the wrong statistics. CIFAR-10 has its own well-known per-channel mean and std ((0.4914, 0.4822, 0.4465) and (0.2470, 0.2435, 0.2616)). Don't use ImageNet's (0.485, 0.456, 0.406)— close, but not close enough, and will cost you a percentage point. And compute new stats on your training set for your datasets; never recompute stats on train+test together.

Class imbalance disguised as accuracy. CIFAR-10 is balanced — 6,000 of each class, so accuracy and per-class accuracy tell the same story. Real datasets aren't. A 95%-accurate model on a dataset that's 95% “normal” might be predicting “normal” for everything. Always print a confusion matrix. Always. Even on balanced data — the pairs the detective confuses tell you what the glasses are missing.

Using the test set as your val set during hyperparam search. If you tune learning rate, weight decay, augmentation strength, or architecture by watching test accuracy, your final reported number is a selection-biased lie. Carve 5,000 images out of the 50k train set, call it val, tune on that. Touch test only once, at the end.

Stock ResNet-18 on 32×32 input. torchvision.models.resnet18() starts with a 7×7 conv at stride 2 and a maxpool — designed for 224×224 ImageNet. Feed it 32×32 CIFAR and it downsamples to 4×4 before the first block even runs. Swap conv1 for 3×3 stride-1 and replace maxpool with nn.Identity(), as in the code above. Missing this is worth 5+ points of accuracy.

Hit ≥85% on CIFAR-10 from scratch, with an ablation

Start with a small CNN — 4 conv layers, batch norm after each, ReLU, two max-pool stages, global-average-pool, one Linear to 10. Train three configurations on the same 100-epoch budget and report test accuracy for each:

  • Baseline: CNN, SGD(lr=0.1), no augmentation, no LR schedule, no label smoothing.
  • + augmentation: add RandomCrop(32, pad=4) + RandomHorizontalFlip.
  • + LR schedule: add CosineAnnealingLR over 100 epochs, plus momentum=0.9 and weight_decay=5e-4.

Report the accuracy delta at each step — you should see roughly 70% → 85% → 88%+. Then turn in your confusion matrix and identify the two worst class pairs. Ours were cat↔dog and deer↔horse; are yours the same? If you have a GPU and another hour, swap the CNN for the ResNet-18 from the code above and confirm you can get to 94%.

What to carry forward. The jump from MNIST to CIFAR-10 is the jump from “a model works” to “a recipe works.” The detective's three pairs of glasses emerge from the architecture; the training recipe — augmentation, momentum, cosine schedule, weight decay, label smoothing — is what decides whether those glasses see anything useful. Each knob is a regularizer in disguise; together they let the detective generalize from 50k photos to the distribution the photos were sampled from. This pattern — training-as-recipe rather than training-as-single-lever — is how every modern vision, language, and multimodal model is actually produced.

Next up — ResNet & Skip Connections. We used ResNet-18 as a black box here, and that black box comes with a catch the detective doesn't tell you about. The deeper you stack layers — the more pairs of glasses you give the detective — the worse training gets. Past a certain depth, plain CNNs stop improving and start actively regressing. A 56-layer plain network trains worse than a 20-layer one, and not because of overfitting. It's an optimization problem that looks impossible from the outside. Then, in 2015, one line of code made it go away: y = F(x) + x. Why that single addition unlocked 152-layer networks — and why ResNet was still the backbone of most computer-vision systems a decade later — is the next lesson.

References