Part III: Deep Learning for Computer Vision
Chapter 22: Vision Transformers

Chapter 22: Vision Transformers

"For thirty years they told me a pixel only cared about its neighbors. Then someone cut me into sixteen-by-sixteen squares, lined them up like words, and asked every square what it thought of every other square. Turns out the corner had opinions about the sky all along."

A Mildly Overfit Vision Transformer
Big Picture

A Vision Transformer treats an image not as a grid to be convolved but as a short sequence of patches, and it lets self-attention learn which patches should talk to which, replacing the hard-wired locality of the convolution with a global mixing operation that is learned from data. This trade buys you a representation that can relate the top-left corner to the bottom-right one in a single layer, and that scales remarkably well when you have enough data. It costs you the convolution's built-in assumptions about locality and translation, the very inductive biases that let CNNs learn from modest datasets. The whole chapter is about when that trade is worth it. We build the attention block in the vision setting, turn an image into a sequence with the patch-embedding recipe, learn the augmentation and distillation tricks that made ViTs trainable without a private hundred-million-image dataset, study the hierarchical designs that brought back some of the convolution's efficiency, and finish with an honest, evidence-based comparison of CNNs and transformers and the hybrids that now win most benchmarks.

Remember the Chapter in Three Words: Patch, Position, Attend

If you keep one thing from this chapter, keep the recipe that turns a convolutional mind into a transformer one. Patch: cut the image into fixed squares and flatten each into a token (Section 22.2). Position: add a learned position vector so the model knows where each patch came from, because attention by itself treats the patches as an unordered bag. Attend: let every token weigh every other token, with weights computed from the content itself (Section 22.1). Patch, position, attend is the entire front end of a Vision Transformer, and you will meet the same three-beat pattern, lightly disguised, in detection, segmentation, video, and text-to-image generation for the rest of the book.

Chapter Overview

For four chapters you have built vision around the convolution. Chapter 19 made the kernel learnable; Chapter 20 stacked kernels into ResNet, Inception, and ConvNeXt; Chapter 21 taught you to train any of them well. Every one of those designs bakes in two assumptions that turn out to be choices, not laws of nature: that a feature depends mostly on its local neighborhood, and that a feature detector useful in one place is useful everywhere. Those assumptions are exactly why CNNs learn so efficiently. They are also a ceiling. A convolution cannot, in one layer, relate a hand to the racket it is holding on the far side of the frame; it must build that relationship slowly, layer by layer, as the receptive field grows. The Vision Transformer asks a different question: what if we let the network decide, per image and per layer, which regions should influence which, with no locality assumption at all?

The answer arrived from an unlikely direction. The transformer was built for language, where a sentence is already a sequence of tokens and attention relates every word to every other word. The 2020 ViT paper made one audacious move: chop the image into fixed-size patches, flatten each patch into a vector, and feed the resulting sequence of "visual words" into a standard transformer encoder, almost unchanged from the language version. With enough data the result matched or beat the best CNNs. Section 22.1 builds the engine: scaled dot-product attention, multi-head attention, and the full transformer block, in the vision setting and from scratch, so the rest of the chapter rests on mechanics you have implemented rather than imported.

Section 22.2 is the heart of the chapter: the ViT architecture itself. You will see exactly how an image becomes a sequence (the patch embedding, which we will recognize as a strided convolution in disguise), why a learnable class token and positional embeddings are needed, and how the encoder stack and classification head finish the job. Section 22.3 confronts the original ViT's dirty secret: it needed roughly 300 million images to shine, far more than most teams will ever have. DeiT showed how strong augmentation and a clever distillation token let a ViT train to competitive accuracy on plain ImageNet, on a single machine, in a few days. This is the section that turned ViTs from a curiosity into a tool.

Section 22.4 brings back structure. The plain ViT keeps the same resolution and the same global attention at every layer, which is both wasteful and ill-suited to dense tasks like detection and segmentation. Swin Transformer and the pyramid family reintroduce the multi-scale hierarchy of the image pyramids from Chapter 4 and the feature hierarchies of Chapter 20, computing attention in local windows and shifting those windows between layers to recover global reach at near-linear cost. Section 22.5 closes the chapter with the comparison the whole part has been building toward: inductive bias versus scale, where the crossover sits, and why the architectures that dominate in 2024 to 2026 are neither pure CNN nor pure transformer but deliberate hybrids.

The unifying thread is attention itself, and it does not stop here. The self-attention you build in Section 22.1 becomes the mask-transformer decoder of Chapter 24, the temporal attention of video models in Chapter 26, and, most consequentially, the cross-attention that lets a text prompt condition an image generator in Chapter 33 and Chapter 34. Learn the patch-and-attend pattern well here, because you will meet it, lightly disguised, for the rest of the book.

Prerequisites

You should have read Chapter 18: Neural Networks & PyTorch for Vision for the tensor mechanics, layer normalization, residual connections, and training loop this chapter assembles into a transformer. Chapter 19: Convolutional Neural Networks is essential, because the central argument of the chapter is a contrast with the convolution's inductive biases, and because the patch embedding turns out to be a convolution. Chapter 20: CNN Architectures supplies the ResNet and ConvNeXt baselines that ViTs are measured against, and Chapter 21: Training Recipes supplies the augmentation, schedule, and regularization that ViTs depend on even more heavily than CNNs do. Comfort with matrix multiplication, the softmax function, and the dot product as a similarity measure (linear algebra you have used since Chapter 3) makes the attention derivation in Section 22.1 concrete rather than abstract.

Chapter Roadmap

The One-Line Trade to Carry Through the Chapter: Bias Buys Data, Scale Buys Freedom

Every design choice in this chapter is a move along a single axis. The convolution's built-in priors (locality, translation equivariance) are paid-up data efficiency: bias buys data, so a CNN learns from modest datasets. A plain ViT throws those priors away and must repay the bill in examples or augmentation, but in exchange its flexibility keeps improving at a scale where the CNN's biases become a ceiling: scale buys freedom. Read the whole chapter as people negotiating this trade, DeiT pays the bill with augmentation and a CNN teacher (Section 22.3), Swin buys the priors back as window attention and a pyramid (Section 22.4), and the winning hybrids (Section 22.5) simply refuse to choose. When you size your own data and resolution against Table 22.5.1, you are pricing exactly this trade.

What's Next?

With attention and the patch-and-attend recipe in hand, you are ready to see the transformer leave classification behind and take over the rest of computer vision. Chapter 23: Object Detection is the immediate sequel: the DETR family reframes detection as set prediction with a transformer decoder, dropping the hand-designed anchors and non-maximum suppression of classical detectors, and the hierarchical backbones of Section 22.4 are exactly what modern detectors sit on top of. From there the attention thread runs through Chapter 24, where mask transformers and the Segment Anything Model turn pixels into promptable masks, through video transformers in Chapter 26, and into the generative half of the book, where the cross-attention you build here is the mechanism that lets a sentence steer a diffusion model. The convolution gave you the first decade of deep vision; attention is writing the second. Before moving on, assemble the chapter's three-beat recipe into one runnable program in the Hands-On Lab below, where patch, position, and attend become a small Vision Transformer you train and evaluate on CIFAR-10.

Hands-On Lab: Build and Train a Vision Transformer from Scratch

Duration: about 60 to 90 minutes Difficulty: Intermediate

Objective

Assemble the chapter's three-beat recipe (patch, position, attend) into one runnable program: a compact Vision Transformer that turns a 32x32 CIFAR-10 image into a sequence of patch tokens, prepends a learnable class token, adds positional embeddings, mixes the tokens through a stack of pre-norm transformer blocks you wrote by hand, and classifies from the class token. You then train it for a few epochs on CIFAR-10 and read off both a test accuracy and a class-token attention map, so the patch-position-attend pipeline produces a measurable number and a picture of where the model looks.

What You'll Practice

  • Turning an image into a sequence with the patch embedding, the strided convolution in disguise from Section 22.2.
  • Prepending a learnable class token and adding positional embeddings so attention knows order and has a vector to classify from (Section 22.2).
  • Implementing scaled dot-product multi-head attention and a pre-norm transformer block with residuals, from scratch, the engine of Section 22.1.
  • Training a ViT on a small dataset with the augmentation-heavy, data-efficient mindset of Section 22.3, and reading a real test accuracy.
  • Extracting and visualizing the class-token attention to see which patches the prediction rests on, the interpretability payoff of Section 22.1.

Setup

One deep-learning stack and no manual download; torchvision fetches CIFAR-10 on first run. A GPU finishes the suggested five epochs in a few minutes, but the code runs on CPU too (slower). Install with:

pip install torch torchvision matplotlib

Everything below lives in one script. The model is deliberately small (patch size 4, embedding dimension 192, depth 6) so it trains to a meaningful accuracy on CIFAR-10 in minutes rather than the days a full ViT-Base needs.

Steps

Step 1: Embed the image as a sequence of patch tokens

Cut each image into non-overlapping patches and flatten each into a vector, exactly the patch embedding of Section 22.2. The cleanest implementation is a single strided convolution whose kernel and stride both equal the patch size: it slices and projects in one operation.

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch=4, in_ch=3, dim=192):
        super().__init__()
        self.n_patches = (img_size // patch) ** 2      # 8x8 = 64 tokens
        # TODO: create a Conv2d with in_ch inputs, dim outputs, and both
        # kernel_size and stride equal to `patch`. This is the "strided
        # convolution in disguise" that splits and projects each patch at once.
        self.proj = ...

    def forward(self, x):                              # x: (B, 3, 32, 32)
        x = self.proj(x)                               # (B, dim, 8, 8)
        # TODO: flatten the 8x8 grid into 64 tokens and put dim last,
        # producing shape (B, 64, dim). Hint: .flatten(2).transpose(1, 2)
        return x
Hint

self.proj = nn.Conv2d(in_ch, dim, kernel_size=patch, stride=patch) and in forward, return x.flatten(2).transpose(1, 2). The convolution stride equal to the kernel guarantees the patches never overlap, which is what makes each output cell a distinct patch token.

Step 2: Prepend the class token and add positional embeddings

Attention treats its inputs as an unordered set, so two pieces are needed before mixing: a learnable class token that travels with the sequence and becomes the vector you classify from, and a positional embedding added to every token so the model knows where each patch came from. Both are the parameters introduced in Section 22.2.

class TokenPrep(nn.Module):
    def __init__(self, n_patches, dim):
        super().__init__()
        self.cls = nn.Parameter(torch.zeros(1, 1, dim))           # one class token
        # TODO: create self.pos as an nn.Parameter of zeros with shape
        # (1, n_patches + 1, dim): one position vector per patch plus the
        # class token. Initialize both with nn.init.trunc_normal_(..., std=0.02).
        self.pos = ...
        nn.init.trunc_normal_(self.cls, std=0.02)
        nn.init.trunc_normal_(self.pos, std=0.02)

    def forward(self, x):                              # x: (B, 64, dim)
        b = x.shape[0]
        cls = self.cls.expand(b, -1, -1)              # (B, 1, dim)
        x = torch.cat([cls, x], dim=1)                # (B, 65, dim)
        # TODO: add the positional embedding self.pos to x and return it.
        return x
Hint

self.pos = nn.Parameter(torch.zeros(1, n_patches + 1, dim)) and return x + self.pos. The + 1 is the slot for the class token; forget it and the addition will raise a shape error, which is the fastest way to remember why the count is 65 not 64.

Step 3: Implement multi-head scaled dot-product attention

This is the engine of Section 22.1: project every token to queries, keys, and values, score each query against every key with a scaled dot product, softmax the scores into weights, and take the weighted sum of values. Splitting the channels into heads lets the block attend to several relationships at once.

import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, heads=3):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5           # 1/sqrt(d_head)
        self.qkv = nn.Linear(dim, dim * 3)            # produce Q, K, V at once
        self.out = nn.Linear(dim, dim)

    def forward(self, x):                              # x: (B, N, dim)
        b, n, d = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.heads, d // self.heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)              # (3, B, heads, N, d_head)
        q, k, v = qkv[0], qkv[1], qkv[2]
        # TODO: compute attention weights as softmax over the last axis of
        # (q @ k.transpose(-2, -1)) * self.scale, then multiply by v.
        attn = ...
        out = ...
        out = out.transpose(1, 2).reshape(b, n, d)    # merge heads back
        return self.out(out), attn                    # also return weights for Step 6
Hint

attn = (q @ k.transpose(-2, -1) * self.scale).softmax(dim=-1) then out = attn @ v. The scale factor 1/sqrt(d_head) keeps the dot products from growing with dimension and saturating the softmax, the reason it appears in Exercise 22.1.1. In production you would replace these two lines with F.scaled_dot_product_attention(q, k, v), the fused FlashAttention path.

Step 4: Stack pre-norm transformer blocks into the encoder

Wrap attention and an MLP in the pre-norm residual block of Section 22.1: LayerNorm, then the sublayer, then add the input back. Stacking these blocks is the entire body of a Vision Transformer; the class token accumulates evidence from the patches as it passes up the stack.

class Block(nn.Module):
    def __init__(self, dim, heads=3, mlp_ratio=4):
        super().__init__()
        self.n1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, heads)
        self.n2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * mlp_ratio), nn.GELU(),
            nn.Linear(dim * mlp_ratio, dim),
        )

    def forward(self, x):
        a, attn = self.attn(self.n1(x))
        # TODO: apply the two pre-norm residual connections:
        # x = x + a   (attention sublayer)
        # x = x + self.mlp(self.n2(x))   (MLP sublayer)
        return x, attn
Hint

x = x + a then x = x + self.mlp(self.n2(x)), and return x, attn. Pre-norm (normalize before the sublayer, add the raw input) is what lets deep transformer stacks train stably without warmup tricks; post-norm ViTs are notoriously twitchy at depth.

Step 5: Assemble the full ViT and train it on CIFAR-10

Chain patch embedding, token prep, the block stack, a final norm, and a linear head over the class token, the complete topology of Section 22.2. Then train a few epochs on CIFAR-10 with the random-crop and flip augmentation that Section 22.3 shows ViTs depend on when data is scarce.

class ViT(nn.Module):
    def __init__(self, img=32, patch=4, dim=192, depth=6, heads=3, classes=10):
        super().__init__()
        self.embed = PatchEmbed(img, patch, 3, dim)
        self.prep = TokenPrep(self.embed.n_patches, dim)
        self.blocks = nn.ModuleList([Block(dim, heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, classes)

    def forward(self, x):
        x = self.prep(self.embed(x))
        last_attn = None
        for blk in self.blocks:
            x, last_attn = blk(x)
        x = self.norm(x)
        # TODO: classify from the class token only (index 0 along the token
        # axis): return self.head(x[:, 0]). Also return last_attn for Step 6.
        return ..., last_attn

import torchvision as tv
from torchvision import transforms as T

dev = "cuda" if torch.cuda.is_available() else "cpu"
train_tf = T.Compose([T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(),
                      T.ToTensor()])
test_tf = T.ToTensor()
tr = tv.datasets.CIFAR10("./data", train=True, download=True, transform=train_tf)
te = tv.datasets.CIFAR10("./data", train=False, download=True, transform=test_tf)
trl = torch.utils.data.DataLoader(tr, batch_size=128, shuffle=True, num_workers=2)
tel = torch.utils.data.DataLoader(te, batch_size=256, num_workers=2)

model = ViT().to(dev)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)

for epoch in range(5):
    model.train()
    for xb, yb in trl:
        xb, yb = xb.to(dev), yb.to(dev)
        opt.zero_grad()
        logits, _ = model(xb)
        loss = loss_fn(logits, yb)
        loss.backward()
        opt.step()
    print(f"epoch {epoch} done, last batch loss {loss.item():.3f}")
Hint

The classifier line is return self.head(x[:, 0]), last_attn. The class token sits at index 0 because Step 2 concatenated it before the patches; every patch token influenced it through attention, so it carries a global summary the linear head can read. If memory is tight, drop batch_size to 64.

Step 6: Evaluate and visualize the class-token attention

Measure test accuracy, then average the last block's attention from the class token over its heads and reshape it back to the 8x8 patch grid. The result is a heat map of which patches the prediction leaned on, the interpretability view that Exercise 22.1.2 and Section 22.1 promised.

model.eval()
correct = total = 0
with torch.no_grad():
    for xb, yb in tel:
        xb, yb = xb.to(dev), yb.to(dev)
        logits, _ = model(xb)
        correct += (logits.argmax(1) == yb).sum().item()
        total += yb.size(0)
print(f"test accuracy: {100 * correct / total:.1f}%")

# Attention map for one test image, averaged over heads.
import matplotlib.pyplot as plt
img, _ = te[0]
with torch.no_grad():
    _, attn = model(img.unsqueeze(0).to(dev))         # attn: (1, heads, 65, 65)
# TODO: take the class token's attention to the 64 patch tokens
# (attn[0, :, 0, 1:]), average over the heads (dim=0), reshape to (8, 8),
# and show it with plt.imshow over the input image.
cls_attn = ...
plt.imshow(img.permute(1, 2, 0)); plt.imshow(cls_attn, alpha=0.6, cmap="inferno")
plt.title("class-token attention"); plt.axis("off"); plt.savefig("vit_attn.png")
Hint

cls_attn = attn[0, :, 0, 1:].mean(0).reshape(8, 8).cpu(). The 1: skips the class token's attention to itself; mean(0) collapses the heads. Upsample the 8x8 map to 32x32 with cls_attn.repeat_interleave(4, 0).repeat_interleave(4, 1) if you want it to line up pixel-for-pixel with the image.

Expected Output

Two artifacts. First, a printed test accuracy: this small from-scratch ViT, trained for only five epochs with basic augmentation, lands in roughly the 65 to 75 percent range on CIFAR-10, well above the 10 percent chance baseline and a concrete demonstration that the patch-position-attend pipeline learns. Train longer (30 to 50 epochs) and it climbs past 80 percent. Second, vit_attn.png, a heat map where the class token's attention concentrates on the object rather than the background, the visual confirmation that attention learned to look at what matters. The exact number will vary with seed and hardware; what should hold is accuracy far above chance and an attention map that is not uniform.

Stretch Goals

  • Vary the patch size to 2 and to 8 and retrain; smaller patches mean more tokens, more compute, and usually higher accuracy, the trade Exercise 22.2.2 asks you to feel firsthand.
  • Add a distillation token alongside the class token and train it against a small pretrained CNN's logits, a miniature of the DeiT recipe from Section 22.3; compare accuracy with and without it.
  • Replace the two attention lines in Step 3 with F.scaled_dot_product_attention(q, k, v) and time an epoch both ways to feel the fused-kernel speedup, then connect it to the hierarchical, near-linear designs of Section 22.4.
Library Shortcut: timm Gives You a Trained-Recipe ViT in Three Lines

The model above is roughly ninety lines and exposes every token and every attention weight on purpose. The timm library (the reference source from the bibliography) ships the same architecture, plus the full DeiT training recipe and pretrained weights, behind one call: import timm then model = timm.create_model("vit_tiny_patch16_224", pretrained=True, num_classes=10) and fine-tune. That is a 90-to-3 reduction, and the pretrained weights skip the data-hunger problem of Section 22.3 entirely. Build the ViT once by hand to understand what those three lines hide; reach for timm every time after.

Complete Solution
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
from torchvision import transforms as T
import matplotlib.pyplot as plt

class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch=4, in_ch=3, dim=192):
        super().__init__()
        self.n_patches = (img_size // patch) ** 2
        self.proj = nn.Conv2d(in_ch, dim, kernel_size=patch, stride=patch)

    def forward(self, x):
        x = self.proj(x)
        return x.flatten(2).transpose(1, 2)           # (B, 64, dim)

class TokenPrep(nn.Module):
    def __init__(self, n_patches, dim):
        super().__init__()
        self.cls = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos = nn.Parameter(torch.zeros(1, n_patches + 1, dim))
        nn.init.trunc_normal_(self.cls, std=0.02)
        nn.init.trunc_normal_(self.pos, std=0.02)

    def forward(self, x):
        b = x.shape[0]
        cls = self.cls.expand(b, -1, -1)
        x = torch.cat([cls, x], dim=1)
        return x + self.pos

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, heads=3):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5
        self.qkv = nn.Linear(dim, dim * 3)
        self.out = nn.Linear(dim, dim)

    def forward(self, x):
        b, n, d = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.heads, d // self.heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1) * self.scale).softmax(dim=-1)
        out = attn @ v
        out = out.transpose(1, 2).reshape(b, n, d)
        return self.out(out), attn

class Block(nn.Module):
    def __init__(self, dim, heads=3, mlp_ratio=4):
        super().__init__()
        self.n1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, heads)
        self.n2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * mlp_ratio), nn.GELU(),
            nn.Linear(dim * mlp_ratio, dim),
        )

    def forward(self, x):
        a, attn = self.attn(self.n1(x))
        x = x + a
        x = x + self.mlp(self.n2(x))
        return x, attn

class ViT(nn.Module):
    def __init__(self, img=32, patch=4, dim=192, depth=6, heads=3, classes=10):
        super().__init__()
        self.embed = PatchEmbed(img, patch, 3, dim)
        self.prep = TokenPrep(self.embed.n_patches, dim)
        self.blocks = nn.ModuleList([Block(dim, heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, classes)

    def forward(self, x):
        x = self.prep(self.embed(x))
        last_attn = None
        for blk in self.blocks:
            x, last_attn = blk(x)
        x = self.norm(x)
        return self.head(x[:, 0]), last_attn

if __name__ == "__main__":
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    train_tf = T.Compose([T.RandomCrop(32, padding=4),
                          T.RandomHorizontalFlip(), T.ToTensor()])
    test_tf = T.ToTensor()
    tr = tv.datasets.CIFAR10("./data", train=True, download=True, transform=train_tf)
    te = tv.datasets.CIFAR10("./data", train=False, download=True, transform=test_tf)
    trl = torch.utils.data.DataLoader(tr, batch_size=128, shuffle=True, num_workers=2)
    tel = torch.utils.data.DataLoader(te, batch_size=256, num_workers=2)

    model = ViT().to(dev)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
    loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)

    for epoch in range(5):
        model.train()
        for xb, yb in trl:
            xb, yb = xb.to(dev), yb.to(dev)
            opt.zero_grad()
            logits, _ = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()
        print(f"epoch {epoch} done, last batch loss {loss.item():.3f}")

    model.eval()
    correct = total = 0
    with torch.no_grad():
        for xb, yb in tel:
            xb, yb = xb.to(dev), yb.to(dev)
            logits, _ = model(xb)
            correct += (logits.argmax(1) == yb).sum().item()
            total += yb.size(0)
    print(f"test accuracy: {100 * correct / total:.1f}%")

    img, _ = te[0]
    with torch.no_grad():
        _, attn = model(img.unsqueeze(0).to(dev))
    cls_attn = attn[0, :, 0, 1:].mean(0).reshape(8, 8).cpu()
    cls_attn = cls_attn.repeat_interleave(4, 0).repeat_interleave(4, 1)
    plt.imshow(img.permute(1, 2, 0))
    plt.imshow(cls_attn, alpha=0.6, cmap="inferno")
    plt.title("class-token attention")
    plt.axis("off")
    plt.savefig("vit_attn.png")
    print("wrote vit_attn.png")

Bibliography & Further Reading

Foundational Papers

Vaswani, A. et al. "Attention Is All You Need." NeurIPS (2017). arXiv:1706.03762
The transformer paper. It introduced scaled dot-product attention, multi-head attention, and the encoder-decoder block of Section 22.1. Written for language, but the block it defines is the engine inside every architecture in this chapter.
Dosovitskiy, A. et al. "An Image Is Worth 16x16 Words: Transformers for Image Recognition at Scale." ICLR (2021). arXiv:2010.11929
The Vision Transformer (ViT) of Section 22.2. Splits an image into 16x16 patches, embeds each as a token, and feeds a near-standard transformer encoder. The paper that started this chapter, including the key finding that ViTs need large pretraining data to beat CNNs.
Touvron, H. et al. "Training data-efficient image transformers & distillation through attention." ICML (2021). arXiv:2012.12877
DeiT of Section 22.3. Strong augmentation plus a distillation token learning from a CNN teacher let a ViT reach competitive ImageNet accuracy with no external data, on a single 8-GPU node, in a few days.
Liu, Z. et al. "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows." ICCV (2021). arXiv:2103.14030
Swin of Section 22.4. Window-local attention with shifts between layers gives a multi-scale feature pyramid at linear cost in image size, making transformers practical backbones for detection and segmentation.
Wang, W. et al. "Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions." ICCV (2021). arXiv:2102.12122
PVT of Section 22.4. An early pyramid transformer that shrinks the token sequence stage by stage and uses spatial-reduction attention, an alternative route to the multi-scale features that dense prediction needs.

Recent Research (2022-2026)

Liu, Z. et al. "A ConvNet for the 2020s." CVPR (2022). arXiv:2201.03545
ConvNeXt, the counterargument of Section 22.5. By porting transformer-era design and training choices back into a pure CNN, it matches Swin, showing much of the ViT advantage was recipe rather than attention.
Dehghani, M. et al. "Scaling Vision Transformers to 22 Billion Parameters." ICML (2023). arXiv:2302.05442
ViT-22B, the scale end of Section 22.5. Demonstrates that the transformer's clean scaling story carries into the tens of billions of parameters, the regime where attention's lack of inductive bias becomes an asset rather than a liability.
Oquab, M. et al. "DINOv2: Learning Robust Visual Features without Supervision." TMLR (2024). arXiv:2304.07193
A frontier ViT backbone trained self-supervised, previewing Chapter 25. Its features transfer to detection, segmentation, and depth without fine-tuning, the modern payoff of the patch-and-attend representation.
Darcet, T. et al. "Vision Transformers Need Registers." ICLR (2024). arXiv:2309.16588
A 2024 fix discussed in Section 22.5: adding a few extra learnable "register" tokens removes high-norm artifact tokens in ViT attention maps, cleaning up the features and the interpretability of where the model looks.

Tools & Libraries

Wightman, R. timm (PyTorch Image Models). github.com/huggingface/pytorch-image-models
The reference implementation source for ViT, DeiT, Swin, and dozens of hybrids, with pretrained weights and the exact training recipes of Sections 22.3 and 22.4. The library shortcut behind most of this chapter's code.
Hugging Face Transformers, Vision models. huggingface.co/docs/transformers
High-level AutoModel and AutoImageProcessor APIs that load ViT, DeiT, and Swin with their preprocessing in a few lines, used in the library shortcuts of Sections 22.2 and 22.3.
PyTorch nn.MultiheadAttention and torch.nn.functional.scaled_dot_product_attention. pytorch.org/docs
The built-in, fused (FlashAttention-backed) attention primitive that replaces the from-scratch attention of Section 22.1 with a single, memory-efficient call in production code.

Tutorials & Explainers

Phil Wang (lucidrains). vit-pytorch. github.com/lucidrains/vit-pytorch
A clean, heavily-starred from-scratch implementation of ViT and many variants, the ideal companion for reading the Section 22.2 architecture line by line.
Alammar, J. "The Illustrated Transformer." (2018). jalammar.github.io/illustrated-transformer
The most accessible visual walkthrough of attention, query-key-value scoring, and multi-head mechanics that Section 22.1 formalizes from scratch. It builds the same intuition pictorially before you meet the equations. Ideal for beginners who want a mental model of attention before reading the math.

Datasets & Benchmarks

Sun, C. et al. "Revisiting Unreasonable Effectiveness of Data in Deep Learning Era (JFT-300M)." ICCV (2017). arXiv:1707.02968
The 300-million-image internal dataset that the original ViT pretrained on, the concrete reason for the data-hunger of Section 22.3 and the scale crossover of Section 22.5.
Deng, J. et al. "ImageNet: A Large-Scale Hierarchical Image Database." CVPR (2009). image-net.org
The 1.28-million-image benchmark on which DeiT (Section 22.3) proved data-efficient training, and the common yardstick for every CNN-versus-ViT comparison in Section 22.5.