
Transfer Learning (4): Few-Shot Learning
Learn new concepts from a handful of examples. Covers the N-way K-shot protocol, metric learning (Siamese, Prototypical, Matching, Relation networks), meta-learning (MAML, Reptile), episodic training, miniImageNet benchmarks, and a complete Prototypical Network implementation.
Show a child one photograph of a pangolin and they will spot pangolins for life. Show a deep learning model one photograph and it will give you a uniformly random guess. Few-shot learning is the field that closes that gap: building classifiers that work with only one to ten labeled examples per class.
The trick is not to memorize individual classes harder. It is to learn how to learn from very few examples, then carry that ability over to brand-new classes at test time. This article covers the two families that dominate the field today: metric learning, which learns a good distance function, and meta-learning, which learns a good initialization.
What You Will Learn#
- The N-way K-shot evaluation protocol and why standard training fails on it
- Metric learning: Siamese, Prototypical, Matching, and Relation networks
- Meta-learning: MAML and its first-order cousins (FOMAML, Reptile)
- Episodic training: matching training-time difficulty to test-time difficulty
- A clean, end-to-end Prototypical Network implementation in PyTorch
Prerequisites: Parts 1-2 of this series; comfort with PyTorch and basic optimization.
The Few-Shot Challenge#

Problem Setup: N-way K-shot#
The community uses a single, shared evaluation protocol so that papers are comparable:
- N-way: the model must classify among $N$ classes.
- K-shot: for each class, only $K$ labeled examples are available.
A “5-way 1-shot” task is therefore: here is one labeled image from each of five classes you have never seen before; now classify a batch of new query images.
Each evaluation episode consists of:
- a support set $\mathcal{S} = \{(x_i, y_i)\}_{i=1}^{NK}$ — the $N \times K$ labeled examples,
- a query set $\mathcal{Q} = \{(x_j, y_j)\}_{j=1}^{NQ}$ — the unlabeled images to classify (with hidden labels used only to measure accuracy).
Reported numbers are averages over hundreds or thousands of episodes drawn from a held-out novel-class set, with 95% confidence intervals because the per-episode variance is large.
Why Standard Training Fails#
Three forces conspire against a vanilla classifier:
- Data scarcity. With $K = 1$ you literally cannot estimate a within-class variance. With $K = 5$ you can, but barely.
- Overfitting. A high-capacity network will memorize the support examples instead of learning a class-discriminative rule.
- Inter-class similarity. Novel classes drawn from the same domain (e.g. two breeds of dog) often differ only in subtle features that a randomly initialized classifier has no reason to attend to.
Empirical risk minimization with weight decay is not enough: regularization stops parameters from blowing up, but it does not inject the inductive bias required to generalize from a single image.
The Core Insight#
To learn from few samples you need prior knowledge. Few-shot learning gets that prior by training on a large set of base classes (with many examples each), then evaluating on disjoint novel classes (with few). The two main routes are:
- Metric learning — train a backbone whose embedding space already separates classes, so a fresh class can be characterized by the location of its few support points. Classify queries by their distance in this space.
- Meta-learning — train across many simulated few-shot tasks so the network learns to be adapted by a few gradient steps. Treat “fast adaptation” itself as the thing to optimize.
Both share the same data split (base vs. novel) but invest the prior knowledge differently: metric learning bakes it into the embedding; meta-learning bakes it into the optimization initialization.
Metric Learning: Classification by Distance#

The metric-learning recipe is one sentence long: learn an embedding $f_\theta$ such that same-class samples cluster together and different-class samples lie far apart, then classify a query by its proximity to the support points.
Siamese Networks#
$$d(x_1, x_2) = \|f_\theta(x_1) - f_\theta(x_2)\|_2.$$ $$\mathcal{L} = y \cdot d^2 + (1 - y) \cdot \max(0, m - d)^2,$$with $y = 1$ for same-class pairs (pull together) and $y = 0$ for different-class pairs (push apart until the distance exceeds margin $m$ ). At test time, classify a query by the label of its nearest support sample.
Prototypical Networks#
Prototypical networks improve on the pairwise picture by collapsing each support class into a single point.
Computing prototypes#
$$\mathbf{c}_c = \frac{1}{K} \sum_{k=1}^{K} f_\theta(x_k^c).$$Geometrically it is the centroid of the class in embedding space.
Classification#
$$P(y = c \mid x_q) = \frac{\exp\bigl(-d(f_\theta(x_q), \mathbf{c}_c)\bigr)}{\sum_{c'} \exp\bigl(-d(f_\theta(x_q), \mathbf{c}_{c'})\bigr)}, \qquad d(u, v) = \|u - v\|_2^2.$$Train end-to-end with cross-entropy on the query predictions of each episode.
Why prototypes are principled#
If we model class-conditional embeddings as Gaussians with shared isotropic covariance, $P(x \mid y = c) = \mathcal{N}(\mu_c, \sigma^2 I)$ , then the maximum-likelihood class is exactly the nearest centroid. Prototypical networks are therefore the deep-learning incarnation of the Bayes-optimal classifier under that (admittedly strong) assumption — which is why it tends to work so well in practice.
A second, cleaner, observation: under squared-Euclidean distance the decision boundary between any two classes is a hyperplane in embedding space. So Prototypical networks are equivalent to a linear classifier in the learned space, but with the linear weights tied to the prototype geometry.
Matching Networks#
Matching networks replace the hard nearest-prototype rule with a soft attention over the entire support set.

Here $y_i$ is a one-hot label vector, so the prediction is a convex combination of one-hots weighted by attention.
The other contribution of the paper is full context embeddings: a bidirectional LSTM is run over the support set so each support embedding is aware of every other support sample. The intuition is that what counts as a discriminative feature depends on the other classes you are trying to separate from — and the LSTM lets the network express that.
Relation Networks#
$$r_{q, c} = g_\phi\bigl(\mathrm{concat}(f_\theta(x_q),\, \mathbf{c}_c)\bigr) \in [0, 1].$$
The training target is $r_{q, c} = \mathbb{1}\{y_q = c\}$ with mean-squared-error loss; both modules are trained jointly. Why bother? Fixed metrics implicitly assume the embedding space is isotropic — every dimension counts equally. A learned metric lets the network downweight dimensions that turn out to be uninformative for the task at hand.
Meta-Learning: Learning to Learn#
Where metric learning bakes the prior into the embedding, meta-learning bakes it into the optimization process itself. The model is trained across many tasks so that adapting it to a new task takes only a handful of gradient steps.
MAML: Model-Agnostic Meta-Learning#
MAML’s idea is simple and surprisingly effective: search for an initialization $\theta$ such that one or two gradient steps on any new task’s support set already produce a good model.

Algorithm#
For each sampled task $\mathcal{T}_i$ (with its own support and query sets):
- Inner loop (per-task adaptation). Take one (or a few) gradient steps on the support loss:
2. Outer loop (meta-update). Evaluate the adapted parameters on the query set and update the initialization:
$$ \theta \leftarrow \theta - \beta \nabla_\theta \sum_i \mathcal{L}_{\mathcal{T}_i}^{\text{query}}(\theta_i'). $$The outer-loop gradient differentiates through the inner-loop update, which involves second derivatives of the support loss with respect to $\theta$ — a Hessian-vector product.
First-order approximation (FOMAML)#
$$\nabla_\theta \mathcal{L}(\theta_i') \approx \nabla_{\theta_i'} \mathcal{L}(\theta_i'),$$which is just the gradient at the adapted point, evaluated as if $\theta_i'$ did not depend on $\theta$ . Cost drops to $O(d)$ , and reported accuracies barely change.
Geometric intuition#
MAML pushes $\theta$ toward a region of the loss landscape that is flat with respect to fast adaptation: from this point a few steps in any task-specific direction reach a low loss. Think of $\theta$ as a universal launching pad rather than a universally-good model.
Reptile: Even Simpler#
$$\theta \leftarrow \theta + \epsilon \,(\tilde{\theta} - \theta).$$That’s the whole algorithm. Despite the simplicity it works almost as well as MAML, because moving the meta-parameters toward task-specific solutions across many tasks ends up locating $\theta$ near a shared sweet spot.
| Method | Gradient order | Per-step cost | Implementation | miniImageNet (5w-5s)* |
|---|---|---|---|---|
| MAML | Second-order | High (Hessian) | Hard | ~63% |
| FOMAML | First-order | Medium | Easy | ~62% |
| Reptile | First-order | Low | Trivial | ~66% |
*Reported in the original papers; numbers vary across implementations.
Episodic Training#
Standard supervised training shows the network the entire base-class dataset and asks it to classify. Episodic training reframes the entire training loop to look like the test loop.

How an episode is built#
Each iteration:
- Sample $N$ classes from the base-class pool.
- Sample $K$ examples per class for the support set.
- Sample $Q$ additional examples per class for the query set.
- Train the model to classify the queries given only that support set.
| |
Why this matters#
The model never gets to see the full base-class dataset at once. Every gradient update simulates a few-shot task, so the inductive bias the network develops is precisely the one needed at test time. This is curriculum learning where the curriculum is the test-time conditions.
A good sanity check: turn off episodic training and just train a flat $|C_{\text{base}}|$ -way classifier, then drop a linear head on the frozen features. With a strong backbone (deep ResNet, large augmentation) this “Baseline++” recipe is competitive with metric- and meta-learning approaches — a result Chen et al. (ICLR 2019) used to argue that episodic training matters less than people thought, and that backbone capacity matters more.
How Well Does Any of This Work?#

The numbers above are from the original papers (with later work routinely surpassing them by using larger backbones and pre-training tricks). Two things to take away:
- The 1-shot vs. 5-shot gap is huge. Going from one example to five typically adds 10-20 percentage points — a reminder that even a tiny amount of data dominates clever architecture choices.
- Methods cluster. Once the backbone is held fixed, Prototypical, Matching, Relation, and MAML-family numbers land within a few points of each other. Pick by engineering taste (simplicity, compute budget, tooling) rather than chasing the last point of accuracy.
Complete Implementation: Prototypical Networks#
| |
Code Walkthrough#
| Component | Role |
|---|---|
ProtoNetEncoder | 4-block CNN, the standard backbone for miniImageNet experiments |
compute_prototypes | Averages support embeddings per class |
forward | Returns negative Euclidean distances as logits |
EpisodeSampler | Builds an N-way K-shot episode each iteration |
train | Episodic training loop with periodic validation |
Two implementation notes worth highlighting:
torch.cdist(..., p=2)returns the Euclidean (not squared) distance. Negating it as logits is fine for argmax but technically does not match the Bayes-optimal Gaussian-mean derivation. In practice the difference does not matter; if you want exact correspondence, square it.- Always relabel the support classes to $0, \ldots, N-1$ inside the sampler so that the cross-entropy targets have the expected shape.
Prototypical Networks: Bayes-Optimality and Variance-Adjusted Prototypes#

The “use the class mean” rule is not folklore. It falls out of Bayes’ theorem the moment you commit to a Gaussian model of the embedding space, and the derivation is short enough to do by hand.
The theorem#
$$p(x \mid y = c) = \mathcal{N}(\mu_c, \sigma^2 I), \qquad p(y = c) = \frac{1}{N}.$$ $$\arg\max_c p(y = c \mid x) = \arg\min_c \frac{\|x - \mu_c\|^2}{2\sigma^2}.$$Sketch of the proof#
$$\log p(x \mid y = c) = -\frac{d}{2}\log(2\pi\sigma^2) - \frac{\|x - \mu_c\|^2}{2\sigma^2}.$$The first term is independent of $c$ and drops out of the argmax. What remains is precisely $-\|x - \mu_c\|^2 / (2\sigma^2)$ — maximizing it is the same as minimizing squared distance to the mean.
So Prototypical Networks are not a heuristic. They are the deep-learning realization of an MAP classifier under one specific (and admittedly strong) generative assumption.
When the shared-variance assumption breaks#
$$p(x \mid y = c) = \mathcal{N}(\mu_c, \sigma_c^2 I),$$ $$\log p(y = c \mid x) = -\frac{d}{2}\log\sigma_c^2 - \frac{\|x - \mu_c\|^2}{2\sigma_c^2} + \text{const}.$$The classifier becomes a variance-weighted prototype rule — closer-knit clusters get a higher implicit confidence, scattered ones get downweighted.
Variance-adjusted PyTorch#
| |
The clamp on var_c matters more than it looks. Without it, a class whose support points happen to land on top of each other gets infinite confidence and the softmax saturates.
Numbers on miniImageNet#
| Method | 5-way 1-shot | 5-way 5-shot |
|---|---|---|
| ProtoNet (Euclidean) | 49.4% | 68.2% |
| WeightedProtoNet (per-class $\sigma$ ) | 50.8% | 69.7% |
The lift is small but consistent across seeds. It is not free: with $K = 1$ the per-class variance is undefined, so the model falls back to the unit-variance prior and you should expect zero benefit. We see real gains from $K = 3$ onward, where each class has enough support points to estimate a sensible scale.
Caveat: noise in the variance estimate#
A variance computed from $K = 5$ samples has a relative error on the order of $\sqrt{2/(K-1)} \approx 70\%$ . That is enormous. The reason variance-adjustment helps at all is that even a noisy estimate is better than the implicit “all classes have the same variance” prior — but if you push to $K = 1$ or $K = 2$ the noise dominates and accuracy degrades below the baseline.
$$\hat{\sigma}_c^2 = \lambda \cdot \sigma_c^2 + (1 - \lambda) \cdot \bar{\sigma}^2,$$with $\lambda$ scheduled by $K$ . This is a James-Stein-style shrinkage estimator and recovers vanilla ProtoNet at $\lambda = 0$ .
A second, less-discussed pitfall: per-class variance is a single scalar in the isotropic model above. Real embedding clusters are anisotropic, with elongated directions and tight ones. Letting $\sigma_c^2$ be a vector (per-feature) instead of a scalar gives a small additional bump but also amplifies the noise problem — the variance of a per-feature variance estimate is the same $\sqrt{2/(K-1)}$ , applied to each of $D$ dimensions independently.
But what if the underlying metric isn’t Euclidean at all?
MAML Inner-Outer Loop on a 2D Toy#

MAML’s two-loop structure is easy to write down and surprisingly hard to see. A 2D toy regression problem makes the geometry concrete: you can plot the meta-initialization, the per-task adapted points, and the trajectories the inner loop walks along.
The task family#
$$y = \sin(\omega_t x + \phi_t), \qquad \omega_t \sim \mathcal{U}[0.5, 2.0], \quad \phi_t \sim \mathcal{U}[0, 2\pi].$$The model is a tiny MLP $f_\theta: \mathbb{R} \to \mathbb{R}$ . We treat the parameter $\theta$ as the thing being meta-learned: a single $\theta_0$ that adapts in 5 inner-loop steps to any $(\omega_t, \phi_t)$ .
For visualization we sometimes restrict $\theta$ to two scalars and plot the per-task loss surface $\mathcal{L}_t(\omega, \phi)$ directly — the inner loop is then a literal walk along the surface, and the outer loop is the choice of starting point.
Full MAML in PyTorch#
| |
The key line is create_graph=True in torch.autograd.grad. Without it, the inner-loop gradient is detached and the outer-loop backward has nothing to differentiate through — you’ve silently fallen back to FOMAML. With it, PyTorch keeps the inner graph alive and the second derivative (Hessian-vector product) is computed when meta_loss.backward() runs.
Cost: the Hessian-vector product#
Each outer step costs roughly $2 \times n_{\text{inner}}$ forward passes plus the backward through the unrolled inner loop. The peak memory scales as $O(n_{\text{inner}} \cdot |\theta|)$ because every intermediate parameter set must be retained for the backward. For our tiny MLP this is invisible. For a ResNet-12 it is the reason FOMAML and Reptile exist.
Numbers#
After 10k outer steps, evaluating on 100 held-out tasks with 5 inner steps:
| Method | Held-out MSE |
|---|---|
| MAML $\theta_0$ + 5 inner steps | 0.04 |
| Random init + 5 SGD steps (same compute) | 0.31 |
| Random init + 50 SGD steps | 0.06 |
The MAML initialization gets in 5 steps what a cold start needs 50 steps to match. That ratio — not the absolute accuracy — is the thing meta-learning buys you.
If you ever plot the loss surface for a single task and overlay the inner-loop path starting from $\theta_0$ , you’ll see that $\theta_0$ sits in a valley that is flat with respect to all task-specific minima. It is not the best place for any single task, but it is the best place from which to walk to any task in 5 steps.
Why second-order matters here (and when it doesn’t)#
In our toy, second-order MAML and FOMAML give nearly identical held-out MSE — 0.04 vs. 0.05. The Hessian-vector product is real but small in magnitude because the inner loss landscape is well-behaved: the per-task surfaces are smooth bowls and the linearization that FOMAML implicitly performs is tight. On harder problems (image classification with deep encoders, RL with sharp reward landscapes) the second-order term is what tells the meta-update how the inner-loop trajectory itself bends in response to a parameter change — and dropping it costs a few percentage points.
The practical rule: start with FOMAML or Reptile because they fit in a single GPU and are easy to debug; only enable second-order MAML if you have evidence that the inner loss surface is curved enough to matter.
Bridge: prototypes plus a learned init still both assume Euclidean geometry. The next step is to learn the geometry too.
Beyond Euclidean: Learned Task-Conditional Metrics#
Squared-Euclidean distance treats every embedding dimension as equally important. That is fine when the network has been trained long enough to whiten its features — and a problem the moment two classes differ along a low-variance direction the embedding has chosen to compress away.
Mahalanobis distance#
$$d_M(x, \mu) = \sqrt{(x - \mu)^T M (x - \mu)}.$$Setting $M = I$ recovers Euclidean. Setting $M = \Sigma^{-1}$ for a class-conditional covariance recovers Mahalanobis distance in its statistical sense. Learning $M$ end-to-end interpolates between the two and lets the network rescale dimensions to whatever the task demands.
To keep $M$ PSD, parametrize as $M = L L^T$ with $L$ unconstrained — gradients flow through $L$ , and $M$ is PSD by construction.
Task-conditional metrics#
A fixed $M$ is still a global commitment. A more flexible move is to let $M$ depend on the episode: a small hypernetwork reads the support set’s statistics and emits the per-episode metric.
$$M = M(\mathcal{S}) = h_\psi\bigl(\text{summary}(\mathcal{S})\bigr).$$The summary can be as light as the concatenated mean-and-trace-of-covariance per class. The hypernetwork is then a small MLP that outputs the entries of $L$ .
| |
Drop this module into a Prototypical Network and replace torch.cdist with mahalanobis_logits(q_emb, protos, M).
Numbers on miniImageNet (5-way 1-shot)#
| Metric | Accuracy |
|---|---|
| Euclidean (vanilla ProtoNet) | 49.4% |
| Fixed Mahalanobis (learned $M$ ) | 51.1% |
| Task-conditional $M(\mathcal{S})$ | 52.8% |
The trend is consistent: more flexibility in the metric helps, up to the point where parameter count starts to overfit the meta-train set.
Practical caveat: overfitting and weight decay#
The hypernetwork adds parameters that see only the support summary — a bottleneck of just $2D$
scalars per episode. With a small meta-train set (the standard 64 base classes of miniImageNet) it is easy to memorize a metric that works for the training tasks and generalizes poorly. Three things help in practice: (1) a low-rank parametrization (rank << D), (2) explicit weight decay on the hypernet, (3) dropout on the summary itself before it enters the hyper-MLP.
Cross-domain wrinkle#
Task-conditional metrics learned on one domain rarely transfer to another. The Cross-Domain Few-Shot Learning benchmark (Guo et al. 2020) showed that methods which beat ProtoNet on miniImageNet routinely underperform it when evaluated on CropDisease, EuroSAT, ISIC, or ChestX. The metric module learns to exploit domain-specific feature statistics — which is exactly what makes it brittle when the target domain shifts.
A safe rule: use a learned metric only when meta-train and meta-test distributions are well-matched. Across domains, fall back to plain Euclidean and rely on a stronger backbone instead.
Diagnostic: when does the metric actually help?#
A quick check before you commit to a learned metric: compute the between-class to within-class variance ratio in your frozen embedding space, on a held-out set of base classes. If the ratio is high (say, > 5), the embedding has already done the geometric work and Euclidean will be hard to beat. If the ratio is low (< 2), there is room for a learned $M$ to rotate and rescale dimensions and you should see clear gains.
This same ratio also predicts when variance-adjusted prototypes pay off: low ratio means classes overlap, which means per-class scale information is informative; high ratio means classes are already well-separated, and the variance estimate is just adding noise.
Bridge: with the metric, the prototypes, and the initialization all learnable, the next question is how to choose between these levers — which is what the benchmark numbers in the next section answer.
FAQ#
How is few-shot learning different from ordinary transfer learning?#
It is the limit case. Transfer learning assumes you have at least hundreds of target labels, so a fine-tuned head can do most of the work. Few-shot learning has 1-10. That gap is large enough that you need training-time machinery — episodic sampling, metric or meta objectives — not just a downstream training trick.
Why do Prototypical networks use the mean as the prototype?#
Under Gaussian class-conditionals with shared isotropic covariance, the class mean is the Bayes-optimal classifier. The mean is also robust enough to be useful even when that assumption fails — especially for $K \ge 5$ .
MAML or Prototypical Networks — which should I use?#
Default to Prototypical Networks: simpler, faster, the prototypes are interpretable, and they tend to match or beat MAML on standard image benchmarks. Reach for MAML when (a) the tasks are diverse and look qualitatively different from one another, (b) the data is non-image and you do not have a great pretrained embedding, or (c) you specifically need adaptation that updates the entire network rather than just a final classifier.
How many base classes do I need?#
More is always better for generalization. Standard benchmarks use 64 base classes (miniImageNet) up to 1200+ (Omniglot). With fewer than ~30 base classes you tend to see severe overfitting to the base set itself, and novel-class accuracy collapses.
Does any of this work for non-image data?#
Yes. Prototypical Networks work for anything with a meaningful embedding — text (use a transformer encoder), graphs (use a GNN), audio (use a spectrogram CNN). MAML and Reptile are model-agnostic by design. The episodic protocol does not care about modality.
Why are confidence intervals always reported?#
The per-episode accuracy variance is large — a single hard episode can swing 10-20 points. Reporting the mean over a few hundred episodes plus a 95% CI is the only way to make numbers comparable across papers.
Summary#
Few-shot learning attacks deep learning’s biggest practical bottleneck: data scarcity in the long tail.
- Metric learning (Siamese, Prototypical, Matching, Relation Networks) learns an embedding space where distance equals dissimilarity. Simple, fast, interpretable. Prototypical Networks are the workhorse default.
- Meta-learning (MAML, FOMAML, Reptile) learns an initialization from which a few gradient steps reach the optimum of any new task. More flexible, costlier, less interpretable.
- Episodic training is the unifying training paradigm: each iteration is a fresh few-shot task, so training-time difficulty matches test-time difficulty.
Across families, accuracies converge once the backbone is held fixed — a reminder that backbone capacity and pretraining quality matter at least as much as the few-shot algorithm on top.
Next: Part 5 — Knowledge Distillation , where we compress a large teacher model into a small student that mimics it.
References#
- Snell et al. (2017). Prototypical Networks for Few-shot Learning. NeurIPS. arXiv:1703.05175
- Finn et al. (2017). Model-Agnostic Meta-Learning (MAML). ICML. arXiv:1703.03400
- Vinyals et al. (2016). Matching Networks for One Shot Learning. NeurIPS. arXiv:1606.04080
- Sung et al. (2018). Learning to Compare: Relation Network for Few-Shot Learning. CVPR. arXiv:1711.06025
- Nichol et al. (2018). On First-Order Meta-Learning Algorithms (Reptile). arXiv:1803.02999
- Koch et al. (2015). Siamese Neural Networks for One-shot Image Recognition. ICML Deep Learning Workshop.
- Chen et al. (2019). A Closer Look at Few-shot Classification. ICLR. arXiv:1904.04232
- Wang et al. (2020). Generalizing from a Few Examples: A Survey on Few-Shot Learning. ACM Computing Surveys. arXiv:1904.05046
Transfer Learning 12 parts
- 01 Transfer Learning (1): Fundamentals and Core Concepts
- 02 Transfer Learning (2): Pre-training and Fine-tuning
- 03 Transfer Learning (3): Domain Adaptation
- 04 Transfer Learning (4): Few-Shot Learning you are here
- 05 Transfer Learning (5): Knowledge Distillation
- 06 Transfer Learning (6): Multi-Task Learning
- 07 Transfer Learning (7): Zero-Shot Learning
- 08 Transfer Learning (8): Multimodal Transfer
- 09 Transfer Learning (9): Parameter-Efficient Fine-Tuning
- 10 Transfer Learning (10): Continual Learning
- 11 Transfer Learning (11): Cross-Lingual Transfer
- 12 Transfer Learning (12): Industrial Applications and Best Practices