Part III: Deep Learning for Computer Vision
Chapter 21: Training Recipes: Data, Augmentation & Transfer

Debugging Training: Curves, Overfitting & Sanity Checks

"My loss would not budge for six hours. They blamed the learning rate, the optimizer, my very architecture. The labels were shifted by one. I had been perfectly, diligently learning to predict the wrong answer the entire time."

A Network That Was Never Actually Broken
Big Picture

Most failed training runs are not failures of the architecture or even the recipe; they are bugs in the pipeline, and a handful of cheap sanity checks catch the great majority of them in minutes rather than hours. This closing section is a debugging manual. It teaches you to read loss and accuracy curves and diagnose the three canonical states (overfitting, underfitting, and the broken pipeline that is neither), to run the high-value sanity checks that surface bugs early (overfit a single batch, verify the loss at initialization, look at a real loaded batch), and to follow a systematic playbook when a model simply will not learn. The skill being built is not memorizing fixes but forming a habit: when training misbehaves, you have an ordered list of cheap experiments, not a panic.

You now have the full recipe from the previous five sections. This section is what to do when, despite all of it, the run misbehaves, which it will. The central reframing is diagnostic: training problems are not mysterious, they leave signatures in the curves and respond predictably to cheap probes. The most important habit is to assume a bug before assuming the model is fundamentally incapable, because the prior probability of a pipeline bug (a wrong label mapping, an un-shuffled loader, a forgotten normalization from Section 21.1, a frozen layer that should not be) vastly exceeds the probability that a well-known architecture simply cannot learn your task.

1. Reading the Curves: Three Canonical Shapes Beginner

The plot of training loss and validation loss against epochs is the single most informative diagnostic you have, and it has three readable shapes. Underfitting: both losses are high and roughly equal, and both are still falling or have plateaued at a poor value, meaning the model lacks the capacity, training time, or learning rate to fit even the training data. Overfitting: training loss keeps falling toward zero while validation loss bottoms out and then rises, the classic divergence that means the model is memorizing the training set at the expense of generalization. A healthy run: both losses fall together and the validation loss flattens near, but slightly above, the training loss. Figure 21.6.1 contrasts the three.

Underfitting both high, both flat Healthy fit both fall, small gap Overfitting train falls, val rises training loss validation loss
Figure 21.6.1: The three canonical training-curve shapes. Underfitting leaves both losses high and flat; a healthy run sees both fall with a small persistent gap; overfitting shows training loss diving while validation loss turns upward. The shape tells you which lever to reach for before you change anything.

Each shape prescribes a different fix. For underfitting, increase capacity, train longer, raise the learning rate, or reduce regularization and augmentation, because the model needs to fit harder. For overfitting, do the opposite: add the augmentation of Section 21.2, increase weight decay or dropout (the random unit-masking regularizer from Chapter 19), get more data, or simply early-stop at the validation minimum. A healthy run needs no intervention except to let it finish. The mistake to avoid is reaching for a random knob; the curve already told you which direction to move. The same loss-and-accuracy curves return for the detection and segmentation tasks of Chapter 23 and Chapter 24, where the validation metric is mean average precision (mAP) or mean intersection over union (mIoU) rather than top-1 accuracy but the three shapes read identically.

2. The Single Most Valuable Sanity Check: Overfit One Batch Beginner

Before launching a full run, do this every time: take a single small batch and train on it repeatedly until the model fits it perfectly, driving the loss to near zero. A model with enough capacity should memorize a handful of examples easily within a few dozen steps. If it cannot, your training is broken in a way that no amount of full-dataset training will fix, and you have found out in thirty seconds instead of six hours. This one check catches an enormous fraction of real bugs: a detached gradient, a learning rate of zero, a label-shape mismatch, a loss that is not connected to the parameters, a frozen layer that should be trainable.

import torch

def overfit_one_batch(model, batch, optimizer, criterion, steps=200):
    """A model that cannot drive ONE batch to ~0 loss has a bug, not a data problem."""
    images, labels = batch
    model.train()
    for step in range(steps):
        optimizer.zero_grad()
        loss = criterion(model(images), labels)
        loss.backward()
        optimizer.step()
        if step % 50 == 0:
            print(f"step {step:3d}  loss {loss.item():.4f}")
    return loss.item()

# Expected on a healthy setup (e.g. ResNet-18, batch of 16):
# step   0  loss 2.3107
# step  50  loss 0.0431
# step 100  loss 0.0009
# step 150  loss 0.0002      <- near zero: the optimization path works
# If the loss is stuck near 2.30 (= ln 10 for 10 classes), the gradient is not flowing.
Code Fragment 1: The overfit-one-batch check. overfit_one_batch loops the same batch through zero_grad, backward, and step for up to 200 iterations and prints the loss every 50. A correctly-wired model drives that single batch to near-zero loss well under 200 steps; a loss stuck at $\ln K$ (the value for $K$ equally-likely classes, $2.30$ for ten) signals a disconnected gradient or a dead learning rate.
Fun Fact

Overfitting is usually the villain of this chapter, the thing every regularizer is built to fight. The overfit-one-batch check is the one moment where you actively beg for it: you want the model to shamelessly memorize five images, and you are alarmed if it cannot. It is the only sanity test in deep learning where success means committing the exact sin the rest of the field spends its time preventing, which makes it a strangely satisfying thirty seconds.

Key Insight: Verify the Loss at Initialization

A second thirty-second check pairs with the first. Before any training, the loss on a fresh, randomly-initialized network should equal the loss of pure chance: for $K$-class cross-entropy that is $\ln K$ (about $2.30$ for ten classes, $6.91$ for a thousand). If your starting loss is wildly different, something is already wrong: a label-range bug (labels indexed from 1 instead of 0), a wrong number of output units, or a mis-applied loss reduction. Checking that the initial loss equals $\ln K$ confirms the output layer, the label encoding, and the loss are all wired correctly before you waste a single epoch. These two checks together, overfit-one-batch and loss-at-init, are the highest-value habit in this entire chapter.

Common Misconception: "Falling Training Loss Means Training Is Working"

It is tempting to read a smoothly decreasing training loss as proof that everything is correct. It is not. The loss only measures how well the model fits whatever targets the data loader hands it, so a pipeline that feeds shuffled, off-by-one, or otherwise wrong labels will produce a textbook-perfect falling loss while the model learns the wrong task, exactly the failure in this section's epigraph and the label-join bug in the practical example below. A driven-to-zero loss on one batch confirms the optimization path works; it says nothing about whether the labels are right. That is why the overfit-one-batch check and inspect_batch are separate steps: the first proves gradients flow, and only the second proves the model is being asked to learn the truth. Trust the validation metric and a visual batch check, never the training-loss curve alone.

3. Look at the Data the Model Actually Sees Intermediate

A startling fraction of training bugs are invisible in code and obvious in a picture. The discipline is to pull one batch straight from the data loader, after all transforms, and actually look at it: render the images, print the labels, and confirm they match. This single habit surfaces the augmentation-gone-wrong from Section 21.2 (a crop so aggressive the object is gone), the normalization error from Section 21.1 (images that look like noise because the channel order is swapped, the same RGB-versus-BGR channel-order trap first flagged in Chapter 0), and the off-by-one label shift from this section's epigraph (the model was learning perfectly, just to predict the neighbor's label), the exact failure the illustration below dramatizes. The code below dumps a labeled grid.

A diligent cartoon robot student proudly matches picture cards to name tags, but every tag is shifted one slot over so it confidently labels the cat with the dog's name, illustrating the off-by-one label bug where a model trains perfectly while learning the wrong answer, the kind of pipeline bug a quick look at the data reveals.
A smoothly falling loss can mean the model is perfectly learning the wrong task; look at one real batch before you blame the optimizer.
import torch, torchvision

def inspect_batch(loader, class_names, mean, std):
    """Render one real, post-transform batch with its labels. Look at it."""
    images, labels = next(iter(loader))
    # Undo normalization so the saved image is human-viewable.
    m = torch.tensor(mean).view(1, 3, 1, 1)
    s = torch.tensor(std).view(1, 3, 1, 1)
    vis = (images * s + m).clamp(0, 1)
    grid = torchvision.utils.make_grid(vis[:16], nrow=4)
    torchvision.utils.save_image(grid, "batch_check.png")
    print("labels:", [class_names[i] for i in labels[:16].tolist()])
    print("pixel range:", vis.min().item(), "to", vis.max().item())   # should be ~0..1
    print("batch shape:", tuple(images.shape))

# Open batch_check.png and confirm: do the images look right, and do the
# printed labels match what you SEE in each tile? Most data bugs die here.
Code Fragment 2: Dumping one real post-transform batch as a labeled image grid. inspect_batch pulls a batch with next(iter(loader)), undoes normalization with the stored mean and std so the tiles are human-viewable, saves a 4-by-4 grid, and prints the matching labels and pixel range. The printed labels must match the visible content, which is the check that catches label shifts, broken augmentation, and channel-order mistakes.
Practical Example: Six Hours Versus Six Minutes

Who: a junior engineer training a multi-class classifier on a new internal dataset, 2025. Situation: the loss started at a plausible value but barely decreased over a full overnight run, and validation accuracy hovered at chance. Problem: the team's first instinct was to tune the learning rate and try a different optimizer, burning most of a day on full re-runs. Decision: a senior engineer instead ran two checks. The loss-at-init was correct ($\ln K$), ruling out an output or label-encoding bug. But inspect_batch revealed the problem instantly: the custom dataset class read labels from a CSV that had been sorted by filename while the images were loaded in directory order, so every image carried a randomly-mismatched label. The model was being asked to learn noise. Result: fixing the label join, a two-line change, made the next run converge normally. Lesson: the six-hour learning-rate hunt could not have worked because the data was wrong, and six minutes of sanity checks would have found it before the overnight run ever started. Look at the data first.

4. A Playbook for "It Will Not Learn" Intermediate

When a model refuses to learn, work the checks in order of cost, cheapest first, so you spend probe time in proportion to how likely and how cheap each cause is. The ordered playbook below is the distilled habit. It deliberately puts the data and wiring checks before any hyperparameter tuning, because a bug there makes all downstream tuning meaningless, exactly the trap the practical example fell into.

Table 21.6.1: The ordered debugging playbook for a run that will not learn.
StepCheckWhat it rules out
1Verify loss at init equals $\ln K$Wrong output count, label-range bug, loss misconfiguration
2inspect_batch: view images + labelsLabel mismatch, broken augmentation, bad normalization
3Overfit a single batch to ~0 lossDisconnected gradient, dead learning rate, frozen layers
4Confirm the loader shuffles training dataOrdered batches that stall or bias the optimizer
5Sweep learning rate over decades ($10^{-5}$ to $10^{-1}$)Too-small (no progress) or too-large (divergence) LR
6Reduce regularization and augmentation, retryOver-regularization masquerading as underfitting
7Only now: change architecture or capacityGenuine model-capacity limits (rarely the real cause)

Two recurring culprits deserve a direct mention because they are common and easily missed. An exploding or NaN loss almost always means the learning rate is too high or you forgot gradient clipping; halve the rate or clip the gradient norm (rescale the gradient vector so its overall magnitude never exceeds a set ceiling, which caps any single runaway update). A loss that is flat from step zero usually means the gradient is not reaching the parameters at all, a detached tensor, a layer left in evaluation mode, or requires_grad mistakenly left False from the freezing of Section 21.3. Both show up immediately in the overfit-one-batch check, which is why it sits so high in the playbook. The Hands-On Lab at the end of this section is where the whole chapter comes together: you will assemble the data, augmentation, transfer, recipe, and sanity-check links into one script that trains and evaluates a real classifier on a small custom set.

Library Shortcut: Curves, Sanity Checks, and Overfit-One-Batch For Free

Hand-logging metrics and writing your own sanity-check harness is fine, but training frameworks bundle all of it. PyTorch Lightning, for example, has a built-in flag that runs the overfit-one-batch check and exposes automatic logging to TensorBoard or Weights and Biases:

import lightning as L

trainer = L.Trainer(
    max_epochs=100,
    overfit_batches=1,        # built-in: overfit ONE batch to validate the pipeline
    fast_dev_run=False,       # set True to run 1 train+val batch end-to-end as a smoke test
    log_every_n_steps=10,     # automatic loss/metric logging for the curves of subsection 1
)
# trainer.fit(model, train_loader, val_loader)
# Curves stream to TensorBoard; overfit_batches=1 reproduces subsection 2 with one flag.
Code Fragment 3: The sanity checks of subsections 2 and 3 as Lightning Trainer flags. overfit_batches=1 reproduces the overfit-one-batch test of Code Fragment 1, fast_dev_run runs one train-plus-validation batch as a smoke test, and log_every_n_steps streams the loss curves of subsection 1 to TensorBoard, turning the manual harnesses into one-line configuration.

The framework handles the metric logging, the train/validation curve plotting, the device placement, and the overfit-one-batch and smoke-test (fast_dev_run) routines as configuration flags. It turns the manual harnesses of subsections 2 and 3 into one-line settings, and the experiment-tracking integration means your curves are recorded and comparable across runs without extra code.

Research Frontier: Scaling Laws and Predicting the Curve Before You Run It

A profound 2020-2026 development reframes debugging from reactive to predictive. Neural scaling laws show that, for a fixed architecture family, final loss follows a smooth power-law in model size, dataset size, and compute, so you can fit the law on a few small, cheap runs and extrapolate the loss of an expensive large run before launching it. A measured curve that falls below or above the predicted scaling line is itself a debugging signal: well below suggests a leak or contamination (the too-good-to-be-true flag from Section 21.1), well above suggests a recipe or pipeline bug. The same machinery powers compute-optimal training (the Chinchilla result on balancing model and data size) and the data-quality scaling work behind the foundation models of Chapter 25. The humble overfit-one-batch check and the grand scaling law are two ends of one idea: the loss curve is predictable, and a curve that defies prediction is telling you exactly where to look.

Hands-On Lab: Train and Evaluate a Classifier With the Full Recipe
Difficulty: Intermediate Duration: 60 to 90 minutes

Build one self-contained script, train_recipe.py, that takes a pretrained backbone and trains it on a small real dataset (Oxford-IIIT Pets, 37 classes) using every link in this chapter's chain: normalization and a proper split (Section 21.1), an augmentation pipeline with RandAugment and MixUp (Section 21.2), a re-headed fine-tuned backbone (Section 21.3), the modern recipe of AdamW with warmup-plus-cosine and label smoothing (Section 21.4), and the sanity checks of this section run before the full run. The deliverable is a trained model, a validation accuracy you can quote, and a saved loss-and-accuracy curve you can read with the three shapes of subsection 1.

What You'll Practice

  • Loading a real dataset with the correct ImageNet normalization and a clean train/validation split (Section 21.1).
  • Composing an augmentation pipeline (resize, crop, flip, RandAugment) and applying MixUp in the training step (Section 21.2).
  • Re-heading a pretrained backbone for a new label set and fine-tuning rather than training from scratch (Section 21.3).
  • Assembling the modern recipe: AdamW with weight decay, a warmup-plus-cosine schedule, and label smoothing (Section 21.4).
  • Running the loss-at-init and overfit-one-batch sanity checks before launching, then reading the training curve (this section).

Setup

pip install torch torchvision timm matplotlib

A GPU shortens training to a few minutes, but every step runs on CPU at a smaller batch size. The dataset, torchvision.datasets.OxfordIIITPet, downloads automatically on first run (about 800 MB) and splits into a trainval and a test partition you will use as train and validation.

Put the chapter's full recipe into practice below. Work the steps in order; each prints a checkpoint so you can confirm progress before the long step. A complete reference solution is folded at the end.

Step 1: Load the data with normalization and an augmentation pipeline

Two transform pipelines: a heavy one for training (random crop, flip, RandAugment) and a deterministic one for validation, both ending in the ImageNet normalization that every pretrained backbone expects (Section 21.1).

import torch
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader

MEAN, STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

train_tf = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=9),   # automated policy, Section 21.2
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])
# TODO: build val_tf with a deterministic Resize(256) + CenterCrop(224),
#       then ToTensor and the SAME Normalize(MEAN, STD). No randomness here.
val_tf = ...

train = OxfordIIITPet("./data", split="trainval", download=True, transform=train_tf)
val   = OxfordIIITPet("./data", split="test",     download=True, transform=val_tf)
train_loader = DataLoader(train, batch_size=32, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val,   batch_size=64, shuffle=False, num_workers=0)
print(f"train={len(train)}  val={len(val)}  classes={len(train.classes)}")
Hint

The validation pipeline must be deterministic so the metric is reproducible: transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), then the identical Normalize(MEAN, STD). Augmentation belongs on the training set only.

Step 2: Re-head a pretrained backbone

Load a pretrained ResNet-50 and swap its 1000-class ImageNet head for a 37-class one. This is transfer learning's central move (Section 21.3): reuse the features, replace only the classifier.

import timm

device = "cuda" if torch.cuda.is_available() else "cpu"
# TODO: create a pretrained "resnet50" with num_classes set to the dataset's
#       class count, then move it to `device`. Hint: timm.create_model(...).
model = ...
print("head out features:", model.get_classifier().out_features)  # expect 37
Hint

timm.create_model("resnet50", pretrained=True, num_classes=len(train.classes)).to(device) reuses the pretrained trunk and re-initializes a 37-way head. The num_classes argument does the head surgery for you.

Step 3: Run the sanity checks before training

Never launch a long run unverified. Check the loss at initialization against $\ln K$ and confirm the model can overfit one batch, the two highest-value checks of this section.

import math
import torch.nn as nn

xb, yb = next(iter(train_loader))
xb, yb = xb.to(device), yb.to(device)

# Loss at init should be about ln(K): the chance level for K classes.
with torch.no_grad():
    init_loss = nn.CrossEntropyLoss()(model(xb), yb).item()
print(f"loss at init = {init_loss:.3f}  (expected about {math.log(len(train.classes)):.3f})")

# Overfit ONE batch: a healthy model drives it toward zero in well under 100 steps.
probe = nn.CrossEntropyLoss()
opt_probe = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train()
for step in range(60):
    opt_probe.zero_grad()
    loss = probe(model(xb), yb)
    loss.backward(); opt_probe.step()
    # TODO: every 20 steps, print step and loss.item() so you can watch it fall.
    ...
Hint

A guard like if step % 20 == 0: print(f"overfit step {step:3d} loss {loss.item():.3f}") shows the loss diving toward zero. If it stalls near $\ln 37 \approx 3.6$, stop and fix the wiring before training. Re-create the model afterward so this probe does not pollute the real run.

Step 4: Assemble the modern recipe

Build the optimizer, schedule, and loss that every strong 2024-2026 result uses together (Section 21.4): AdamW with decoupled weight decay, a linear warmup into cosine decay, and cross-entropy with label smoothing.

from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

model = timm.create_model("resnet50", pretrained=True,
                          num_classes=len(train.classes)).to(device)  # fresh after probe
EPOCHS = 8
steps_per_epoch = len(train_loader)
warmup_steps = steps_per_epoch  # one warmup epoch

opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
# TODO: build a SequentialLR that does LinearLR warmup for `warmup_steps`,
#       then CosineAnnealingLR for the remaining steps. Milestone = [warmup_steps].
sched = ...
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)   # label smoothing, Section 21.4
Hint

Compose with SequentialLR(opt, schedulers=[LinearLR(opt, start_factor=0.01, total_iters=warmup_steps), CosineAnnealingLR(opt, T_max=EPOCHS*steps_per_epoch - warmup_steps)], milestones=[warmup_steps]), and call sched.step() once per optimizer step, not per epoch.

Step 5: Train with MixUp in the loop

Run the training epochs. Inside each step apply MixUp (Section 21.2): blend pairs of images and their labels with a Beta-sampled weight, then use the mixed-target cross-entropy. Step the scheduler every iteration.

import numpy as np

def mixup(x, y, num_classes, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0), device=x.device)
    mixed_x = lam * x + (1 - lam) * x[idx]
    y1 = nn.functional.one_hot(y, num_classes).float()
    mixed_y = lam * y1 + (1 - lam) * y1[idx]   # mix the LABELS too
    return mixed_x, mixed_y

for epoch in range(EPOCHS):
    model.train()
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        mx, my = mixup(xb, yb, len(train.classes))
        opt.zero_grad()
        # Soft-target cross-entropy: -sum(target * log_softmax(logits)).
        logp = nn.functional.log_softmax(model(mx), dim=1)
        loss = -(my * logp).sum(dim=1).mean()
        loss.backward()
        opt.step()
        sched.step()   # TODO: confirm this is per-STEP, not per-epoch
    print(f"epoch {epoch}  last train loss {loss.item():.3f}  lr {sched.get_last_lr()[0]:.2e}")
Hint

With MixUp the targets are soft (a blend of two one-hot vectors), so the standard CrossEntropyLoss on hard labels does not apply during training; the soft-target form -(target * log_softmax(logits)).sum(1).mean() is the correct loss. Keep loss_fn with label smoothing for any non-MixUp baseline you compare against.

Step 6: Evaluate on the held-out validation set

Switch to eval mode (this freezes batch-norm statistics and disables dropout) and compute top-1 accuracy on the validation loader, the number that actually reports whether the recipe worked.

@torch.no_grad()
def accuracy(model, loader):
    model.eval()
    correct = total = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        preds = model(xb).argmax(dim=1)
        # TODO: accumulate correct predictions and the total count
        ...
    return 100.0 * correct / total

print(f"validation top-1 accuracy: {accuracy(model, val_loader):.2f}%")
Hint

correct += (preds == yb).sum().item() and total += yb.size(0). The @torch.no_grad() decorator and model.eval() are both required: the first saves memory, the second makes batch-norm use running statistics instead of the batch.

Step 7: Plot and read the curve

Record per-epoch training loss and validation accuracy and plot them. Reading this curve with the three shapes of subsection 1 is what turns a finished run into a diagnosis.

import matplotlib.pyplot as plt
# Assume you appended to `train_losses` and `val_accs` (one value per epoch) above.
fig, ax1 = plt.subplots(figsize=(6, 4))
ax1.plot(train_losses, color="tab:blue", label="train loss")
ax1.set_xlabel("epoch"); ax1.set_ylabel("train loss", color="tab:blue")
ax2 = ax1.twinx()
ax2.plot(val_accs, color="tab:orange", label="val acc")
ax2.set_ylabel("val accuracy (%)", color="tab:orange")
plt.title("Recipe training curve"); plt.tight_layout()
plt.savefig("recipe_curve.png", dpi=120)
print("saved recipe_curve.png")
Hint

Collect the metrics by initializing train_losses, val_accs = [], [] before the epoch loop and appending the last training loss and accuracy(model, val_loader) at the end of each epoch. A healthy curve shows train loss falling while validation accuracy rises and then flattens.

Expected Output

The sanity checks print a loss at init near $\ln 37 \approx 3.61$ and an overfit-one-batch loss diving below $0.1$ within sixty steps, confirming the pipeline before the real run. Training prints eight epochs with the learning rate rising during the warmup epoch and then decaying along the cosine curve, and the per-step MixUp loss noisier than a plain run because the targets are blended. On a GPU the whole run finishes in a few minutes; validation top-1 accuracy lands around 88 to 93 percent, far above what training the same ResNet-50 from scratch on this 37-class, roughly 3700-image set would reach, the transfer-learning payoff of Section 21.3 made concrete. The saved curve shows training loss falling steadily while validation accuracy climbs and plateaus, the healthy shape of Figure 21.6.1.

Stretch Goals

  • Ablate the recipe: rerun with MixUp off, then with label smoothing off, then with the schedule replaced by a constant learning rate, and tabulate the validation accuracy each change costs. This reproduces the chapter's thesis that the recipe, not the architecture, moves the number.
  • Add model EMA (an exponential moving average of the weights, Section 21.4): keep a shadow copy updated each step and evaluate it instead of the raw model. Report the accuracy difference.
  • Make the validation split class-imbalanced by subsampling a few breeds heavily, then add the class-balanced or focal loss of Section 21.5 and measure whether per-class recall on the rare breeds improves.
Complete Solution
import math
import numpy as np
import torch
import torch.nn as nn
import timm
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet

device = "cuda" if torch.cuda.is_available() else "cpu"
MEAN, STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

# Step 1: data with augmentation + normalization
train_tf = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])
val_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])
train = OxfordIIITPet("./data", split="trainval", download=True, transform=train_tf)
val   = OxfordIIITPet("./data", split="test",     download=True, transform=val_tf)
K = len(train.classes)
train_loader = DataLoader(train, batch_size=32, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val,   batch_size=64, shuffle=False, num_workers=0)

# Step 3: sanity checks (on a throwaway model)
probe_model = timm.create_model("resnet50", pretrained=True, num_classes=K).to(device)
xb, yb = next(iter(train_loader))
xb, yb = xb.to(device), yb.to(device)
with torch.no_grad():
    print(f"loss at init = {nn.CrossEntropyLoss()(probe_model(xb), yb).item():.3f} "
          f"(expected about {math.log(K):.3f})")
opt_probe = torch.optim.Adam(probe_model.parameters(), lr=1e-3)
probe_model.train()
for step in range(60):
    opt_probe.zero_grad()
    loss = nn.CrossEntropyLoss()(probe_model(xb), yb)
    loss.backward(); opt_probe.step()
    if step % 20 == 0:
        print(f"overfit step {step:3d} loss {loss.item():.3f}")
del probe_model, opt_probe

# Step 2 + 4: fresh model and the modern recipe
model = timm.create_model("resnet50", pretrained=True, num_classes=K).to(device)
EPOCHS = 8
steps_per_epoch = len(train_loader)
warmup_steps = steps_per_epoch
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
sched = SequentialLR(
    opt,
    schedulers=[
        LinearLR(opt, start_factor=0.01, total_iters=warmup_steps),
        CosineAnnealingLR(opt, T_max=EPOCHS * steps_per_epoch - warmup_steps),
    ],
    milestones=[warmup_steps],
)

# Step 5: MixUp training loop
def mixup(x, y, num_classes, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0), device=x.device)
    mixed_x = lam * x + (1 - lam) * x[idx]
    y1 = nn.functional.one_hot(y, num_classes).float()
    mixed_y = lam * y1 + (1 - lam) * y1[idx]
    return mixed_x, mixed_y

@torch.no_grad()
def accuracy(model, loader):
    model.eval()
    correct = total = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        preds = model(xb).argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    return 100.0 * correct / total

train_losses, val_accs = [], []
for epoch in range(EPOCHS):
    model.train()
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        mx, my = mixup(xb, yb, K)
        opt.zero_grad()
        logp = nn.functional.log_softmax(model(mx), dim=1)
        loss = -(my * logp).sum(dim=1).mean()
        loss.backward()
        opt.step()
        sched.step()
    acc = accuracy(model, val_loader)
    train_losses.append(loss.item()); val_accs.append(acc)
    print(f"epoch {epoch}  train loss {loss.item():.3f}  "
          f"val acc {acc:.2f}%  lr {sched.get_last_lr()[0]:.2e}")

# Step 6 + 7: final accuracy and curve
print(f"final validation top-1 accuracy: {val_accs[-1]:.2f}%")
fig, ax1 = plt.subplots(figsize=(6, 4))
ax1.plot(train_losses, color="tab:blue", label="train loss")
ax1.set_xlabel("epoch"); ax1.set_ylabel("train loss", color="tab:blue")
ax2 = ax1.twinx()
ax2.plot(val_accs, color="tab:orange", label="val acc")
ax2.set_ylabel("val accuracy (%)", color="tab:orange")
plt.title("Recipe training curve"); plt.tight_layout()
plt.savefig("recipe_curve.png", dpi=120)
print("saved recipe_curve.png")
Exercise 21.6.1: Diagnose From the Curve Alone Conceptual

For each described curve, name the state (underfitting, overfitting, healthy, or broken pipeline) and the single best first action: (a) training loss falls to $0.01$, validation loss falls then climbs from $0.4$ to $0.8$; (b) both losses sit flat at $2.30$ from step zero on a ten-class problem; (c) both losses fall steadily and end close together at $0.3$ and $0.35$; (d) both losses are stuck high at $1.8$ and still slowly falling after many epochs. Justify each diagnosis from the curve shapes of subsection 1.

Exercise 21.6.2: Plant Three Bugs, Catch Them With Checks Coding

Start from a working CIFAR-10 training script. Create three broken copies, each with exactly one planted bug: (a) requires_grad = False left on the whole model, (b) labels shifted by one with modulo wraparound, (c) the learning rate set to $0$. For each, run only the loss-at-init check, the inspect_batch check, and the overfit-one-batch check, and record which check catches which bug. Confirm that the three cheap checks together catch all three bugs without a single full training run.

Exercise 21.6.3: Build Your Own Learning-Rate Range Test Analysis

Implement a learning-rate range test: starting from a tiny learning rate, train for a few hundred steps while exponentially increasing the rate each step, and plot loss against learning rate on a log scale. Identify the range where loss falls fastest (a good peak learning rate sits roughly an order of magnitude below where the loss starts to diverge). Compare the rate you find to the one used in the recipe of Section 21.4, and write a paragraph on how this test removes the guesswork from step 5 of the playbook in Table 21.6.1.