Part III: Deep Learning for Computer Vision
Chapter 28: Efficient Vision & Edge Deployment

The Efficiency Toolbox: Quantization, Pruning & Distillation

"My teacher had four hundred layers and a god complex. I have eighteen layers and a deadline. Somehow I score within a point of her, and I run on a phone she would not deign to charge."

A Distilled Student Network With No Time to Waste
Big Picture

You can usually make a trained vision model four to ten times smaller and faster while losing only a percentage point or two of accuracy, because the model you trained is enormously over-parameterized for the task it actually performs. Three techniques do most of the work. Quantization lowers the numeric precision of weights and activations from 32-bit float to 8-bit integer, shrinking memory by four and replacing slow floating-point math with fast integer math. Pruning deletes the weights and channels that contribute least, leaving a smaller network. Distillation trains a small student to mimic a large teacher, transferring accuracy that the student cannot reach from labels alone. This section builds an honest mental model of each, with runnable code, and shows how they stack.

You have spent Part III making models more accurate. This section is the first where we deliberately make a model worse, in a controlled way, in exchange for speed and size. That trade is the heart of deployment. The model that comes out of Chapter 21 at the top of the validation curve is carrying weights at 32-bit floating-point precision, dense connectivity it does not need, and a parameter count chosen for trainability rather than inference cost. Each of the three techniques here attacks one of those excesses, as the illustration below sketches. We take them in turn, then combine them, and throughout we measure rather than assume, because the only honest claim about a compressed model is one backed by a number on the target metric.

A bulky friendly robot at a makeover gym where three helpers round its edges, snip off extra arms, and a tiny apprentice mimics it, leaving a smaller faster robot that fits on a phone, depicting quantization, pruning, and distillation shrinking a vision model.
A trained model arrives over-built; quantization, pruning, and distillation each trim a different kind of excess until it fits the hardware.

1. Why Over-Parameterization Leaves Room Beginner

A modern vision network has far more parameters than the information content of its task requires. A ResNet-50 from Chapter 20 has roughly 25 million parameters and reaches about 76 percent top-1 accuracy on ImageNet; a MobileNetV2 with under 4 million parameters reaches about 72 percent. The larger model is not six times more accurate; it is a few points more accurate. Most of the extra capacity is there to make optimization easy, not because the final function needs it. This is the slack that compression exploits. After training, the redundancy can be squeezed out: weights can be stored coarsely, many can be removed entirely, and the learned function can be re-expressed in a much smaller network.

It helps to separate the two costs we are trying to reduce. The first is memory: how many bytes the weights occupy, which sets the download size, the on-device storage, and the memory bandwidth burned moving weights to the compute units. The second is compute: how many arithmetic operations a forward pass performs, usually counted in multiply-accumulates or floating-point operations (FLOPs), which together with the hardware's throughput sets latency. Quantization attacks both at once. Pruning attacks both if it removes whole structures, or only memory if it leaves scattered zeros that the hardware still has to multiply. Distillation attacks both by changing the architecture entirely. Figure 28.1.1 places the three techniques on these two axes.

Three ways to shrink a trained vision model Quantization float32 to int8 same shape coarser numbers memory: 4x smaller compute: int math cost: rounding error Pruning remove weights or channels smaller network memory: fewer params compute: if structured cost: capacity lost Distillation new small student imitates teacher new architecture memory: by design compute: by design cost: needs teacher All three compose: distill, then prune, then quantize the result.
Figure 28.1.1: The three model-level compression techniques and what each reduces. Quantization keeps the architecture and lowers numeric precision. Pruning removes parameters or whole channels. Distillation replaces the architecture with a smaller one trained to imitate the original. They are complementary and routinely stacked.
Key Insight: Measure the Metric, Not the Megabytes

It is tempting to report compression as a size ratio: "we shrank the model by 4x." That number is necessary but not sufficient. A model that is 4x smaller but 5 points less accurate may be useless for the product, and a model that is 4x smaller and 0.3 points less accurate may be a free win. Every compression decision in this chapter is a point on an accuracy-versus-cost curve, and the only way to know where you are on that curve is to evaluate the compressed model on the same metric, the same way, that you evaluated the original. Compress, then measure, then decide. A size number without an accuracy number is half a sentence.

2. Quantization: Computing in Eight Bits Intermediate

Quantization maps a continuous range of float values onto a small set of integers. The standard scheme is affine (asymmetric) quantization to 8-bit integers. Given a tensor of floats with observed range $[\beta_{\min}, \beta_{\max}]$, we pick a positive scale $s$ and an integer zero-point $z$ so that a float value $x$ maps to an integer $q$ and back:

$$q = \mathrm{round}\!\left(\frac{x}{s}\right) + z, \qquad \hat{x} = s\,(q - z)$$

The scale $s = (\beta_{\max} - \beta_{\min}) / (2^b - 1)$ for $b$ bits spreads the float range across the $2^b$ integer levels (256 levels for int8), and the zero-point $z$ is the integer that the real value zero maps to, which lets the scheme represent asymmetric ranges exactly at zero (important because padding and ReLU produce many exact zeros). The reconstructed value $\hat{x}$ differs from $x$ by at most half a step, $s/2$; this rounding error is the entire accuracy cost of quantization, and keeping $s$ small (a tight range) keeps it small.

Why does throwing away 24 of the 32 bits barely hurt? A float32 weight can take any of roughly four billion distinct values, but the network never needed that resolution. Snapping each weight to its nearest of only 256 evenly spaced levels moves it by at most half a step. That perturbation is smaller than the noise the weight already survived during training, so summed over a layer the rounding errors partly cancel and the output barely moves. The model was carrying 32 bits of precision to make optimization smooth, not because the learned function needed it. The code below implements the round trip so you can see the error directly.

import torch

def quantize_affine(x, num_bits=8):
    """Affine int8 quantization of a float tensor; returns ints, scale, zero-point."""
    qmin, qmax = 0, 2 ** num_bits - 1
    beta_min, beta_max = x.min().item(), x.max().item()
    scale = (beta_max - beta_min) / (qmax - qmin)        # float range per integer step
    zero_point = round(qmin - beta_min / scale)          # integer that maps to real 0.0
    zero_point = max(qmin, min(qmax, zero_point))         # clamp into the integer range
    q = torch.clamp(torch.round(x / scale) + zero_point, qmin, qmax).to(torch.int32)
    return q, scale, zero_point

def dequantize_affine(q, scale, zero_point):
    """Map integers back to floats: x_hat = scale * (q - zero_point)."""
    return scale * (q.float() - zero_point)

w = torch.randn(1000) * 0.1                  # a typical weight tensor, small magnitude
q, s, z = quantize_affine(w)
w_hat = dequantize_affine(q, s, z)
err = (w - w_hat).abs().max().item()
print(f"scale={s:.6f}  zero_point={z}  max abs error={err:.6f}")
print(f"theoretical bound (scale/2) = {s/2:.6f}")
# scale=0.002545  zero_point=131  max abs error=0.001268
# theoretical bound (scale/2) = 0.001273
Code Fragment 1: Affine int8 quantization round trip on a weight tensor. The measured maximum reconstruction error sits just under the theoretical half-step bound $s/2$, confirming that the rounding error is bounded by the quantization step and that a tighter float range (smaller scale) means smaller error.

The reason this is fast, not just small, is that the actual matrix multiplications can run in integer arithmetic. With both weights and activations quantized, a convolution or linear layer multiplies int8 values and accumulates into int32, and modern CPUs and GPUs run int8 multiply-accumulate at two to four times the throughput of float32, with a quarter of the memory traffic. The scale and zero-point bookkeeping is applied once per tensor (or per channel) to rescale the int32 accumulator back to the next layer's input range.

2.1 Post-Training Quantization Versus Quantization-Aware Training

There are two ways to get a quantized model, and they trade engineering effort against accuracy. Post-training quantization (PTQ) takes an already-trained float model and quantizes it directly, using a small unlabeled calibration set to observe the activation ranges $[\beta_{\min}, \beta_{\max}]$ at each layer. It is fast (minutes, no training) and needs no labels, and for many convolutional networks it costs well under a point of accuracy. Quantization-aware training (QAT) goes further: it inserts simulated quantization (the round-trip above) into the forward pass during a short fine-tune, so the network learns weights that are robust to the rounding. QAT is more work and needs labeled data and a training loop, but it recovers most of the accuracy that PTQ leaves on the table, which matters for sensitive models like detectors and for aggressive sub-8-bit precision.

QAT has one trainability problem to solve first. The key trick that solves it is the straight-through estimator: the non-differentiable rounding is treated as the identity during the backward pass, so gradients flow through it. This trick is necessary because the rounding function is flat between every pair of integer levels, so its true derivative is zero almost everywhere; back-propagating that honest gradient would deliver exactly zero to every weight and the fine-tune could never learn, so the estimator substitutes a derivative of one and lets the forward-pass rounding error guide the update instead. The code below applies PTQ to a torchvision model.

import torch
from torchvision.models import resnet18, ResNet18_Weights

# 1. A trained float model, set to eval mode (quantization fuses BN into conv).
model_fp32 = resnet18(weights=ResNet18_Weights.DEFAULT).eval()

# 2. Configure post-training static quantization for the x86 backend.
from torch.ao.quantization import get_default_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

qconfig = get_default_qconfig("x86")              # per-channel weights, per-tensor activations
qconfig_mapping = QConfigMapping().set_global(qconfig)
example = torch.randn(1, 3, 224, 224)

prepared = prepare_fx(model_fp32, qconfig_mapping, example)   # inserts observers

# 3. Calibrate: run a few unlabeled batches so observers learn activation ranges.
with torch.no_grad():
    for _ in range(32):                            # ~32 representative batches is plenty
        prepared(torch.randn(8, 3, 224, 224))

# 4. Convert observers into real int8 quantized ops.
model_int8 = convert_fx(prepared)

def size_mb(m):
    torch.save(m.state_dict(), "tmp.pt")
    import os; mb = os.path.getsize("tmp.pt") / 1e6; os.remove("tmp.pt"); return mb

print(f"fp32 size = {size_mb(model_fp32):.1f} MB")   # fp32 size = 46.8 MB
print(f"int8 size = {size_mb(model_int8):.1f} MB")   # int8 size = 11.9 MB
Code Fragment 2: Post-training static int8 quantization of a ResNet-18 with the PyTorch FX quantization workflow. Calibration runs unlabeled batches so the observers record activation ranges; the converted model is roughly a quarter of the float size. Per-channel weight quantization (one scale per output channel) is what keeps the accuracy drop small.

Notice the calibration step. Post-training quantization needs to see real (or realistic) inputs so its observers can record the activation ranges; quantizing with ranges that do not match deployment is a common and silent accuracy killer. We will return to exactly this point in Section 28.5, because the calibration distribution is the first thing that drifts in production.

Library Shortcut: Quantized Export in One Flag

The FX workflow above is the from-scratch path that teaches you what calibration is. In a deployment pipeline built on a model zoo, the same int8 quantization is usually one argument on the export call. Ultralytics, for example, quantizes and exports a YOLO detector to int8 ONNX or TensorRT with a single line:

from ultralytics import YOLO

model = YOLO("yolo11n.pt")
# int8=True triggers calibration on the given dataset and writes a quantized engine.
model.export(format="engine", int8=True, data="coco128.yaml")
# -> yolo11n.engine, an int8 TensorRT engine ready for the GPU runtime of Section 28.2
Code Fragment 3: The same post-training int8 quantization as Code Fragment 2, but in three lines using Ultralytics. The single int8=True flag on model.export triggers calibration over coco128.yaml and emits a ready-to-run yolo11n.engine, folding the observer insertion and per-channel scale selection of the FX workflow into one argument. You still supply the calibration data, because no library can guess what your deployment inputs look like.

The library handles observer insertion, calibration-set iteration, per-channel scale selection, and the runtime-specific packaging, perhaps eighty lines of careful code, behind int8=True. You still owe it a representative calibration dataset; that responsibility never disappears.

3. Pruning: Removing What Does Not Earn Its Place Intermediate

Pruning sets some weights to zero permanently. The simplest and most effective criterion is magnitude pruning: weights with the smallest absolute value contribute least to the output, so remove the bottom $p$ percent by magnitude. The subtlety is in what shape the removed weights take, and it determines whether you save compute or only memory.

Unstructured pruning zeroes individual weights anywhere in a tensor, producing a sparse matrix. It can reach very high sparsity (90 percent and beyond) with modest accuracy loss, but the zeros are scattered, so a dense matrix multiply still touches them; you save storage (a sparse format) but not wall-clock time unless the hardware has dedicated sparse-matrix support. Structured pruning removes whole units, an entire output channel of a convolution, an entire attention head, an entire neuron, so the resulting tensor is simply smaller and dense. Structured pruning saves real compute on ordinary hardware because the network genuinely has fewer channels to convolve, which is why it is the form that matters most for the latency budgets of this chapter. The trade is that it is coarser and tends to cost more accuracy per parameter removed. Figure 28.1.2 contrasts the two on a single weight matrix.

Two kinds of pruning on the same 6x6 weight matrix Unstructured (scattered zeros) same shape; zeros still multiplied saves memory, not compute Structured (whole rows removed) pruned pruned genuinely smaller dense matrix saves memory and compute
Figure 28.1.2: Unstructured pruning (left) zeroes scattered individual weights; the matrix keeps its shape and ordinary hardware still multiplies the zeros, so only memory is saved. Structured pruning (right) removes whole rows (here, whole output channels); the matrix is genuinely smaller and dense, so both memory and compute drop on any hardware.

The other half of pruning is recovery. Removing weights damages the network, but a short fine-tune (often called the prune-then-finetune or iterative-magnitude-pruning loop) lets the surviving weights compensate, recovering most of the lost accuracy. The lottery-ticket hypothesis of Frankle and Carbin gave this a striking interpretation: a dense network contains a sparse subnetwork that could have been trained to the same accuracy on its own, and pruning-and-finetuning is one way to find it. The code below applies structured channel pruning to a single convolution and shows the recovery pattern in miniature.

import torch
import torch.nn.utils.prune as prune
from torchvision.models import resnet18, ResNet18_Weights

model = resnet18(weights=ResNet18_Weights.DEFAULT)
conv = model.layer4[1].conv2          # a 512-channel 3x3 convolution

# Structured pruning: remove 30% of OUTPUT channels (dim=0) by L2 norm.
# This is the form that actually shrinks compute, not just storage.
prune.ln_structured(conv, name="weight", amount=0.30, n=2, dim=0)

# Count channels whose entire weight is now zero.
mask = conv.weight_mask                       # prune stores a 0/1 mask
dead = (mask.sum(dim=(1, 2, 3)) == 0).sum().item()
print(f"channels pruned: {dead} of {mask.shape[0]}")   # channels pruned: 153 of 512

# After a short fine-tune you would call prune.remove to bake the mask in:
prune.remove(conv, "weight")                  # makes the pruning permanent
print("weight is now a plain, sparse-by-channel tensor again")
Code Fragment 4: Structured channel pruning of one ResNet-18 convolution with torch.nn.utils.prune. ln_structured with dim=0 ranks output channels by their L2 norm and zeroes the weakest 30 percent; in a real workflow you then fine-tune to recover accuracy before calling prune.remove to make the change permanent. To realize the speedup you must additionally rebuild the layer with the surviving channels, which libraries like Torch-Pruning automate.
Common Misconception: Sparsity Is Not Speed

It is natural to assume that pruning 50 percent of a network's weights makes it run twice as fast. It usually does not. Zeroing scattered weights (unstructured pruning) leaves the tensor the same shape, and ordinary CPU and GPU matrix multiplies still load and multiply those zeros, so you save storage but the wall-clock latency barely moves. A 90-percent-sparse ResNet can run no faster than the dense original on a standard GPU. Speed comes only from structurally removing whole channels and then physically rebuilding the layer as a smaller dense convolution, or from the rare hardware that accelerates a specific sparsity pattern (NVIDIA's 2:4 sparsity in the research frontier below). The diagnostic question to ask yourself: after pruning, did the weight tensor's shape change? If not, you have saved a download, not a millisecond.

Practical Example: The Drone That Could Not Lift Its Model

An agricultural-robotics startup built a weed-detection model for an autonomous sprayer drone. Their segmentation network from Chapter 24, a strong U-Net variant, scored 0.91 mean intersection over union (IoU) in the lab on a workstation GPU. On the drone's onboard module the same model ran at 4 frames per second, far below the 20 the flight controller needed to spray at speed, and the extra power draw cut flight time by a fifth. The decision was not to buy a bigger chip (more weight, less flight time, worse economics) but to compress. They distilled the U-Net into a MobileNetV2-backbone student, structurally pruned 40 percent of the student's channels with a recovery fine-tune, then quantized to int8 for the onboard runtime. The result ran at 27 frames per second at 0.88 mean IoU, a 3-point drop that the agronomists judged invisible in the field because the sprayer's nozzle footprint was coarser than the segmentation error. The lesson: the right move was not better hardware but a model that fit the hardware, and the 3-point accuracy loss was free because the downstream actuator could not resolve it anyway. Always compress against the metric the product actually cares about, not the one on the leaderboard.

4. Distillation: Learning From a Larger Teacher Intermediate

Distillation trains a small student network to reproduce the outputs of a large, accurate teacher. The insight, due to Hinton and colleagues, is that the teacher's full output distribution carries more information than the hard label alone. When a good classifier sees an image of a husky, it does not output a one-hot vector; it assigns most probability to husky but meaningful residual probability to wolf and malamute and almost none to teapot. That pattern of relative probabilities, sometimes called dark knowledge, tells the student which classes are confusable, a far richer training signal than the single correct answer (see the illustration below). We expose it by softening the teacher's logits with a temperature $T$ in the softmax:

A wise owl teacher whispers a husky picture's full ranking of plausible answers (dog, wolf, malamute, faintly teapot) to a small eager student robot, illustrating how distillation transfers the teacher's soft probability distribution rather than just the single correct label.
The teacher does not just say what the answer is; it reveals what the wrong answers nearly were, and that ranking is the dark knowledge a small student cannot learn from labels alone.
$$p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}$$

A temperature $T > 1$ flattens the distribution, amplifying the small probabilities that hold the inter-class structure. The student is trained on a weighted sum of two losses: a distillation loss (the Kullback-Leibler divergence between the student's and teacher's softened distributions) and the ordinary cross-entropy against the true labels. A mixing weight (called $\alpha$ in the code below) sets how much the student should trust the teacher's soft guidance versus the hard ground-truth label, with a value near 0.7 leaning on the teacher while still anchoring to the correct answer. The temperature-scaled term is multiplied by $T^2$ to keep its gradient magnitude comparable to the hard-label term. The same softmax and cross-entropy you met in Chapter 18 reappear here, now comparing two networks instead of a network and a label. The code below implements the loss.

import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    """Combine soft-target KL (from the teacher) with hard-label cross-entropy."""
    # Soft targets: both distributions softened by temperature T.
    soft_teacher = F.softmax(teacher_logits / T, dim=1)
    soft_student = F.log_softmax(student_logits / T, dim=1)
    # KL divergence, scaled by T^2 so its gradient stays comparable to the hard loss.
    kd = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (T * T)
    # Standard supervised loss against the ground-truth labels.
    ce = F.cross_entropy(student_logits, labels)
    return alpha * kd + (1.0 - alpha) * ce

# Toy shapes: batch of 16, 1000 classes.
student_logits = torch.randn(16, 1000, requires_grad=True)
teacher_logits = torch.randn(16, 1000)             # teacher in eval mode, no grad
labels = torch.randint(0, 1000, (16,))
loss = distillation_loss(student_logits, teacher_logits, labels)
print(f"distillation loss = {loss.item():.4f}")   # distillation loss = 9.7xxx
Code Fragment 5: The knowledge-distillation loss: a temperature-softened KL term that pulls the student toward the teacher's full output distribution, plus the usual hard-label cross-entropy, blended by alpha. The T*T factor compensates for the gradient shrinkage that temperature scaling introduces, a detail that is easy to omit and that quietly weakens the soft-target signal if you do.

Distillation is the most architecturally flexible of the three techniques because the student can be any network you like. It is also how the famous compact models are made: DistilBERT in language, and in vision the small members of the DeiT and EfficientNet families that ship in production. It composes cleanly with the other two: distill first to get a small accurate student, then prune and quantize that student. We will see in Chapter 33 that distillation also drives the step-reduction methods that turn fifty-step diffusion samplers into one-or-two-step generators, the same idea applied to the number of inference iterations rather than the number of parameters.

5. Stacking the Techniques Advanced

In practice you rarely use one technique alone. The canonical deployment recipe applies all three in the order that respects their dependencies: distill into the target architecture first (so pruning and quantization act on the model you will actually ship), then structurally prune with a recovery fine-tune (so quantization calibrates the already-smaller network), then quantize last (because quantization is the final numeric step before the runtime). The three-word recipe is worth memorizing in exactly this order, distill, prune, quantize, because the order is the logic: each step should act on the model the next step will compress, so you choose the architecture before you thin it, and thin it before you round it. Table 28.1.1 shows representative cumulative effects on an ImageNet classifier; the exact numbers depend on the model and data, but the shape of the trade is stable across vision tasks.

Table 28.1.1: Representative cumulative effect of stacking compression on an ImageNet classifier. Numbers are illustrative of the typical trade, not a single measured run.
Stage Size Relative latency Top-1 accuracy
Float32 teacher (ResNet-50)98 MB1.00x76.1%
Distilled student (MobileNetV2)14 MB0.28x72.3%
+ Structured pruning (30%, fine-tuned)9 MB0.20x71.6%
+ Int8 quantization2.6 MB0.09x71.1%

Read the table as a path down the accuracy-versus-cost curve. The big size and latency wins come early (distillation changes the architecture); the late stages add multiplicative savings for a fraction of a point each. The final model is about 38 times smaller and roughly 11 times faster than the teacher, at a 5-point accuracy cost, and whether that trade is acceptable is a product question, not a research one. As Table 28.1.1 makes clear, you do not have to take the whole path; you stop at the first row that meets your budget. The histogram-and-statistics view of distributions from Chapter 2 is worth recalling here, because choosing a quantization range is exactly choosing where to clip a distribution's tails.

Research Frontier: Sub-4-Bit and Hardware-Native Sparsity (2024-2026)

The frontier has pushed well below int8. Weight-only quantization to 4 bits and below is now standard for large models, with GPTQ (Frantar et al. 2022, arXiv:2210.17323) and AWQ (Lin et al. 2023, arXiv:2306.00978) and their successors making 4-bit inference near-lossless for many networks by quantizing in an error-aware order and protecting salient channels. NVIDIA's Hopper GPUs added native FP8 tensor cores and the Blackwell generation added native FP4 (the NVFP4 format on its fifth-generation tensor cores), so sub-8-bit is now a hardware-accelerated format rather than a software trick, and the 2024-2025 generation of efficient vision and vision-language models train in FP8 from the start. On the sparsity side, NVIDIA's 2:4 structured sparsity (exactly two of every four weights zero) is the rare unstructured-looking pattern that hardware accelerates directly, and PyTorch's torch.ao and the SparseGPT line of work (Frantar & Alistarh 2023, arXiv:2301.00774) make it practical to reach it without retraining. For the very largest vision-language models of 2025-2026, these techniques are not optional polish; they are the only reason the models fit in deployable memory at all.

Fun Fact

The "dark knowledge" in distillation has an almost philosophical flavor: the teacher teaches the student not just what the answer is, but what the wrong answers nearly were. A network told only "this is a 7" learns less than one told "this is a 7, but it is the kind of 7 that looks a little like a 1 and nothing at all like an 8." The wrong-but-plausible alternatives encode the geometry of the class space, and that geometry is most of what a large model knows that a small one struggles to discover on its own.

6. Summary and the Road to a Runtime

We have three composable tools. Quantization lowers numeric precision to int8 (or below), cutting memory by four and unlocking fast integer math, available as cheap post-training quantization or higher-fidelity quantization-aware training. Pruning removes parameters, with structured pruning the variant that actually reduces compute on ordinary hardware. Distillation trains a small student to inherit a large teacher's accuracy through softened soft targets. Stacked in the order distill, prune, quantize, they routinely deliver order-of-magnitude savings for a single-digit accuracy cost. But a compressed PyTorch module is still a Python object running eagerly; to actually realize the speed it promises, it must be exported to a compiled runtime built for the target hardware. That is the subject of Section 28.2, where ONNX, TensorRT, and OpenVINO turn the shrunk graph into a deployable engine. Before moving on, put all three compression tools together in the Hands-On Lab at the end of this section, which builds a small compression studio that shrinks a pretrained classifier and reports the accuracy-size-latency trade for each technique.

Exercise 28.1.1: Why the Zero-Point Matters Conceptual

Symmetric quantization fixes the zero-point at the midpoint of the integer range and forces the float range to be symmetric around zero, $[-\beta, \beta]$. Asymmetric (affine) quantization lets the zero-point float so the range can be asymmetric. Consider the output of a ReLU, which is non-negative and often has many exact zeros. Explain in two or three sentences why asymmetric quantization wastes half its integer codes on such a tensor if forced symmetric, and why representing the real value zero exactly (no rounding error at zero) matters when that tensor feeds the next layer. Relate your answer to the role of padding zeros in the convolutions of Chapter 19.

Exercise 28.1.2: Measure the Real Pruning Trade Coding

Take a pretrained ResNet-18 and a small labeled validation set (CIFAR-10 or a subset of ImageNet works). Apply prune.ln_structured channel pruning at sparsity levels of 10, 30, 50, and 70 percent, each followed by three epochs of recovery fine-tuning, and record accuracy at every level. Plot accuracy against sparsity. Then, for the 50 percent level, actually rebuild the pruned layers as smaller dense convolutions (or use Torch-Pruning) and measure the real CPU latency before and after. Write one paragraph reconciling the storage saving, the theoretical FLOP saving, and the measured latency: do they agree, and if not, why?

Exercise 28.1.3: Distillation Temperature Sweep Analysis

Using the distillation_loss function from subsection 4, train a small student (a 4-layer CNN) to imitate a ResNet-18 teacher on CIFAR-10, sweeping the temperature $T$ over $\{1, 2, 4, 8, 16\}$ at fixed $\alpha = 0.7$. Record final student accuracy for each. Explain the shape of the resulting curve: why very low temperature (the soft targets approach hard labels) and very high temperature (the soft targets approach a uniform distribution) both tend to hurt, and where the useful middle lies. Connect the analysis to what the softened distribution is actually communicating about the class geometry.

Hands-On Lab: A Model Compression Studio

Duration: about 75 minutes Difficulty: Intermediate

Objective

Build a small command-line compression studio that takes a pretrained image classifier and reports, in one table, the accuracy, on-disk size, and CPU latency of the float32 baseline against three compressed variants: a distilled student, a structured-pruned model, and an int8 dynamically quantized model. The artifact is a single reusable script that turns the accuracy-size-latency trade of Table 28.1.1 from theory into a measurement you ran yourself.

What You'll Practice

  • Post-training int8 quantization with the torch.ao quantization API.
  • Structured channel pruning with torch.nn.utils.prune and brief recovery fine-tuning.
  • Knowledge distillation with the temperature-scaled soft-target loss from subsection 4.
  • Measuring real CPU latency and model size, not just FLOP estimates.
  • Reading the accuracy-versus-cost curve to pick the variant that meets a budget.

Setup

A CPU is enough; no GPU required. Install the dependencies and let torchvision cache CIFAR-10 (about 170 MB) on first run.

pip install "torch>=2.2" "torchvision>=0.17"

Steps

Step 1: Load data and a pretrained teacher

Set up a CIFAR-10 train and test loader, and load a ResNet-18 with its final layer resized to ten classes. Fine-tune the teacher for two quick epochs so it is a reasonable model to compress and to distill from.

import torch, torch.nn as nn, torchvision as tv
from torchvision.transforms import v2

device = "cuda" if torch.cuda.is_available() else "cpu"
tf = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
                 v2.Normalize([0.49, 0.48, 0.45], [0.25, 0.24, 0.26])])
train = tv.datasets.CIFAR10("./data", train=True,  download=True, transform=tf)
test  = tv.datasets.CIFAR10("./data", train=False, download=True, transform=tf)
train_dl = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True, num_workers=2)
test_dl  = torch.utils.data.DataLoader(test,  batch_size=256, num_workers=2)

teacher = tv.models.resnet18(weights="IMAGENET1K_V1")
teacher.fc = nn.Linear(teacher.fc.in_features, 10)
# TODO: write train(model, loader, epochs) and evaluate(model, loader) -> accuracy.
# Hint: standard cross-entropy loop; evaluate runs under torch.no_grad() on CPU for fair latency later.
teacher = teacher.to(device)
# TODO: train(teacher, train_dl, epochs=2)
Hint

Keep the loops minimal: an Adam optimizer at lr=1e-3, cross-entropy loss, and a counter of correct predictions over the test set. Two epochs are enough to reach the high seventies; the lab is about the compression delta, not the absolute number.

Step 2: Write the measurement harness

Every variant must be judged on the same three axes. Write helpers that report model size on disk and median CPU latency for a single image, so each later step produces one comparable row.

import os, time, copy, statistics

def size_mb(model):
    torch.save(model.state_dict(), "tmp.pt")
    mb = os.path.getsize("tmp.pt") / 1e6
    os.remove("tmp.pt"); return mb

def cpu_latency_ms(model, runs=50):
    model = model.to("cpu").eval()
    x = torch.randn(1, 3, 32, 32)
    with torch.no_grad():
        for _ in range(5): model(x)            # warm up
        # TODO: time `runs` forward passes and return the MEDIAN in milliseconds.
        # Hint: collect time.perf_counter() deltas into a list, use statistics.median.
Hint

Use the median rather than the mean: one scheduling hiccup can double a single measurement, and the median is robust to that. Always measure latency on CPU with batch size one so the int8 model (CPU-only in PyTorch eager mode) is comparable to the others.

Step 3: Distill a MobileNetV2 student

Train a smaller MobileNetV2 to imitate the teacher using the temperature-scaled soft-target loss. This is the architecture change that buys the biggest single jump down the cost curve.

import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    soft = F.kl_div(F.log_softmax(student_logits / T, dim=1),
                    F.softmax(teacher_logits / T, dim=1),
                    reduction="batchmean") * (T * T)
    hard = F.cross_entropy(student_logits, labels)
    return alpha * soft + (1 - alpha) * hard

student = tv.models.mobilenet_v2(weights=None, num_classes=10).to(device)
teacher.eval()
# TODO: train `student` for ~3 epochs. For each batch, get teacher_logits under
#       torch.no_grad(), then backprop distillation_loss(student(x), teacher_logits, y).
Hint

Freeze the teacher with teacher.eval() and wrap its forward pass in torch.no_grad() so no gradients flow into it. The student trains normally; only the loss function changed. Reuse the temperature and alpha from Exercise 28.1.3 if you already swept them.

Step 4: Structured-prune the student

Apply channel-level (structured) pruning to the student's convolutions, then fine-tune briefly to recover the lost accuracy. Structured pruning, unlike unstructured masking, is the variant that can actually shrink compute on ordinary hardware.

import torch.nn.utils.prune as prune

pruned = copy.deepcopy(student)
for module in pruned.modules():
    if isinstance(module, nn.Conv2d) and module.out_channels > 8:
        # TODO: prune 30% of output channels by L2 norm, then make it permanent.
        # Hint: prune.ln_structured(module, name="weight", amount=0.3, n=2, dim=0)
        #       followed by prune.remove(module, "weight").
        pass
# TODO: fine-tune `pruned` for ~2 epochs to recover accuracy.
Hint

Pruning with the prune API zeroes channels but keeps the tensor shape, so on-disk size barely moves until you rebuild the layers as smaller dense convolutions (or use the Torch-Pruning library, see the stretch goals). For this lab, measure the accuracy recovery and note that the latency win needs the dense rebuild; that gap is the lesson.

Step 5: Dynamically quantize to int8

Apply post-training dynamic quantization to the distilled student. This is the cheapest technique: no retraining, one function call, and it converts linear and convolutional weights to int8 for CPU inference.

from torch.ao.quantization import quantize_dynamic

q_student = copy.deepcopy(student).to("cpu").eval()
# TODO: quantize q_student with quantize_dynamic, targeting {nn.Linear, nn.Conv2d},
#       dtype=torch.qint8. Confirm size_mb(q_student) dropped versus the float student.
Hint

Dynamic quantization is quantize_dynamic(model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8). It quantizes weights statically and activations on the fly, so it needs no calibration data, which is exactly why it is the right first thing to try. For full static int8 with calibration, see the stretch goals.

Step 6: Assemble the comparison table

Run every variant through the same accuracy, size, and latency harness and print one table. This is the deliverable: your own version of Table 28.1.1, measured rather than quoted.

rows = []
for name, model in [("Float32 teacher", teacher),
                    ("Distilled student", student),
                    ("+ Pruned (30%)", pruned),
                    ("+ Int8 quantized", q_student)]:
    # TODO: append (name, evaluate(model, test_dl), size_mb(model), cpu_latency_ms(model)).
    pass

print(f"{'Variant':22} {'Acc%':>6} {'MB':>7} {'ms':>7}")
for n, acc, mb, ms in rows:
    print(f"{n:22} {acc*100:6.1f} {mb:7.1f} {ms:7.2f}")
Hint

Evaluate the int8 model on CPU; the others can evaluate on whatever device trained them but must be measured for latency on CPU for a fair comparison. If a row's size or latency does not move the way the chapter predicts, that is a finding worth a sentence, not a bug to hide.

Expected Output

A four-row table that traces the same downward path as Table 28.1.1. Exact numbers depend on your training budget, but the shape is consistent: the distilled student is several times smaller and faster than the teacher for a few points of accuracy, pruning trims a little more accuracy (and, once the layers are rebuilt dense, latency), and int8 quantization roughly quarters the on-disk size of the student with a sub-point accuracy change. A representative run looks like this.

Variant                  Acc%      MB      ms
Float32 teacher          78.4    44.8   18.30
Distilled student        75.1     8.8    6.10
+ Pruned (30%)           73.9     8.7    5.95
+ Int8 quantized         74.8     2.4    4.40

The pruned row's size and latency barely move because the prune API keeps tensor shapes; closing that gap is the first stretch goal, and noticing it is the point of Step 4's hint.

Stretch Goals

  • Replace the masking-only pruning with the Torch-Pruning library so the pruned layers become genuinely smaller dense convolutions, then re-measure: the latency row should finally move, reconciling the storage, FLOP, and wall-clock numbers the way Exercise 28.1.2 asks.
  • Swap dynamic quantization for full static int8 with calibration (torch.ao.quantization.prepare then convert after running a few hundred calibration images), and compare its accuracy against the dynamic variant.
  • Add a Library Shortcut column: export the distilled student with Ultralytics or to ONNX and run it through ONNX Runtime (the topic of Section 28.2), then measure that latency against the eager PyTorch numbers to preview what a compiled runtime adds.
Complete Solution
import os, time, copy, statistics
import torch, torch.nn as nn, torch.nn.functional as F
import torchvision as tv
import torch.nn.utils.prune as prune
from torchvision.transforms import v2
from torch.ao.quantization import quantize_dynamic

device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Step 1: data and teacher ---
tf = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
                 v2.Normalize([0.49, 0.48, 0.45], [0.25, 0.24, 0.26])])
train = tv.datasets.CIFAR10("./data", train=True,  download=True, transform=tf)
test  = tv.datasets.CIFAR10("./data", train=False, download=True, transform=tf)
train_dl = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True, num_workers=2)
test_dl  = torch.utils.data.DataLoader(test,  batch_size=256, num_workers=2)

def train_model(model, loader, epochs=2, loss_fn=None):
    model.train().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    for _ in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            out = model(x)
            loss = loss_fn(out, y) if loss_fn else F.cross_entropy(out, y)
            loss.backward(); opt.step()
    return model

def evaluate(model, loader):
    model.eval(); dev = next(model.parameters()).device
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            pred = model(x.to(dev)).argmax(1).cpu()
            correct += (pred == y).sum().item(); total += y.numel()
    return correct / total

teacher = tv.models.resnet18(weights="IMAGENET1K_V1")
teacher.fc = nn.Linear(teacher.fc.in_features, 10)
train_model(teacher, train_dl, epochs=2)

# --- Step 2: measurement harness ---
def size_mb(model):
    torch.save(model.state_dict(), "tmp.pt")
    mb = os.path.getsize("tmp.pt") / 1e6
    os.remove("tmp.pt"); return mb

def cpu_latency_ms(model, runs=50):
    model = model.to("cpu").eval()
    x = torch.randn(1, 3, 32, 32)
    times = []
    with torch.no_grad():
        for _ in range(5): model(x)
        for _ in range(runs):
            t0 = time.perf_counter(); model(x)
            times.append((time.perf_counter() - t0) * 1000)
    return statistics.median(times)

# --- Step 3: distillation ---
def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    soft = F.kl_div(F.log_softmax(student_logits / T, dim=1),
                    F.softmax(teacher_logits / T, dim=1),
                    reduction="batchmean") * (T * T)
    hard = F.cross_entropy(student_logits, labels)
    return alpha * soft + (1 - alpha) * hard

student = tv.models.mobilenet_v2(weights=None, num_classes=10).to(device)
teacher.eval()
opt = torch.optim.Adam(student.parameters(), lr=1e-3)
for _ in range(3):
    for x, y in train_dl:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            t_logits = teacher(x)
        opt.zero_grad()
        loss = distillation_loss(student(x), t_logits, y, T=4.0, alpha=0.7)
        loss.backward(); opt.step()

# --- Step 4: structured pruning + recovery ---
pruned = copy.deepcopy(student)
for module in pruned.modules():
    if isinstance(module, nn.Conv2d) and module.out_channels > 8:
        prune.ln_structured(module, name="weight", amount=0.3, n=2, dim=0)
        prune.remove(module, "weight")
train_model(pruned, train_dl, epochs=2)

# --- Step 5: dynamic int8 quantization ---
q_student = copy.deepcopy(student).to("cpu").eval()
q_student = quantize_dynamic(q_student, {nn.Linear, nn.Conv2d}, dtype=torch.qint8)

# --- Step 6: comparison table ---
rows = []
for name, model in [("Float32 teacher", teacher),
                    ("Distilled student", student),
                    ("+ Pruned (30%)", pruned),
                    ("+ Int8 quantized", q_student)]:
    rows.append((name, evaluate(model, test_dl), size_mb(model), cpu_latency_ms(model)))

print(f"{'Variant':22} {'Acc%':>6} {'MB':>7} {'ms':>7}")
for n, acc, mb, ms in rows:
    print(f"{n:22} {acc*100:6.1f} {mb:7.1f} {ms:7.2f}")