Series · Transfer Learning · Chapter 10

Transfer Learning (10): Continual Learning

Derive catastrophic forgetting from gradient interference and the Fisher information matrix. Covers EWC, MAS, LwF, replay (ER/A-GEM), dynamic architectures, the three CL scenarios, FWT/BWT metrics, and a from-scratch EWC implementation.

You can teach yourself to play guitar this year and you will still remember how to ride a bike. A neural network cannot. Fine-tune a vision model on CIFAR-10 then on SVHN, evaluate it on CIFAR-10 again, and accuracy collapses to barely above chance. The phenomenon is called catastrophic forgetting, and overcoming it is the central problem of continual learning (CL): a learner that absorbs a stream of tasks $\mathcal{T}_1, \mathcal{T}_2, \ldots$ without re-accessing past data and without losing what it already knew.

This post derives why forgetting happens (it is not a bug, it is the structure of SGD on overparameterised networks), then walks through the four families of solutions — regularisation, replay, dynamic architectures, meta-learning — with the math, the intuition, and a from-scratch EWC implementation.

Transfer Learning (10): Continual Learning — Chapter overview


What You Will Learn#

  • The CL problem statement and the three scenarios (Task-IL, Domain-IL, Class-IL)
  • Why SGD on a new task destroys old-task knowledge: gradient interference and the loss-landscape view
  • Fisher information as a principled measure of parameter importance
  • Regularisation methods — EWC, MAS, SI, LwF — and how they differ
  • Replay methods — Experience Replay, GEM, A-GEM — and the projection geometry of A-GEM
  • Dynamic architectures — Progressive Networks, PackNet — and their trade-offs
  • The standard metrics: average accuracy, average forgetting, and forward/backward transfer
  • A self-contained EWC implementation evaluated on Permuted MNIST

Prerequisites#

  • Neural network training, gradients, the cross-entropy loss
  • Basic familiarity with the Fisher information matrix
  • Transfer-learning fundamentals (Parts 1-6 of this series)

Problem Setup#

Tasks arrive sequentially: $\mathcal{T}_1, \mathcal{T}_2, \ldots, \mathcal{T}_T$ . When the learner trains on $\mathcal{T}_t$ it sees $\mathcal{D}_t = \{(x_i, y_i)\}$ , but $\mathcal{D}_{<t}$ is not available. After all $T$ tasks the model is tested on every task it has ever seen.

Three scenarios make the difficulty concrete (van de Ven & Tolias, 2019):

Three CL Scenarios: Task-IL, Domain-IL, Class-IL

  • Task-IL. The task identity is known at test time. The model can use a per-task head — only the shared trunk competes for capacity.
  • Domain-IL. The label space is fixed but the input distribution shifts (clean -> rotated -> noisy MNIST). One head; no test-time task ID.
  • Class-IL. Each task introduces new classes and the learner must classify across all classes seen so far without knowing which task a sample came from. This is the hardest setting and the one most relevant to deployment.
$$ \mathrm{Avg} \;=\; \frac{1}{T}\sum_{j=1}^{T} R_{T,j}, \qquad \mathrm{Forgetting} \;=\; \frac{1}{T-1}\sum_{j=1}^{T-1}\!\left(\max_{t \le T} R_{t,j} - R_{T,j}\right). $$ $$ \mathrm{BWT} \;=\; \frac{1}{T-1}\sum_{j=1}^{T-1} (R_{T,j} - R_{j,j}), \qquad \mathrm{FWT} \;=\; \frac{1}{T-1}\sum_{j=2}^{T} (R_{j-1,j} - b_j), $$

where $b_j$ is a random/untrained baseline on task $j$ . BWT < 0 is forgetting; BWT > 0 is the rare and desirable phenomenon of positive backward transfer (learning later tasks helps earlier ones). FWT > 0 means earlier tasks pre-shape representations that help future tasks zero-shot.

Transfer matrix R[i,j] with FWT and BWT regions


Why Forgetting Happens#

Transfer Learning (10): Continual Learning — Chapter summary

Gradient interference#

$$\Delta \mathcal{L}_1 \approx -\eta\, \mathbf{g}_1 \cdot \mathbf{g}_2.$$

If $\mathbf{g}_1 \cdot \mathbf{g}_2 < 0$ , every step on task 2 increases the task-1 loss. In high-dimensional networks gradients of unrelated tasks are typically nearly orthogonal but the negative-cosine fraction is large enough to be devastating after thousands of steps.

Loss-landscape view#

The optima $\theta_1^{*}$ and $\theta_2^{*}$ live in different low-loss basins. SGD on task 2 starting from $\theta_1^{*}$ walks out of basin 1 unless something pulls it back. The figure below shows a vanilla baseline doing exactly that, alongside two repairs we will derive in the next sections.

Catastrophic forgetting on a 5-task sequence: baseline vs EWC vs replay

Fisher information = parameter importance#

$$F(\theta) \;=\; \mathbb{E}_{x \sim \mathcal{D},\, y \sim p_\theta(\cdot \mid x)}\!\left[\nabla_\theta \log p_\theta(y \mid x)\, \nabla_\theta \log p_\theta(y \mid x)^{\top}\right].$$

At a local optimum the Fisher equals the (positive semi-definite) Hessian of the negative log-likelihood, so the diagonal $F_i$ measures how steeply the loss rises when $\theta_i$ is perturbed. A large $F_i$ means $\theta_i$ is load-bearing for the task — protect it. A small $F_i$ means the loss is flat in that direction — it is safe to repurpose the parameter for a new task. Every regularisation method below is a specific answer to “how should we pick which parameters to protect?”.


Regularisation Methods#

Elastic Weight Consolidation (EWC)#

$$\mathcal{L}_A(\theta) \;\approx\; \mathcal{L}_A(\theta_A^{*}) + \tfrac{1}{2} (\theta - \theta_A^{*})^{\top} F_A\, (\theta - \theta_A^{*}).$$ $$\boxed{\;\mathcal{L}(\theta) \;=\; \mathcal{L}_B(\theta) \;+\; \frac{\lambda}{2} \sum_i F_{A,i}\, (\theta_i - \theta_{A,i}^{*})^{2}\;}$$

Geometrically EWC anchors a quadratic well at the old optimum whose curvature matches the true curvature of the old loss. Updates are cheap in directions where $F_i$ is small (the loss was flat anyway) and expensive where $F_i$ is large.

EWC penalty as a quadratic well in parameter space

$$\tilde F_t \;=\; \gamma\, \tilde F_{t-1} + F_t, \qquad \theta^{*}_{1:t} = \theta^{*}_t.$$

Picking $\lambda$ matters. Too small and forgetting wins; too large and the model becomes plastic-blind (“rigidity”). Typical ranges are $\lambda \in [10^2, 10^4]$ for Permuted MNIST and $\lambda \in [1, 10]$ for Split CIFAR.

Memory Aware Synapses (MAS)#

$$\Omega_i \;=\; \mathbb{E}_{x}\!\left[\, \left| \frac{\partial \, \tfrac{1}{2}\|f(x;\theta)\|_2^{2}}{\partial \theta_i} \right| \, \right].$$

This is unsupervised — you can compute it on unlabelled data, even on the test stream — which is a real advantage in deployed settings.

Synaptic Intelligence (SI)#

Zenke et al. (2017) compute importance online during training as the path integral of $-g_i \cdot \dot\theta_i$ along the SGD trajectory. No second pass over data is needed; the cost is folded into the optimiser.

Learning without Forgetting (LwF)#

$$\mathcal{L} \;=\; \underbrace{\mathcal{L}_{\text{CE}}\bigl(y,\, z^{\text{new}}_{\text{new heads}}\bigr)}_{\text{learn new task}} \;+\; \alpha\, \underbrace{T^{2}\, \mathrm{KL}\!\bigl(\sigma(z^{\text{old}}/T)\,\Vert\,\sigma(z^{\text{new}}_{\text{old heads}}/T)\bigr)}_{\text{don't move old outputs}}.$$

LwF needs no old data and no Fisher matrix — only the old model. The temperature $T$ (typically 2-4) softens the distributions so the distillation signal carries shape information beyond the argmax.

LwF: knowledge distillation from frozen old model


Replay Methods#

A different philosophy: keep a small slice of the past around. With even a tiny memory buffer, mixing old samples into each mini-batch is by far the strongest baseline known.

Experience replay pipeline

Experience Replay (ER)#

$$\mathcal{L} \;=\; \mathcal{L}_{\text{new}}(B_{\text{new}}) \;+\; \alpha\, \mathcal{L}_{\text{mem}}(B_{\text{mem}}),$$

then write some new samples back into $\mathcal{M}$ . Reservoir sampling keeps a uniform sample over the entire past stream with a fixed-size buffer (Vitter, 1985); class-balanced sampling guarantees coverage of every class. Empirically $|B_{\text{mem}}| = |B_{\text{new}}|$ already recovers most of the joint-training accuracy on Split-CIFAR-style benchmarks.

GEM and A-GEM#

$$ \min_{\tilde{\mathbf{g}}} \tfrac{1}{2}\|\tilde{\mathbf{g}} - \mathbf{g}_{\text{new}}\|^{2} \quad \text{s.t.} \quad \tilde{\mathbf{g}} \cdot \mathbf{g}_{k} \;\ge\; 0 \quad \forall k = 1, \ldots, t-1. $$ $$\tilde{\mathbf{g}} \;=\; \mathbf{g}_{\text{new}} \;-\; \frac{\mathbf{g}_{\text{new}} \cdot \mathbf{g}_{\text{ref}}}{\|\mathbf{g}_{\text{ref}}\|^{2}}\, \mathbf{g}_{\text{ref}} \quad \text{if } \mathbf{g}_{\text{new}} \cdot \mathbf{g}_{\text{ref}} < 0,$$

otherwise $\tilde{\mathbf{g}} = \mathbf{g}_{\text{new}}$ . The cost is one extra forward/backward on the reference batch and a single dot product — a thousand times cheaper than GEM and almost as accurate.

DER and DER++#

Buzzega et al. (2020) store both the input and the model’s logits at the time the sample was added. The replay loss becomes a logit-matching MSE, optionally combined with the original label cross-entropy. DER++ is currently among the strongest single-model baselines on most CL benchmarks.


Dynamic Architectures#

Instead of squeezing all tasks into a fixed parameter budget, grow the model.

  • Progressive Networks (Rusu et al., 2016): freeze the network after each task and add a new column for the next task, with lateral connections from frozen columns into the new one. Forgetting becomes zero by construction, but parameters and inference cost grow linearly with $T$ .
  • PackNet (Mallya & Lazebnik, 2018): after each task, prune to a sparse subset of weights and freeze them; future tasks reuse the unpruned mask. Model size is fixed but available capacity shrinks each task — after enough tasks, performance collapses.
  • Supermasks in Superposition (Wortsman et al., 2020): keep parameters random and frozen; learn a binary mask per task. Storage per task is one bit per parameter, and surprisingly, performance rivals trained baselines.

The trade-off is universal: zero forgetting either costs growing parameters or shrinking capacity. Hybrid approaches — a fixed trunk with lightweight per-task adapters (cf. Part 9 ) — are how this technology actually ships.


Implementation: EWC From Scratch#

Below is a clean PyTorch implementation that (i) computes the empirical diagonal Fisher after each task, (ii) stores it together with $\theta^{*}$ , and (iii) adds the EWC penalty to the next task’s loss. It runs on Permuted MNIST out of the box.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy

class EWC:
    """Elastic Weight Consolidation.

    After each task, call `consolidate(dataloader)` to snapshot theta*
    and the empirical Fisher diagonal. During subsequent training, add
    `lambda * ewc.penalty()` to the loss.
    """

    def __init__(self, model: nn.Module, device: str = "cpu"):
        self.model = model
        self.device = device
        self.fisher: list[dict[str, torch.Tensor]] = []
        self.opt_params: list[dict[str, torch.Tensor]] = []

    @torch.enable_grad()
    def _empirical_fisher(self, dataloader, n_samples: int = 1024
                          ) -> dict[str, torch.Tensor]:
        """Diagonal Fisher: E[(d log p(y|x; theta) / d theta)^2]."""
        self.model.eval()
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters()
                  if p.requires_grad}

        seen = 0
        for x, _ in dataloader:
            x = x.to(self.device)
            self.model.zero_grad()
            logits = self.model(x)
            # Sample y from the model's predictive distribution -- this is the
            # *true* Fisher; using the labels gives the empirical Fisher.
            probs = F.softmax(logits, dim=-1)
            y = torch.multinomial(probs, 1).squeeze(-1)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            for n, p in self.model.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.detach() ** 2 * x.size(0)
            seen += x.size(0)
            if seen >= n_samples:
                break

        for n in fisher:
            fisher[n] /= max(seen, 1)
        return fisher

    def consolidate(self, dataloader, n_samples: int = 1024) -> None:
        """Call at the END of training a task."""
        self.fisher.append(self._empirical_fisher(dataloader, n_samples))
        self.opt_params.append(
            {n: p.detach().clone() for n, p in self.model.named_parameters()
             if p.requires_grad}
        )

    def penalty(self) -> torch.Tensor:
        """Sum_t Sum_i F_{t,i} * (theta_i - theta*_{t,i})^2."""
        if not self.fisher:
            return torch.tensor(0.0, device=self.device)
        loss = torch.tensor(0.0, device=self.device)
        for F_t, theta_t in zip(self.fisher, self.opt_params):
            for n, p in self.model.named_parameters():
                if n in F_t:
                    loss = loss + (F_t[n] * (p - theta_t[n]) ** 2).sum()
        return 0.5 * loss

def train_task(model, ewc, loader, optimiser, *, ewc_lambda: float,
               epochs: int, device: str) -> None:
    model.train()
    for _ in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            loss = F.cross_entropy(model(x), y) + ewc_lambda * ewc.penalty()
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

@torch.no_grad()
def evaluate(model, loader, device: str) -> float:
    model.eval()
    correct = total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        correct += (model(x).argmax(-1) == y).sum().item()
        total += y.size(0)
    return 100.0 * correct / total

Usage sketch (Permuted MNIST):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
ewc = EWC(model, device=device)
for t, (train_loader, test_loader) in enumerate(tasks):
    opt = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)
    train_task(model, ewc, train_loader, opt,
               ewc_lambda=400.0 if t > 0 else 0.0,
               epochs=5, device=device)
    ewc.consolidate(train_loader)               # snapshot theta* and F
    accs = [evaluate(model, t_loader, device)   # check all tasks so far
            for _, t_loader in tasks[:t + 1]]
    print(f"After task {t + 1}: {accs}")

Two implementation details that matter:

  1. True vs empirical Fisher. Sampling $y$ from $p_\theta(\cdot \mid x)$ (as above) gives the true Fisher and is theoretically what the EWC derivation uses. Plugging in the dataset labels gives the empirical Fisher; in practice both work and the empirical version is slightly stronger when labels are clean.
  2. Where to compute Fisher. Compute it after you finish training the task — that is when $\theta \approx \theta_t^{*}$ and the quadratic approximation is tight.

Empirical Comparison#

The figure below shows representative numbers on the two canonical CL benchmarks. Replay dominates regularisation when memory is allowed; with no memory, EWC and LwF still substantially beat naive SGD; nothing yet matches the joint upper bound.

CL benchmarks: Permuted MNIST and Split CIFAR

Three takeaways:

  • Replay is the strongest single trick. Even 200 stored samples per task usually beats every regularisation-only method on class-IL.
  • Class-IL is much harder than task-IL. Methods that score 80% on Permuted MNIST routinely drop to 40-50% on Split CIFAR-100.
  • Combine, don’t choose. Production systems use ER + LwF (or ER + DER) and a small EWC term. Each addresses a different failure mode.

The Fisher Spectrum View of Forgetting#

Continual learning: Split-CIFAR-100 forgetting matrix + 8-method comparison (avg/FWT/BWT).

The diagonal Fisher tells one story; the full Fisher tells a richer one. EWC keeps only $F_{ii} = \mathbb{E}\!\left[(\partial \log p_\theta(y \mid x) / \partial \theta_i)^2\right]$ — a single scalar per parameter, ignoring every off-diagonal coupling. That works in practice, and the spectrum of the full Fisher is the reason why.

$$F = \frac{1}{N}\sum_{n=1}^{N} \mathbf{g}_n \mathbf{g}_n^{\top}, \qquad \mathbf{g}_n = \nabla_\theta \log p_\theta(y_n \mid x_n)$$

is a $P \times P$ Gram matrix of per-sample log-likelihood gradients. For any modern network $P$ is in the millions, so $F$ never gets materialised. But its eigenstructure is observable through Hessian-vector products, and the picture that comes back is consistent across architectures: a heavy-tailed spectrum where a tiny number of directions hold almost all the curvature.

Top-$k$ eigenvalues by power iteration#

You do not need $F$ explicitly. Each Hessian-vector product can be replaced by a Fisher-vector product because, by the Gauss-Newton identity at a minimum, $F\mathbf{v} = \mathbb{E}[\mathbf{g}(\mathbf{g}^{\top}\mathbf{v})]$ . That is a single forward, a single backward, and a dot product per sample.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import torch.nn.functional as F

def flat_grad(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.grad.detach().flatten() for p in model.parameters()
                      if p.grad is not None])

def fisher_vector_product(model, loader, v, n_samples=500, device="cpu"):
    """Compute F v where F is the empirical Fisher of p_theta(y|x)."""
    model.eval()
    Fv = torch.zeros_like(v)
    seen = 0
    for x, _ in loader:
        x = x.to(device)
        bs = x.size(0)
        for i in range(bs):
            model.zero_grad()
            logits = model(x[i:i+1])
            probs = F.softmax(logits, dim=-1)
            y = torch.multinomial(probs, 1).squeeze(-1)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            g = flat_grad(model)
            Fv += g * (g @ v)
            seen += 1
            if seen >= n_samples:
                return Fv / seen
    return Fv / seen

def lanczos_topk(model, loader, k=20, n_samples=500, device="cpu", iters=40):
    """Lanczos tridiagonalisation -> top-k eigenvalues of the Fisher."""
    P = sum(p.numel() for p in model.parameters() if p.requires_grad)
    V = torch.zeros(iters + 1, P, device=device)
    alphas = torch.zeros(iters, device=device)
    betas = torch.zeros(iters, device=device)
    V[0] = torch.randn(P, device=device); V[0] /= V[0].norm()
    for j in range(iters):
        w = fisher_vector_product(model, loader, V[j], n_samples, device)
        alphas[j] = w @ V[j]
        w = w - alphas[j] * V[j] - (betas[j-1] * V[j-1] if j > 0 else 0)
        # full reorthogonalisation -- numerically essential
        w = w - V[:j+1].T @ (V[:j+1] @ w)
        betas[j] = w.norm()
        if betas[j] < 1e-10: break
        V[j+1] = w / betas[j]
    T = torch.diag(alphas) + torch.diag(betas[:-1], 1) + torch.diag(betas[:-1], -1)
    eigs = torch.linalg.eigvalsh(T)
    return eigs.sort(descending=True).values[:k]

What the spectrum looks like#

Run this on a small CNN (two conv + two FC, ~50k params) trained on CIFAR-10 to ~70% accuracy and the top 20 eigenvalues come back roughly geometric: $\lambda_1 \approx 12.4$ , $\lambda_2 \approx 7.1$ , $\lambda_5 \approx 1.9$ , $\lambda_{20} \approx 0.08$ . Compare that to the trace $\mathrm{tr}(F) = \sum_i F_{ii}$ , which is the sum of all eigenvalues. The top 5% of directions account for roughly 80% of the trace. The Fisher is low effective rank.

That is the structural fact behind EWC: most of the curvature lives in a few directions, and those directions correlate with high diagonal entries (since $F_{ii} = \sum_k \lambda_k v_{k,i}^2$ ). Diagonal EWC throws away the eigenvectors but keeps roughly the right total mass per coordinate. It is a coarse but cheap proxy for the rank-$k$ truncation that would be optimal.

Where the high-eigenvalue mass lives#

Group the parameters by layer and ask: which layers contribute most to the top eigenvectors? Tracking $\sum_{k \le 20} v_k^{\top} P_\ell v_k$ where $P_\ell$ projects onto layer $\ell$ ’s coordinates gives a clean per-layer importance signal across training. On the CIFAR CNN the pattern is consistent — conv1 carries little top-eigenvector mass (its Fisher is diffuse), conv2 and fc1 carry most of it (sharp directions concentrate there), and the final classifier fc2 is dominated by a handful of class-specific directions. Forgetting hits fc2 first, exactly as the spectrum predicts.

Bridge#

This points at an obvious EWC upgrade: instead of penalising along the diagonal, project the parameter delta onto the top-$k$ Fisher eigenvectors and penalise there. Memory cost is $kP$ (twenty floats per parameter for $k=20$ ) rather than $P$ , but the penalty is much sharper. K-FAC EWC and its block-diagonal cousins are this idea; we will not implement them here, but the spectral picture above is the justification. The next section takes a different route — keeping diagonal EWC and instead fixing the accumulation problem with a discount.


FAQ#

How should I pick EWC’s $\lambda$ ?#

Start at 100 for MNIST-scale problems and 1-10 for CIFAR-scale. Run a sweep on the average accuracy and forgetting metric; the right $\lambda$ is the one that maximises Avg subject to Forgetting under your tolerance.

Why does my EWC degenerate to “freeze everything” after many tasks?#

Accumulated Fisher matrices keep growing — every parameter eventually gets a large $\sum_t F_{t,i}$ . Use Online EWC with $\gamma \approx 0.95$ to forget old Fisher contributions exponentially.

EWC vs MAS vs SI — which one in practice?#

EWC for clean supervised tasks. MAS when you have unlabelled streams (it does not need labels). SI when you cannot afford a second pass over data after each task — it is the cheapest because it is computed online.

How big should the replay buffer be?#

On Split-CIFAR-style benchmarks the curve typically saturates around 200-500 samples per task. The interesting regime is “as small as you can afford” — if you can afford more, replay just keeps winning.

Continual learning vs multi-task learning — aren’t they the same?#

No. Multi-task has all data simultaneously, so you optimise a fixed objective; the only challenge is task balancing. CL has tasks one at a time and forbids re-access to past data; the challenge is forgetting. CL with infinite memory and no order constraint reduces to multi-task — this is exactly the joint upper bound in the benchmark figure.

Does replay leak data?#

Yes — the buffer is literal training data. In privacy-sensitive deployments use generative replay (train a generator on past data, then sample from it for replay) or dark experience (store only logits, not inputs).

Why is Class-IL so much harder than Task-IL?#

Class-IL requires cross-task discrimination at inference time. Even with perfect retention on each task, the per-task softmax heads have not seen each other’s classes during training, so their logits are not calibrated against one another — new-class outputs typically swamp old-class ones. iCaRL (Rebuffi et al., 2017) addresses exactly this with a nearest-class-mean classifier on top of the learned features.


Summary#

Catastrophic forgetting is a structural property of SGD on a single shared parameter vector, derivable from gradient interference and the geometry of high-dimensional loss landscapes. Solutions cluster into four families:

FamilyMechanismTypical exampleTrade-off
RegularisationAnchor important parametersEWC, MAS, SI, LwFNo memory cost, weakest on class-IL
ReplayRe-train on past samplesER, A-GEM, DER++Strongest in practice; needs storage
Dynamic architecturesAdd capacity per taskProgressive Nets, PackNet, SupSupZero forgetting; growing model
Meta-learningLearn how to continue learningOML, MERPowerful but costly to meta-train

The takeaway for practitioners is direct: if you can store any data at all, run a small reservoir buffer with experience replay; layer in LwF for free regularisation through the old model snapshot; only reach for EWC/MAS when memory is impossible. The next part picks up cross-lingual transfer, where the “tasks” are languages and the same machinery — careful sharing, careful protection — carries over.


References#

  • Kirkpatrick, J., et al. (2017). Overcoming catastrophic forgetting in neural networks. PNAS, 114(13), 3521-3526.
  • Schwarz, J., et al. (2018). Progress & Compress: A scalable framework for continual learning. ICML.
  • Aljundi, R., et al. (2018). Memory Aware Synapses: Learning what (not) to forget. ECCV.
  • Zenke, F., Poole, B., & Ganguli, S. (2017). Continual learning through synaptic intelligence. ICML.
  • Li, Z., & Hoiem, D. (2017). Learning without forgetting. TPAMI, 40(12), 2935-2947.
  • Lopez-Paz, D., & Ranzato, M. (2017). Gradient episodic memory for continual learning. NeurIPS.
  • Chaudhry, A., et al. (2019). Efficient lifelong learning with A-GEM. ICLR.
  • Buzzega, P., et al. (2020). Dark experience for general continual learning. NeurIPS.
  • Rusu, A. A., et al. (2016). Progressive neural networks. arXiv:1606.04671 .
  • Mallya, A., & Lazebnik, S. (2018). PackNet: Adding multiple tasks to a single network by iterative pruning. CVPR.
  • Wortsman, M., et al. (2020). Supermasks in superposition. NeurIPS.
  • Rebuffi, S.-A., et al. (2017). iCaRL: Incremental classifier and representation learning. CVPR.
  • van de Ven, G. M., & Tolias, A. S. (2019). Three scenarios for continual learning. arXiv:1904.07734 .
  • Vitter, J. S. (1985). Random sampling with a reservoir. ACM TOMS, 11(1), 37-57.
In this series

Transfer Learning 12 parts

  1. 01 Transfer Learning (1): Fundamentals and Core Concepts
  2. 02 Transfer Learning (2): Pre-training and Fine-tuning
  3. 03 Transfer Learning (3): Domain Adaptation
  4. 04 Transfer Learning (4): Few-Shot Learning
  5. 05 Transfer Learning (5): Knowledge Distillation
  6. 06 Transfer Learning (6): Multi-Task Learning
  7. 07 Transfer Learning (7): Zero-Shot Learning
  8. 08 Transfer Learning (8): Multimodal Transfer
  9. 09 Transfer Learning (9): Parameter-Efficient Fine-Tuning
  10. 10 Transfer Learning (10): Continual Learning you are here
  11. 11 Transfer Learning (11): Cross-Lingual Transfer
  12. 12 Transfer Learning (12): Industrial Applications and Best Practices

Liked this piece?

Follow on GitHub for the next one — usually one a week.

GitHub