Vision Transformer (ViT)
Treating image patches as tokens — the bridge to attention.
Picture a jigsaw puzzle box. You tip it onto the table and out fall a few hundred little squares — a sliver of sky here, a corner of an ear there, a chunk of wheel, a patch of fur. You don't see the image yet. You see pieces. That's the move this whole lesson rests on: take an image, cut it into puzzle pieces, and read the pieces like a sentence. A 224×224 photo sliced into 16×16 patches is a 196-token paragraph. Each patch is a word. The transformer reads them left-to-right-top-to-bottom and figures out what the picture is the same way it figures out what a sentence means.
For thirty years, that was heresy. Vision meant convolutions— grids, kernels, pooling, hand-tuned inductive biases that told a network images are 2D and nearby pixels are related. Language meant recurrence, then attention — sequences, tokens, transformers. Two architectures, two worlds, two PhD tracks.
Then in October 2020 a Google team submitted a paper called An Image Is Worth 16×16 Words and quietly knocked the wall down. Their claim: take the transformer block that was eating NLP, feed it pixels in puzzle-piece chunks instead of words, and if you scale the data far enough it beats every convnet you can name. This is the Vision Transformer — ViT — and it's the reason your 2025 multimodal model has one architecture for both modalities instead of two.
This lesson is the bridge. You've seen CNNs in this section. You've heard of transformers in the wild but haven't built one — that's a whole section by itself, coming later. For today the transformer encoder is a black box. The part we care about — the genuinely novel piece — is the one at the front: how do you cut an image into a paragraph the transformer can actually read?
The ViT hypothesis, in one sentence: an image is a sequence of patches. Not a 2D grid with spatial structure the network needs to respect via convolutions — a sequence, flat, ordered, just like a sentence of words. The transformer then treats each patch the way BERT treats a word: embed it, add a positional code so it knows where the piece sits in the grid, mix all the tokens with self-attention, and read out a prediction.
This is more radical than it sounds. A CNN is structurally forced to care about locality — a 3×3 kernel literally cannot look at two pixels that are a hundred pixels apart on the first layer. A ViT has no such constraint. On layer one, every patch can attend to every other patch. The sky piece in the top corner can shake hands with the grass piece in the bottom corner on step one. The network learns locality (or doesn't) from data alone.
Start with the surgery. Drag the patch-size slider and watch the image get diced. A 224×224 image with 16×16 patches gives you a 14×14 grid — 196 puzzle pieces, each a little 16×16×3 tile. Those 196 pieces are your sequence. From here on, the transformer doesn't see an image. It sees a list of 196 vectors, laid out end to end like words on a page, each one carrying the contents of one piece.
Three things to notice as you slide:
- Patch size is a hyperparameter with quadratic cost. Halving the patch side quadruples the sequence length, and attention is
O(N²)in sequence length. Going from16×16to8×8patches makes the model 16× more expensive. Smaller pieces, longer paragraph, bigger bill. - Each patch is flattened to a vector. A 16×16×3 RGB piece becomes a 768-dimensional vector (just
16 · 16 · 3 = 768). No averaging, no pooling — the raw pixels are the token. - The 2D structure is lost at this step. The grid collapses into a flat list; the puzzle pieces come out of the box in a line. We'll bolt position back on via a positional encoding — the transformer has no other way to know that patch 42 is above patch 56.
Now the quiet reveal — the step that explains why the whole vision-as-language thing even works. Input image x ∈ ℝ^(H×W×C). Reshape into N patches of size P×P×C, where N = HW/P². Flatten each patch, stack into a matrix, multiply by a single learned matrix E to project to model dimension d_model. That last line — patch-times-matrix — is the piece to stare at. It's the exact same operation as a word-embedding lookup. NLP tokenizes text into symbols and maps each symbol to a vector. ViT tokenizes an image into pieces and maps each piece to a vector. Same machine, different side of the puzzle box.
x ∈ ℝ^(H × W × C) input image → reshape to patches: xₚ ∈ ℝ^(N × (P² · C)) where N = HW / P² → linear projection: z₀ = xₚ · E , E ∈ ℝ^((P² · C) × d_model) → prepend [CLS] token, add positional encoding: z₀ = [ x_CLS ; xₚ · E ] + E_pos final shape: z₀ ∈ ℝ^((N + 1) × d_model)
For ViT-Base on 224×224 images: P=16, N=196, d_model=768. After patch embedding + CLS + positional encoding you have a 197 × 768 tensor. That's a sequence of length 197 — 197 tokens, a paragraph with one puzzle piece per word plus one “whole image” slot up front. Every transformer block that follows sees exactly that. It neither knows nor cares that the tokens came from pixels. To the attention machinery, this is indistinguishable from a sentence.
The positional encoding E_pos is a learned (N+1) × d_model table — one row per sequence position, including the CLS slot. It's the grid reference that tells patch 42 “you're in row 3, column 0” without re-introducing 2D structure to the network. ViT uses learned 1D position embeddings, not the sinusoidal ones from the original Attention Is All You Need paper; the authors found little difference in practice and 1D-learned is the simplest thing that works.
I am a 16×16 square of pixels — a scrap of fur, a slice of sky, a corner of a wheel. Alone I mean very little: 768 numbers that could be anything. But in the sequence I live in, my neighbors and I will shout at each other through attention until we agree on what the whole picture is. I am a token. Treat me like one.
Now the interesting part. Once the puzzle pieces are embedded and positionally-coded, every transformer block does self-attention — each patch computes a query, a key, and a value, and every patch's output is a weighted sum of every other patch's value, with weights given by query·key similarity. You haven't built attention yet, so take that on faith for now. What matters here is the consequence: patches can, and do, attend to anywhere in the image. The ear piece can ask the nose piece a question on step one.
Click a query patch. The heatmap shows how strongly it attends to every other piece of the grid in a trained ViT. Early layers often attend locally — ViT has to learn the CNN-style locality bias from data, rediscovering the “nearby pieces go together” prior that a convolution gets for free. Deeper layers often go long-range, hooking a background patch to a foreground object, or linking the left and right side of a symmetric thing. A probing paper by Raghu et al. (2021) showed ViT's early heads actually include both local-only heads (similar to 3×3 convolutions) and global heads — the network builds locality where it's useful and discards it where it isn't.
This is a different failure mode from a CNN's. A CNN can't look far without stacking depth or dilating kernels. A ViT can look anywhere from layer one but has to learn from scratch what to pay attention to. The trade is inductive bias for expressive power — and it's only a good trade when you have enough data to fill that freedom with signal.
One last trick before we code. The transformer outputs a sequence — 197 vectors in, 197 vectors out, same paragraph in, same paragraph out, contents rewritten. But classification wants a single vector to feed a linear head. How do you pool a paragraph down to one word?
ViT borrows BERT's move. Prepend a learned token, called [CLS], to the sequence. It has no corresponding patch — it's a free-floating vector in the embedding, the same on every example, trained like any other parameter. A blank sticker stuck to the front of the puzzle pieces. As it passes through the transformer blocks it attends to every patch and every patch attends to it. By the final block its embedding has soaked up information from the entire grid. You read it off, feed it to a linear classifier, and that's your logits.
I don't represent any part of the picture. I represent the whole picture. I'm a learned sponge, prepended at position zero, that spends twelve transformer blocks asking every patch what they think. By the time the final layer hands me off to the classifier head, I am the image, compressed into 768 numbers. Use me, then discard me — my only job is to pool.
Three layers, same progression you've seen everywhere in this series. NumPy to show the patch mechanics with nothing up our sleeves. PyTorch to show a single ViT block we can actually back-prop. timm to show what a real ViT call looks like in a production repo.
import numpy as np
# Fake image: 3 channels, 224×224.
rng = np.random.default_rng(0)
img = rng.normal(size=(3, 224, 224)).astype(np.float32)
C, H, W = img.shape
P = 16 # patch side
D = 768 # d_model
assert H % P == 0 and W % P == 0, "image must divide by patch size"
# (1) Reshape into an (N_h, N_w) grid of (C, P, P) patches, then flatten.
# Trick: reshape + transpose. It's the same data in a different order.
Nh, Nw = H // P, W // P # 14, 14
patches = img.reshape(C, Nh, P, Nw, P) # (3, 14, 16, 14, 16)
patches = patches.transpose(1, 3, 0, 2, 4) # (14, 14, 3, 16, 16)
patches = patches.reshape(Nh * Nw, C * P * P) # (196, 768)
# (2) Linear projection to d_model. For a 16×16×3 patch this is a 768→768 map;
# the matrix is learned in a real model, random here for illustration.
E = rng.normal(size=(C * P * P, D)).astype(np.float32) * 0.02
tokens = patches @ E # (196, 768)
# (3) Prepend [CLS], add positional encoding. Both learned in real ViT.
cls = rng.normal(size=(1, D)).astype(np.float32) * 0.02
tokens = np.concatenate([cls, tokens], axis=0) # (197, 768)
E_pos = rng.normal(size=(Nh * Nw + 1, D)).astype(np.float32) * 0.02
z0 = tokens + E_pos # (197, 768) — ready for transformer
print(f"image: {img.shape}")
print(f"patch grid: ({Nh}, {Nw})")
print(f"flat patches: {patches.shape}")
print(f"projected: {tokens[1:].shape}")
print(f"with [CLS]+pos: {z0.shape}")image: (3, 224, 224) patch grid: (14, 14) flat patches: (196, 768) projected: (196, 768) with [CLS]+pos: (197, 768)
Into PyTorch. Two things to notice. First, patch extraction is usually written as a Conv2d with stride=kernel=P — which looks like a convolution, but it isn't, really. Non-overlapping stride-equals-kernel is algebraically the same reshape-and-project you just did in NumPy; cuDNN just has a faster kernel for it. The pieces don't overlap, so no information is shared between them — that's what keeps this “tokenization” and not “feature extraction.” Second, the transformer encoder block comes pre-built as nn.TransformerEncoderLayer. We call it, we don't define it.
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
"""Image → flat token sequence via a strided convolution."""
def __init__(self, img_size=224, patch=16, in_ch=3, d_model=768):
super().__init__()
self.n_patches = (img_size // patch) ** 2
# Conv2d(kernel=patch, stride=patch) = non-overlapping patch projection.
self.proj = nn.Conv2d(in_ch, d_model, kernel_size=patch, stride=patch)
def forward(self, x): # x: (B, 3, 224, 224)
x = self.proj(x) # (B, D, 14, 14)
x = x.flatten(2).transpose(1, 2) # (B, 196, D)
return x
class ViTLite(nn.Module):
"""One-block ViT skeleton — swap N=12 for the real ViT-Base."""
def __init__(self, img_size=224, patch=16, d_model=768, n_heads=12, n_blocks=1, n_classes=1000):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch, 3, d_model)
N = self.patch_embed.n_patches
# Learned [CLS] and 1D positional encoding — both are nn.Parameter.
self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
self.pos = nn.Parameter(torch.zeros(1, N + 1, d_model))
# Transformer encoder — we get attention + MLP + LayerNorm as one call.
block = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
batch_first=True, activation='gelu', norm_first=True,
)
self.encoder = nn.TransformerEncoder(block, num_layers=n_blocks)
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, n_classes)
def forward(self, x): # (B, 3, 224, 224)
B = x.size(0)
tokens = self.patch_embed(x) # (B, 196, D)
cls = self.cls.expand(B, -1, -1) # (B, 1, D)
z = torch.cat([cls, tokens], dim=1) + self.pos # (B, 197, D)
z = self.encoder(z) # (B, 197, D)
z = self.norm(z)
return self.head(z[:, 0]) # CLS → logits (B, 1000)
model = ViTLite()
x = torch.randn(2, 3, 224, 224)
print(model(x).shape) # torch.Size([2, 1000])img.reshape(...).transpose(...).reshape(...)←→nn.Conv2d(stride=patch, kernel=patch)— same linear map, but differentiable and CUDA-fused
cls = rng.normal(size=(1, D))←→self.cls = nn.Parameter(torch.zeros(1, 1, D))— nn.Parameter registers it for gradient descent
hand-rolled self-attention loop←→nn.TransformerEncoderLayer(...)— multi-head attention + MLP + LayerNorm in one line
np.concatenate([cls, tokens]) + E_pos←→torch.cat([cls, tokens], dim=1) + self.pos— same op, different dtype and device story
Layer three is what you'd actually ship. timm — the PyTorch-image-models library — carries every ViT variant pretrained on ImageNet-21k, JFT, LAION, and more. In practice nobody trains ViT from scratch on their own data; they finetune a timm checkpoint. The pieces (patch embed, CLS, pos embed) you just wrote by hand are sitting right there, with the same names and the same shapes, just with weights that actually work.
import timm
import torch
# One call loads architecture + weights pretrained on ImageNet-21k,
# finetuned on ImageNet-1k. The model is ready for inference.
model = timm.create_model('vit_base_patch16_224', pretrained=True).eval()
# timm exposes the same pieces we built by hand.
print("patch embed weight shape: ", model.patch_embed.proj.weight.shape)
print("CLS token shape: ", model.cls_token.shape)
print("pos embed shape: ", model.pos_embed.shape)
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
logits = model(x) # (1, 1000) — ImageNet classes
print("logits shape: ", logits.shape)model.default_cfg: crop_pct=0.9, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5] patch embed weight shape: torch.Size([768, 3, 16, 16]) CLS token shape: torch.Size([1, 1, 768]) pos embed shape: torch.Size([1, 197, 768]) logits shape: torch.Size([1, 1000])
ViTLite(n_blocks=1)←→timm.create_model('vit_base_patch16_224')— same architecture at N=12, plus pretrained weights
model.patch_embed.proj←→model.patch_embed.proj— timm uses the exact same Conv2d trick — ours lines up 1:1
self.cls / self.pos←→model.cls_token / model.pos_embed— same parameters, same shapes — naming converges
Two follow-up architectures worth knowing, because they're the pragmatic compromises you'll actually meet in 2026 codebases:
- DeiT (Touvron et al. 2021). “Data-efficient ViT.” Same architecture, trained with heavy augmentation + a distillation token that mimics a CNN teacher. Matches ViT quality with only ImageNet-1k — no JFT required. This is what made ViT accessible to non-Google labs.
- Hybrid ViTs. Replace the patchify-by-Conv2d with a small CNN stem (say, a ResNet's first three stages). The CNN does the low-level feature extraction — which is what CNNs are good at — and the transformer does long-range mixing over the resulting tokens. ConvNeXt, Swin, and most modern vision backbones sit somewhere on this spectrum.
Image size must divide patch size. A ViT with P=16 on a 225×225 image will error (or silently crop). If you want arbitrary resolution, you must either pad, resize, or pick a patch size that divides your shortest side. timm uses bicubic resize by default — not free; it blurs edges.
Positional encoding doesn't transfer to new resolutions. If you pretrain at 224 (196 patches) and finetune at 384 (576 patches), the positional encoding table is the wrong size. The fix is a 2D bicubic interpolation of E_pos, treating it as a 14×14 image and upsampling to 24×24. Every ViT codebase implements this and every ViT codebase has had a subtle bug in it.
CLS position matters — and isn't the only pooling choice. ViT uses CLS-at-position-0. Some variants use mean-pooling over all patch tokens (“GAP”) instead, which can be a few tenths of a point better on ImageNet. If you swap pooling strategies you must retrain the head. Don't mix.
Normalization stats are patch-level. The input normalization (ImageNet mean/std) is applied pixel-wise before patching — it is not per-patch. A common bug is to forget this and normalize twice or not at all.
Write a PatchEmbedScratch module that takes an image (B, 3, 224, 224) and returns (B, 197, 768) — patches + CLS + learned positional encoding — without using nn.Conv2d. Use pure reshape + transpose + nn.Linear, following the NumPy code above.
Then load timm.create_model('vit_base_patch16_224', pretrained=True), copy its patch_embed.proj Conv2d weights into your nn.Linear (you'll need a reshape — the Conv2d weight is (768, 3, 16, 16), your Linear weight is (768, 768)), and confirm that your output matches timm's to within floating-point noise: torch.allclose(mine, theirs, atol=1e-5).
Bonus: time both on a 32-image batch. The Conv2d version will be faster by a noticeable margin — cuDNN has an optimized kernel for stride=kernel, and a plain Linear doesn't. This is the usual “equivalent math, unequal hardware” lesson.
What to carry forward. ViT's one-sentence contribution: an image is a sequence of puzzle pieces, and you can run a transformer on it. The patch embedding is a linear projection (a Conv2d with stride=kernel under the hood, same math as a word-embedding lookup), the CLS token pools the paragraph into a single vector, and the positional encoding puts the 2D grid back after you flattened it away. The trade versus a CNN is less inductive bias, more data hunger, higher ceiling. At ImageNet-21k and above, ViT wins. Below, CNNs or hybrids win. The deeper lesson — the one that pays off for the rest of this curriculum — is that attention doesn't care what the tokens are. Pixels, words, audio frames, protein residues — if you can tokenize it, a transformer can eat it. Vision is just language with different puzzle pieces.
Up next — Build GPT. You've seen two tokenizers now: one that cuts images into fixed 16×16 squares, and one that handles embeddings for words. Image tokens are easy to eyeball — you can literally see them on the grid. But what actually counts as a token for a language model? “Cat” is a token. Is “cats”? Is “cat’s”? Is the space before it? The answer changes how everything downstream behaves — vocab size, context window, which spellings the model can even express. Next lesson: Tokenizer (Byte Pair Encoding). We stop taking the word “token” for granted, build the tokenizer GPT actually uses from scratch, and see why “ hello” and “hello” are not the same token.
- [01]Dosovitskiy et al. · ICLR 2021 — the original ViT paper
- [02]Touvron et al. · ICML 2021 — DeiT
- [03]Raghu, Unterthiner, Kornblith, Zhang, Dosovitskiy · NeurIPS 2021 — representation probing of ViT
- [04]Zhang, Lipton, Li, Smola · d2l.ai
- [05]Wightman · reference implementation — every ViT variant, pretrained