系列 · 迁移学习 · 第 10 篇

迁移学习(十):持续学习

从梯度干扰和 Fisher 信息出发推导灾难性遗忘的成因,系统讲解 EWC、MAS、SI、LwF 四种正则化方法,Replay/A-GEM 重放方法,动态架构与三大 CL 场景的差异,并附 Permuted MNIST 上的 EWC 从零实现。

我今年能自学吉他,同时仍记得如何骑自行车;但神经网络却无法做到这一点。先在 CIFAR-10 上微调一个视觉模型,再在 SVHN 上继续微调——此时若重新在 CIFAR-10 上测试,准确率会骤降至接近随机猜测的水平。这一现象被称为灾难性遗忘(catastrophic forgetting)。如何让模型像人一样,在源源不断的任务流 $\mathcal{T}_1, \mathcal{T}_2, \ldots$ 中持续学习新知识,又不忘记旧技能,正是**持续学习(continual learning,CL)**要解决的核心问题——且必须在无法访问历史任务数据的前提下实现。

本文首先从过参数化网络中 SGD 的优化几何特性出发,阐明遗忘是一种结构性现象,而非工程实现上的缺陷。接着依次介绍四类主流方法:正则化、重放、动态架构和元学习,每种方法从数学原理、直观机制与工程实践三个角度展开分析。最后给出一份完整的 EWC 实现代码,可以直接在 Permuted MNIST 上运行。

迁移学习(十):持续学习 — 章节概览图


你将学到什么#

  • 持续学习的定义和三种场景:Task-IL、Domain-IL 和 Class-IL
  • 为什么用 SGD 训练新任务会破坏旧任务的知识:从梯度干扰和损失曲面的角度解释
  • Fisher 信息作为衡量参数重要性的理论基础
  • 正则化方法 EWC、MAS、SI 和 LwF 的原理与区别
  • 重放方法 Experience Replay、GEM 和 A-GEM 的几何解释,特别是 A-GEM 的投影几何
  • 动态架构 Progressive Networks 和 PackNet 的设计权衡
  • 标准评估指标:平均准确率、平均遗忘率、前向迁移和后向迁移
  • 一个完整的 EWC 实现,并在 Permuted MNIST 上进行评估

前置知识#

  • 神经网络训练、梯度、交叉熵损失
  • 对 Fisher 信息矩阵有基本了解
  • 迁移学习基础(本系列第 1 到 6 篇)

问题设定#

任务按顺序到达:$\mathcal{T}_1, \mathcal{T}_2, \ldots, \mathcal{T}_T$ 。训练 $\mathcal{T}_t$ 时,模型只能看到当前任务的数据集 $\mathcal{D}_t = \{(x_i, y_i)\}$ ,无法访问之前任务的数据 $\mathcal{D}_{<t}$ 。所有 $T$ 个任务训练完成后,模型需要在每个见过的任务上接受测试。

van de Ven 和 Tolias (2019)将持续学习划分为三类典型场景,其难度各不相同;掌握这三类场景,是理解持续学习本质的关键。

持续学习的三种场景:Task-IL、Domain-IL、Class-IL

  • Task-IL。测试阶段已知任务 ID,模型可为每个任务配置独立的输出头,主干网络则共享并仅用于特征提取。
  • Domain-IL。标签空间保持不变,但输入数据分布发生偏移(例如从标准 MNIST 逐步变为旋转或加噪版本)。模型只有一个输出头,测试时不知道样本来自哪个任务。
  • Class-IL。每个新任务引入一组新类别,模型需对所有已见类别执行统一分类,且测试时无法获知样本所属的任务来源。这是三类设定中最难的,也是实际部署中最常见的场景。
$$ \mathrm{Avg} \;=\; \frac{1}{T}\sum_{j=1}^{T} R_{T,j}, \qquad \mathrm{Forgetting} \;=\; \frac{1}{T-1}\sum_{j=1}^{T-1}\!\left(\max_{t \le T} R_{t,j} - R_{T,j}\right). $$ $$ \mathrm{BWT} \;=\; \frac{1}{T-1}\sum_{j=1}^{T-1} (R_{T,j} - R_{j,j}), \qquad \mathrm{FWT} \;=\; \frac{1}{T-1}\sum_{j=2}^{T} (R_{j-1,j} - b_j), $$

其中 $b_j$ 是任务 $j$ 的随机或未训练基线。BWT < 0 表示发生遗忘;BWT > 0 则代表罕见且宝贵的“正向后向迁移”——即学习新任务反而提升了旧任务的性能。FWT > 0 则说明早期任务塑造的表征对后续任务有零样本帮助。

迁移矩阵 R[i,j] 与 FWT/BWT 区域

遗忘是怎么发生的#

迁移学习(十):持续学习 — 章节小结图

梯度干扰#

$$\Delta \mathcal{L}_1 \approx -\eta\, \mathbf{g}_1 \cdot \mathbf{g}_2.$$

如果 $\mathbf{g}_1 \cdot \mathbf{g}_2 < 0$ ,那么每一步都会让任务 1 的损失增加。在高维网络中,无关任务的梯度通常近似正交;但一旦负余弦相似度的方向占比足够高,经过数千步更新后,旧任务的性能便会严重下降。

损失曲面视角#

任务 1 的最优解 $\theta_1^{*}$ 和任务 2 的最优解 $\theta_2^{*}$ 通常位于不同的低损区域。从 $\theta_1^{*}$ 开始对任务 2 执行 SGD,如果没有机制把参数拉回任务 1 的区域,参数就会直接离开。下图展示了该现象:朴素基线导致所有旧任务准确率显著下降,而 EWC 和重放方法则通过不同机制有效缓解了这一问题。

5 任务序列上的灾难性遗忘:Baseline / EWC / Replay 对比

Fisher 信息 = 参数重要性#

$$F(\theta) \;=\; \mathbb{E}_{x \sim \mathcal{D},\, y \sim p_\theta(\cdot \mid x)}\!\left[\nabla_\theta \log p_\theta(y \mid x)\, \nabla_\theta \log p_\theta(y \mid x)^{\top}\right].$$

在局部最优解处, Fisher 矩阵在局部最优解处等于负对数似然函数的(半正定) Hessian 矩阵。因此,对角元 $F_i$ 表示扰动 $\theta_i$ 时损失上升的陡峭程度。$F_i$ 大,说明 $\theta_i$ 是任务的关键参数,需要保护;$F_i$ 小,说明损失在这个方向上是平坦的,可以放心调整去适应新任务。后文所述各类正则化方法,本质上都在解决同一个核心问题:哪些参数对旧任务至关重要,需加以保护?

正则化方法#

弹性权重巩固 (EWC)#

$$\mathcal{L}_A(\theta) \;\approx\; \mathcal{L}_A(\theta_A^{*}) + \tfrac{1}{2} (\theta - \theta_A^{*})^{\top} F_A\, (\theta - \theta_A^{*}).$$ $$\boxed{\;\mathcal{L}(\theta) \;=\; \mathcal{L}_B(\theta) \;+\; \frac{\lambda}{2} \sum_i F_{A,i}\, (\theta_i - \theta_{A,i}^{*})^{2}\;}$$

从几何上看, EWC 在旧最优解 $\theta_A^{*}$ 周围挖了一个二次井,井底曲率匹配了旧任务损失的真实曲率。在 $F_i$ 小的方向上(旧损失本来就很平),更新代价低;在 $F_i$ 大的方向上(旧损失敏感),更新代价高。

EWC 惩罚项:参数空间中以 θA* 为中心的二次井

$$\tilde F_t \;=\; \gamma\, \tilde F_{t-1} + F_t, \qquad \theta^{*}_{1:t} = \theta^{*}_t.$$

$\lambda$ 的选择很关键。太小会导致遗忘问题;太大又会让模型变得“僵化”,无法学习新任务。对于 Permuted MNIST 类型的任务,通常取 $\lambda \in [10^2, 10^4]$ ;而对于 Split CIFAR 这类更复杂的任务,则取 $\lambda \in [1, 10]$

记忆感知突触 (MAS)#

$$\Omega_i \;=\; \mathbb{E}_{x}\!\left[\, \left| \frac{\partial \, \tfrac{1}{2}\|f(x;\theta)\|_2^{2}}{\partial \theta_i} \right| \, \right].$$

该指标完全无监督:既可在无标签数据上计算,也可直接作用于测试数据流。在实际部署中,这是一个很大的工程优势。

突触智能 (SI)#

Zenke 等人在 2017 年提出了一种方法,在训练过程中沿 SGD 轨迹计算“重要性”,具体是对 $-g_i \cdot \dot\theta_i$ 做路径积分。这种方法不需要第二遍数据,开销直接融入优化器中。

无遗忘学习 (LwF)#

$$ \mathcal{L} \;=\; \underbrace{\mathcal{L}_{\text{CE}}\bigl(y,\, z^{\text{new}}_{\text{new heads}}\bigr)}_{\text{学新任务}} \;+\; \alpha\, \underbrace{T^{2}\, \mathrm{KL}\!\bigl(\sigma(z^{\text{old}}/T)\,\Vert\,\sigma(z^{\text{new}}_{\text{old heads}}/T)\bigr)}_{\text{别动旧输出}}. $$

LwF 无需访问旧任务数据,也无需计算 Fisher 矩阵,仅需保存旧模型的参数快照。温度 $T$ (通常为 2-4)会软化分布,让蒸馏信号承载比 argmax 更丰富的形状信息。

LwF:来自冻结旧模型的知识蒸馏

重放方法#

另一种思路:保留一小部分过去的样本。即使缓冲区容量很小,只要将历史样本混合到每个 mini-batch 中,其效果已是当前持续学习(CL)领域最强的基线方法之一。

Experience Replay 流水线

经验回放 (ER)#

$$\mathcal{L} \;=\; \mathcal{L}_{\text{new}}(B_{\text{new}}) \;+\; \alpha\, \mathcal{L}_{\text{mem}}(B_{\text{mem}}),$$

接着将一部分新样本写回 $\mathcal{M}$ 。**蓄水池采样(Reservoir Sampling, Vitter 1985)**用固定大小的缓冲区对整个历史数据流进行均匀采样;类别均衡采样则确保每个类别都被覆盖到。实验表明,只要 $|B_{\text{mem}}| = |B_{\text{new}}|$ ,在类似 Split-CIFAR 的基准测试中,就能恢复大部分联合训练的准确率。

GEM 和 A-GEM#

$$ \min_{\tilde{\mathbf{g}}} \tfrac{1}{2}\|\tilde{\mathbf{g}} - \mathbf{g}_{\text{new}}\|^{2} \quad \text{s.t.} \quad \tilde{\mathbf{g}} \cdot \mathbf{g}_{k} \;\ge\; 0 \quad \forall k = 1, \ldots, t-1. $$ $$\tilde{\mathbf{g}} \;=\; \mathbf{g}_{\text{new}} \;-\; \frac{\mathbf{g}_{\text{new}} \cdot \mathbf{g}_{\text{ref}}}{\|\mathbf{g}_{\text{ref}}\|^{2}}\, \mathbf{g}_{\text{ref}} \quad \text{当 } \mathbf{g}_{\text{new}} \cdot \mathbf{g}_{\text{ref}} < 0,$$

否则直接取 $\tilde{\mathbf{g}} = \mathbf{g}_{\text{new}}$ 。该方法仅需一次额外的前向与反向传播,以及一次点积运算,计算开销约为 GEM 的千分之一,而性能几乎相当。

DER 和 DER++#

Buzzega 等(2020)不仅存储输入数据,还保存样本入库时模型输出的 logits。重放损失被设计为 logits 匹配的均方误差(MSE),可以与原始标签的交叉熵结合使用。 DER++ 是目前单模型方法中,在大多数 CL 基准测试上表现最强的基线之一。

动态架构#

与其把所有任务硬塞进固定的参数预算,不如让模型自己扩展。

  • Progressive Networks(Rusu 等, 2016):每完成一个任务,就把网络冻结,再加一列处理新任务。新列通过侧向连接从冻结的列中提取特征。这种方法天生不会遗忘,但参数量和推理成本会随着任务数 $T$ 线性增长。
  • PackNet(Mallya 和 Lazebnik, 2018):每完成一个任务,就剪枝出一个稀疏权重子集并冻结。后续任务只能用剩下的部分。模型大小固定,但可用容量会随着任务增加逐渐减少,最终性能会崩塌。
  • Supermasks in Superposition(Wortsman 等, 2020):参数随机初始化后直接冻结,每个任务只学习一个二值掩码。每个任务的存储开销是每参数 1 比特,但效果却意外地接近训练好的基线模型。

这种权衡不可避免:实现零遗忘,要么需增加参数量,要么会牺牲可用容量。在实际应用中,真正具备落地潜力的方案通常是混合方法:固定主干网络,再为每个任务添加轻量级适配器(参见本系列第 9 篇关于 PEFT 的内容)。

实现:从零开始写 EWC#

下面是一个简洁的 PyTorch 实现,做了三件事:

  1. 每个任务训练完后计算经验 Fisher 对角矩阵;
  2. 存储 $\theta^{*}$ 和 Fisher 对角矩阵;
  3. 在后续任务训练时将 EWC 惩罚加入损失函数。
    该代码开箱即用,可直接在 Permuted MNIST 数据集上运行。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy

class EWC:
    """Elastic Weight Consolidation.

    每完成一个任务,调用 `consolidate(dataloader)` 快照 theta* 和经验 Fisher 对角矩阵。
    后续任务训练时,在损失函数中加上 `lambda * ewc.penalty()`。
    """

    def __init__(self, model: nn.Module, device: str = "cpu"):
        self.model = model
        self.device = device
        self.fisher: list[dict[str, torch.Tensor]] = []
        self.opt_params: list[dict[str, torch.Tensor]] = []

    @torch.enable_grad()
    def _empirical_fisher(self, dataloader, n_samples: int = 1024
                          ) -> dict[str, torch.Tensor]:
        """Diagonal Fisher: E[(d log p(y|x; theta) / d theta)^2]."""
        self.model.eval()
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters()
                  if p.requires_grad}

        seen = 0
        for x, _ in dataloader:
            x = x.to(self.device)
            self.model.zero_grad()
            logits = self.model(x)
            # 从模型预测分布采样 y -- 这是"真 Fisher";用真实标签则是"经验 Fisher"。
            probs = F.softmax(logits, dim=-1)
            y = torch.multinomial(probs, 1).squeeze(-1)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            for n, p in self.model.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.detach() ** 2 * x.size(0)
            seen += x.size(0)
            if seen >= n_samples:
                break

        for n in fisher:
            fisher[n] /= max(seen, 1)
        return fisher

    def consolidate(self, dataloader, n_samples: int = 1024) -> None:
        """每个任务训练结束后调用。"""
        self.fisher.append(self._empirical_fisher(dataloader, n_samples))
        self.opt_params.append(
            {n: p.detach().clone() for n, p in self.model.named_parameters()
             if p.requires_grad}
        )

    def penalty(self) -> torch.Tensor:
        """Sum_t Sum_i F_{t,i} * (theta_i - theta*_{t,i})^2."""
        if not self.fisher:
            return torch.tensor(0.0, device=self.device)
        loss = torch.tensor(0.0, device=self.device)
        for F_t, theta_t in zip(self.fisher, self.opt_params):
            for n, p in self.model.named_parameters():
                if n in F_t:
                    loss = loss + (F_t[n] * (p - theta_t[n]) ** 2).sum()
        return 0.5 * loss

def train_task(model, ewc, loader, optimiser, *, ewc_lambda: float,
               epochs: int, device: str) -> None:
    model.train()
    for _ in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            loss = F.cross_entropy(model(x), y) + ewc_lambda * ewc.penalty()
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

@torch.no_grad()
def evaluate(model, loader, device: str) -> float:
    model.eval()
    correct = total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        correct += (model(x).argmax(-1) == y).sum().item()
        total += y.size(0)
    return 100.0 * correct / total

调用示例(Permuted MNIST):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
ewc = EWC(model, device=device)
for t, (train_loader, test_loader) in enumerate(tasks):
    opt = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)
    train_task(model, ewc, train_loader, opt,
               ewc_lambda=400.0 if t > 0 else 0.0,
               epochs=5, device=device)
    ewc.consolidate(train_loader)               # 快照 theta* 和 Fisher
    accs = [evaluate(model, t_loader, device)   # 检查所有已学任务
            for _, t_loader in tasks[:t + 1]]
    print(f"After task {t + 1}: {accs}")

两个关键实现细节:

  1. 真 Fisher vs 经验 Fisher
    $p_\theta(\cdot \mid x)$ 采样 $y$ 得到的是理论推导中的“真 Fisher”;直接用数据集标签得到的是“经验 Fisher”。两者在实践中都有效,但经验 Fisher 在标签干净时略强一些,工程上也更简单。

  2. Fisher 的计算时机
    必须在任务训练结束后计算 Fisher,这时 $\theta \approx \theta_t^{*}$ ,二阶近似最准确。

实证对比#

下图展示了两个经典 CL 基准测试中的代表性数据。结论很清晰:如果有内存预算,重放方法完胜正则化方法;如果没有内存, EWC 和 LwF 依然远超朴素 SGD;但目前还没有方法能达到联合训练的上限。

CL 基准:Permuted MNIST 与 Split CIFAR

三个关键点:

  • 重放是最强的单一技巧。每个任务只存 200 个样本,通常就能在 class-IL 上超越所有纯正则化方法。
  • Class-IL 比 Task-IL 难得多。在 Permuted MNIST 上能跑到 80% 的方法,到了 Split CIFAR-100 上往往只剩 40-50%。
  • 组合优于选择。实际生产系统一般会用 ER + LwF (或 ER + DER),再加一点 EWC。每种方法解决的是不同的失败模式。

遗忘的 Fisher 谱视角#

持续学习:Split-CIFAR-100 遗忘矩阵 + 8 种方法对比 (avg/FWT/BWT)。

对角 Fisher 讲述了一个故事;完整 Fisher 则揭示了更丰富的图景。EWC 仅保留 $F_{ii} = \mathbb{E}\!\left[(\partial \log p_\theta(y \mid x) / \partial \theta_i)^2\right]$ —— 每个参数一个标量,忽略了所有非对角耦合。这在实践中有效,而完整 Fisher 的谱结构正是其背后的原因。

$$F = \frac{1}{N}\sum_{n=1}^{N} \mathbf{g}_n \mathbf{g}_n^{\top}, \qquad \mathbf{g}_n = \nabla_\theta \log p_\theta(y_n \mid x_n)$$

这是一个 $P \times P$ 的 Gram 矩阵,由每个样本的对数似然梯度构成。对于现代网络,$P$ 通常达数百万,因此 $F$ 从不显式构造。但通过 Hessian-向量积可观察其特征结构,结果在不同架构下高度一致:谱呈重尾分布,极少数方向承载了几乎全部曲率。

通过幂迭代求前 $k$ 个特征值#

无需显式构造 $F$ 。每个 Hessian-向量积可用 Fisher-向量积替代,因为在极小值点处,由 Gauss-Newton 恒等式有 $F\mathbf{v} = \mathbb{E}[\mathbf{g}(\mathbf{g}^{\top}\mathbf{v})]$ 。这只需一次前向传播、一次反向传播和每个样本的一次点积。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import torch.nn.functional as F

def flat_grad(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.grad.detach().flatten() for p in model.parameters()
                      if p.grad is not None])

def fisher_vector_product(model, loader, v, n_samples=500, device="cpu"):
    """计算 F v,其中 F 是 p_theta(y|x) 的经验 Fisher 矩阵。"""
    model.eval()
    Fv = torch.zeros_like(v)
    seen = 0
    for x, _ in loader:
        x = x.to(device)
        bs = x.size(0)
        for i in range(bs):
            model.zero_grad()
            logits = model(x[i:i+1])
            probs = F.softmax(logits, dim=-1)
            y = torch.multinomial(probs, 1).squeeze(-1)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            g = flat_grad(model)
            Fv += g * (g @ v)
            seen += 1
            if seen >= n_samples:
                return Fv / seen
    return Fv / seen

def lanczos_topk(model, loader, k=20, n_samples=500, device="cpu", iters=40):
    """Lanczos 三对角化 → Fisher 矩阵的前 k 个特征值。"""
    P = sum(p.numel() for p in model.parameters() if p.requires_grad)
    V = torch.zeros(iters + 1, P, device=device)
    alphas = torch.zeros(iters, device=device)
    betas = torch.zeros(iters, device=device)
    V[0] = torch.randn(P, device=device); V[0] /= V[0].norm()
    for j in range(iters):
        w = fisher_vector_product(model, loader, V[j], n_samples, device)
        alphas[j] = w @ V[j]
        w = w - alphas[j] * V[j] - (betas[j-1] * V[j-1] if j > 0 else 0)
        # 完全重正交化 —— 数值稳定性必需
        w = w - V[:j+1].T @ (V[:j+1] @ w)
        betas[j] = w.norm()
        if betas[j] < 1e-10: break
        V[j+1] = w / betas[j]
    T = torch.diag(alphas) + torch.diag(betas[:-1], 1) + torch.diag(betas[:-1], -1)
    eigs = torch.linalg.eigvalsh(T)
    return eigs.sort(descending=True).values[:k]

谱的实际形态#

在小型 CNN(两层卷积 + 两层全连接,约 5 万参数)上训练 CIFAR-10 至约 70% 准确率后,前 20 个特征值大致呈几何衰减:$\lambda_1 \approx 12.4$$\lambda_2 \approx 7.1$$\lambda_5 \approx 1.9$$\lambda_{20} \approx 0.08$ 。对比迹 $\mathrm{tr}(F) = \sum_i F_{ii}$ (即所有特征值之和),前 5% 的方向贡献了约 80% 的迹。Fisher 矩阵具有低有效秩

这正是 EWC 背后的结构性事实:大部分曲率集中在少数方向上,且这些方向与高对角元相关(因 $F_{ii} = \sum_k \lambda_k v_{k,i}^2$ )。对角 EWC 舍弃了特征向量,但大致保留了每个坐标的总质量。这是一种粗糙但廉价的近似,相当于最优的秩-$k$ 截断。

高特征值质量所在位置#

按层分组参数并探究:哪些层对前几个特征向量贡献最大?通过追踪 $\sum_{k \le 20} v_k^{\top} P_\ell v_k$ (其中 $P_\ell$ 投影到第 $\ell$ 层坐标)可得到清晰的逐层重要性信号。在 CIFAR CNN 上,模式一致:conv1 几乎不承载前特征向量质量(其 Fisher 分布弥散),conv2fc1 承载大部分(尖锐方向集中于此),而最终分类器 fc2 由少量类别特定方向主导。遗忘首先影响 fc2,与谱预测完全吻合。

过渡#

这自然引向 EWC 的明显升级:不再沿对角线施加惩罚,而是将参数增量投影到前 $k$ 个 Fisher 特征向量上并在该子空间施加惩罚。内存开销为 $kP$$k=20$ 时每个参数需 20 个浮点数),而非 $P$ ,但惩罚更精准。K-FAC EWC 及其块对角变体正是此思想;此处不实现它们,但上述谱图景为其提供了理论依据。下一节采取不同路径——保留对角 EWC,转而通过折扣机制解决累积问题。


常见问题#

怎么选 EWC 的 $\lambda$#

MNIST 规模的问题从 100 开始, CIFAR 规模的从 1 到 10 起步。跑一个超参数搜索,同时看平均准确率和遗忘指标。合适的 $\lambda$ 是在你容忍的遗忘范围内让平均准确率达到最高的那个。

为什么我的 EWC 跑了很多任务后变成“冻结所有参数”了?#

累计的 Fisher 矩阵一直在增长,每个参数最终都会积累出很大的 $\sum_t F_{t,i}$ 。换成 Online EWC,设置 $\gamma \approx 0.95$ ,让旧的 Fisher 贡献指数衰减就行。

实际用 EWC、 MAS 还是 SI?#

干净的监督学习任务用 EWC;无标签数据流用 MAS (它不需要标签);如果连任务结束后再扫一遍数据计算 Fisher 都嫌贵,那就用 SI——它最便宜,因为重要性是在线计算的。

重放缓冲区应该多大?#

Split-CIFAR 类型的基准测试中,曲线通常在每任务 200 到 500 样本时趋于饱和。有趣的地方在于“能存多少就存多少”——只要资源允许,重放的效果会一直提升。

持续学习和多任务学习不是一回事吗?#

不是。多任务学习有所有数据同时可用,优化目标固定,唯一挑战是任务间的平衡。持续学习是任务逐个到达且不能访问旧数据,核心挑战是遗忘。如果持续学习有无限内存且没有顺序限制,就会退化成多任务学习——这正是基准图里 Joint 上界的意义。

重放会不会泄露数据?#

会,缓冲区里存的就是真实的训练数据。隐私敏感的场景可以用 生成式重放(先用旧数据训练一个生成器,然后从中采样进行重放),或者 dark experience(只存 logits,不存输入)。

为什么 Class-IL 比 Task-IL 难得多?#

Class-IL 在推理时需要在跨任务的所有类别中做区分。即使每个任务都记得很好,各任务的 softmax 头之间从未一起训练过, logits 的量级也没对齐——新类别的输出往往会压过旧类别。 iCaRL (Rebuffi 等, 2017)就是为解决这个问题设计的:它在学到的特征上加了一个最近类均值分类器,绕过了 softmax 校准问题。

总结#

灾难性遗忘不是工程上的缺陷,而是 SGD 在共享参数向量上优化时的结构性结果。它可以从梯度干扰和高维损失曲面的几何特性中推导出来。解决方法大致分为四类:

流派机制代表取舍
正则化锚定重要参数EWC、 MAS、 SI、 LwF不占内存,但在 class-IL 上效果较弱
重放利用旧样本重新训练ER、 A-GEM、 DER++实际效果最好;需要存储数据
动态架构每个任务增加容量Progressive Net、 PackNet、 SupSup零遗忘;模型规模会变大
元学习学习“如何持续学习”OML、 MER能力强,但元训练成本高

对工程师来说,结论很简单:只要能存下数据,就用一个小缓冲池做经验重放;再加一个 LwF,利用旧模型快照免费获得正则化效果;只有在完全无法使用内存时,才考虑 EWC 或 MAS。接下来讲跨语言迁移——任务变成了语言,但“小心共享、小心保护”的思路依然适用。


参考文献#

  • Kirkpatrick, J., et al. (2017). Overcoming catastrophic forgetting in neural networks. PNAS, 114(13), 3521-3526.
  • Schwarz, J., et al. (2018). Progress & Compress: A scalable framework for continual learning. ICML.
  • Aljundi, R., et al. (2018). Memory Aware Synapses: Learning what (not) to forget. ECCV.
  • Zenke, F., Poole, B., & Ganguli, S. (2017). Continual learning through synaptic intelligence. ICML.
  • Li, Z., & Hoiem, D. (2017). Learning without forgetting. TPAMI, 40(12), 2935-2947.
  • Lopez-Paz, D., & Ranzato, M. (2017). Gradient episodic memory for continual learning. NeurIPS.
  • Chaudhry, A., et al. (2019). Efficient lifelong learning with A-GEM. ICLR.
  • Buzzega, P., et al. (2020). Dark experience for general continual learning. NeurIPS.
  • Rusu, A. A., et al. (2016). Progressive neural networks. arXiv:1606.04671 .
  • Mallya, A., & Lazebnik, S. (2018). PackNet: Adding multiple tasks to a single network by iterative pruning. CVPR.
  • Wortsman, M., et al. (2020). Supermasks in superposition. NeurIPS.
  • Rebuffi, S.-A., et al. (2017). iCaRL: Incremental classifier and representation learning. CVPR.
  • van de Ven, G. M., & Tolias, A. S. (2019). Three scenarios for continual learning. arXiv:1904.07734 .
  • Vitter, J. S. (1985). Random sampling with a reservoir. ACM TOMS, 11(1), 37-57.
本系列

迁移学习 12 篇

  1. 01 迁移学习(一):基础与核心概念
  2. 02 迁移学习(二):预训练与微调
  3. 03 迁移学习(三):域适应
  4. 04 迁移学习(四):小样本学习
  5. 05 迁移学习(五):知识蒸馏
  6. 06 迁移学习(六):多任务学习
  7. 07 迁移学习(七):零样本学习
  8. 08 迁移学习(八):多模态迁移
  9. 09 迁移学习(九):参数高效微调
  10. 10 迁移学习(十):持续学习 当前
  11. 11 迁移学习(十一):跨语言迁移
  12. 12 迁移学习(十二):工业应用与最佳实践

读有所得?

GitHub 关注我 → 新文周更

GitHub