迁移学习(十):持续学习
从梯度干扰和 Fisher 信息出发推导灾难性遗忘的成因,系统讲解 EWC、MAS、SI、LwF 四种正则化方法,Replay/A-GEM 重放方法,动态架构与三大 CL 场景的差异,并附 Permuted MNIST 上的 EWC 从零实现。
人去年学会了弹吉他,今天还能骑自行车。神经网络做不到。让一个视觉模型先在 CIFAR-10 上微调,再拿同一个模型去微调 SVHN,回过头测 CIFAR-10——准确率会跌到接近随机猜的水平。这就是灾难性遗忘(catastrophic forgetting)。如何让模型像人一样在源源不断到来的任务流 $\mathcal{T}_1, \mathcal{T}_2, \ldots$ 中持续吸收新知识,又不丢掉旧本事,正是**持续学习(continual learning,CL)**这个领域要回答的问题——而且要在"看不到过去数据"这个硬约束下回答。
本文不打算停留在"列方法清单"。我们先从 SGD 在过参数化网络上的几何结构出发,推导遗忘为什么是结构性的、而不是工程 bug;再依次走过四类解法——正则化、重放、动态架构、元学习——把数学推导、几何直觉和工程权衡讲清楚;最后给出一份能直接跑在 Permuted MNIST 上的 EWC 实现。
你将学到
- 持续学习的问题定义和三种场景(Task-IL、Domain-IL、Class-IL)
- SGD 在新任务上为什么必然破坏旧任务知识:梯度干扰与损失曲面视角
- Fisher 信息为什么是衡量"参数重要性"的自然选择
- 正则化方法 EWC、MAS、SI、LwF 的推导与差别
- 重放方法 ER、GEM、A-GEM 的几何含义
- 动态架构 Progressive Networks、PackNet 的取舍
- 标准评估指标:平均准确率、平均遗忘度、前向/后向迁移
- 一份从零写出的、可直接复现的 EWC 实现
预备知识
- 神经网络训练、梯度、交叉熵
- Fisher 信息矩阵的基本概念
- 迁移学习基础(本系列前 6 篇)
问题设定
任务按顺序到达:$\mathcal{T}_1, \mathcal{T}_2, \ldots, \mathcal{T}_T$。学习 $\mathcal{T}_t$ 时只能看到 $\mathcal{D}_t = \{(x_i, y_i)\}$,不能回访 $\mathcal{D}_{ van de Ven 和 Tolias(2019)把"难度"拆成三档,理解 CL 必先理解这三种场景: 评估指标。 记 $R_{i,j}$ 为"训完任务 $i$ 后在任务 $j$ 上的准确率"。$T$ 个任务全部训完后: Lopez-Paz 和 Ranzato(2017)又补了两个,专门衡量"任务之间的迁移"而非"是否记得住": 其中 $b_j$ 是任务 $j$ 的随机/未训练基线。BWT < 0 是遗忘;BWT > 0 是少见且珍贵的"正向后向迁移"——学新任务反而提升了旧任务表现。FWT > 0 则说明早期任务塑造的表征对后续任务有零样本帮助。 写出两个任务的梯度 $\mathbf{g}_1 = \nabla_\theta \mathcal{L}_1$,$\mathbf{g}_2 = \nabla_\theta \mathcal{L}_2$。在任务 2 上做一步 SGD,对任务 1 损失的一阶变化是 只要 $\mathbf{g}_1 \cdot \mathbf{g}_2 < 0$,每一步都在抬高任务 1 的损失。在高维网络里两个无关任务的梯度往往近似正交,但只要"夹角钝"的方向占比不可忽略,几千步下去旧任务就被推出低损区。 任务 1 的最优点 $\theta_1^{*}$ 和任务 2 的最优点 $\theta_2^{*}$ 通常分属不同的低损盆地。从 $\theta_1^{*}$ 出发对任务 2 做 SGD,如果没有任何机制把参数往任务 1 的盆地里拉,参数就会径直走出去。下面这张图就是典型现象:朴素基线把所有旧任务的准确率都拖了下去,而 EWC 和重放则各以不同方式扛住了。 模型预测分布 $p_\theta(y \mid x)$ 的 Fisher 信息矩阵是 在局部最优点,Fisher 等于负对数似然的(半正定)Hessian,所以对角元 $F_i$ 衡量的是"扰动 $\theta_i$ 时损失上升得多陡"。$F_i$ 大 ⇒ 参数 $\theta_i$ 是任务的承重墙,要保护;$F_i$ 小 ⇒ 损失在这个方向是平的,可以安心改去服务新任务。后面所有正则化方法的差异,本质上都是"如何挑出该保护的参数"这个问题的不同回答。 Kirkpatrick 等(2017)用一个高斯近似任务 A 训完后的参数后验:均值是 $\theta_A^{*}$,精度(precision)正比于 Fisher 对角。把任务 A 的负对数似然在 $\theta_A^{*}$ 处展开到二阶: 把这个二次项作为正则加到任务 B 的目标上,就得到 EWC: 几何上,EWC 在旧最优点 $\theta_A^{*}$ 周围挖了一口二次井,井底曲率正好等于旧任务损失的真实曲率:在 $F_i$ 小的方向(旧损失本来就平),更新便宜;在 $F_i$ 大的方向(旧损失敏感),更新昂贵。 多任务情形下有两种做法:要么把历次 Fisher 累加 $F_{1:t} = \sum_{k \le t} F_k$;要么用 Online EWC(Schwarz 等,2018),用一个折扣因子 $\gamma \in (0, 1)$ 让旧 Fisher 指数衰减: $\lambda$ 怎么选很关键。太小,遗忘照旧;太大,模型变得"过分稳重"——任何新任务都学不进去(rigidity)。Permuted MNIST 量级的任务一般取 $\lambda \in [10^2, 10^4]$,Split CIFAR 这种更难的任务则在 $\lambda \in [1, 10]$ 量级。 EWC 需要标签(出现在对数似然里)。Aljundi 等(2018)改用模型输出范数平方对参数的梯度的绝对值: 这个量完全无监督——可以在无标签数据甚至线上测试流上算。在已部署系统里这是真实的工程优势。 Zenke 等(2017)干脆把"重要性"沿 SGD 轨迹做路径积分 $-g_i \cdot \dot\theta_i$,一边训练一边累加。无需第二遍数据,开销折在优化器里就行。 Li 和 Hoiem(2017)换了思路:与其约束参数不漂,不如约束输出不漂。在开始新任务前先把旧模型 $f_{\text{old}}$ 拍快照存下来。新任务训练时,对每个新任务样本 $x$ 让旧模型给出软目标 $\sigma(z^{\text{old}}/T)$,然后在新模型的"旧任务输出头"上用 KL 蒸馏过去;新输出头按常规交叉熵学新任务标签: LwF 不需要旧数据、不需要 Fisher,只需要旧模型快照。温度 $T$(通常 2-4)软化两边分布,让蒸馏信号承载比 argmax 更多的形状信息。 另一种哲学:留一点过去的样本在身边。哪怕只存一小撮旧样本,把它们拌进每个 mini-batch,至今仍是 CL 中最强的单技。 维护大小为 $N$ 的记忆缓冲 $\mathcal{M}$。每步从新流采 $B_{\text{new}}$、从 $\mathcal{M}$ 采 $B_{\text{mem}}$,优化 之后再把部分新样本写回 $\mathcal{M}$。**蓄水池采样(reservoir sampling,Vitter 1985)**用固定大小缓冲在整条流上保持均匀采样;类别均衡采样则保证每个类都被覆盖到。经验上,只要 $|B_{\text{mem}}| = |B_{\text{new}}|$,Split-CIFAR 量级的基准上就能拿回大部分联合训练(Joint)的准确率。 Lopez-Paz 和 Ranzato(2017)把"梯度更新"本身写成约束优化问题:在所有"不会让任何旧任务损失上升"的方向中,挑一个离 $\mathbf{g}_{\text{new}}$ 最近的: 这是个二次规划,每个旧任务一个约束——任务一多就跑不动。A-GEM(Chaudhry 等,2019)把它简化成"只检查一个均值参考梯度"$\mathbf{g}_{\text{ref}}$(在 $\mathcal{M}$ 上随机一批的平均梯度),仅当夹角为钝时投影: 否则 $\tilde{\mathbf{g}} = \mathbf{g}_{\text{new}}$。代价只是一次额外前后向加一次点积——比 GEM 便宜约一千倍,效果却几乎不掉。 Buzzega 等(2020)不仅存输入,还把样本入库时模型的 logits 一并存下。重放损失变成 logit 匹配的 MSE,可与原始标签的交叉熵叠加。DER++ 是目前单模型基线里在大多数 CL 基准上最强的之一。 与其在固定参数预算里塞下所有任务,不如让模型本身长大。 权衡是普适的:零遗忘必然意味着要么参数膨胀,要么容量收缩。真正能上线的方案,几乎都是"固定主干 + 每任务轻量适配器"(参考第 9 篇 PEFT)这种混合形态。 下面这份 PyTorch 实现做了三件事:(1)每个任务训完时计算经验 Fisher 对角;(2)连同 $\theta^{*}$ 一起存档;(3)后续任务训练时把 EWC 惩罚加进损失里。直接能跑 Permuted MNIST。 调用方式(Permuted MNIST): 两个真正影响效果的工程细节: 下图是两个经典 CL 基准上的代表性数字。结论很明确:有内存预算时,重放方法压倒所有正则化方法;没有内存时,EWC 和 LwF 仍能把朴素 SGD 拉开一大截;但都还差联合训练的上界一段距离。 三个要带走的判断: MNIST 量级的任务从 100 起步,CIFAR 量级从 1-10 起步。在验证集上同时看 Avg 和 Forgetting 两条曲线,选"在你能接受的遗忘度上限下让 Avg 最高"的 $\lambda$。 Fisher 在累加,每个参数迟早都背着很大的 $\sum_t F_{t,i}$。换 Online EWC,$\gamma \approx 0.95$,让旧 Fisher 指数衰减就行。 有干净标签的监督任务用 EWC;面对无标签流式数据用 MAS(不需要标签);如果连"任务结束后再扫一遍数据算 Fisher"这一步都嫌贵,就用 SI——它在训练过程中顺手就把重要性算好了。 Split-CIFAR 量级的曲线通常在每任务 200-500 样本附近开始饱和。真正有趣的区间是"你能存多少就存多少"——只要还能加,重放就还在赢。 不是。多任务所有数据同时可见,挑战只是任务之间的平衡;CL 任务逐个到达且不能回访旧数据,挑战是遗忘。CL 在"无限内存 + 无序约束"下退化成多任务——这恰好就是基准图里 Joint 上界的含义。 会——缓冲里就是真实训练样本。隐私敏感的部署用 生成式重放(先训一个生成器学旧数据分布,再从生成器采样)或 dark experience(只存 logits 不存输入)。 Class-IL 在推理时要在跨任务的全部类别中做选择。即使每个任务都没怎么遗忘,每任务的 softmax 头之间从未一起被训练过,logits 量级互不校准——新类输出常常压过旧类。iCaRL(Rebuffi 等,2017)就是为这个问题设计的:它在学到的特征上套一个最近类均值分类器,绕开 softmax 校准问题。 灾难性遗忘不是工程缺陷,而是"在共享参数向量上做 SGD"这个范式的结构性后果,可以从梯度干扰和高维损失曲面几何里推导出来。解法分四类: 落到工程上结论很直接:只要还能存任何数据,就跑一个小蓄水池缓冲做经验重放;再叠一个 LwF(不需要任何额外存储,旧模型快照已经够用);只有在内存绝对不允许时才退而求其次用 EWC/MAS。下一篇讲跨语言迁移——“任务"换成"语言”,但"小心共享、小心保护"这套思路完全适用。
![迁移矩阵 R[i,j] 与 FWT/BWT 区域](./10-%e6%8c%81%e7%bb%ad%e5%ad%a6%e4%b9%a0/fig6_transfer_matrix.png)
遗忘是怎么发生的
梯度干扰
损失曲面视角

Fisher 信息 = 参数重要性
正则化方法
Elastic Weight Consolidation(EWC)

Memory Aware Synapses(MAS)
Synaptic Intelligence(SI)
Learning without Forgetting(LwF)

重放方法

Experience Replay(ER)
GEM 与 A-GEM
DER 与 DER++
动态架构
实现:从零写一遍 EWC
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
class EWC:
"""Elastic Weight Consolidation.
每完成一个任务,调一次 `consolidate(dataloader)` 来快照 theta* 和
经验 Fisher 对角;后续任务训练时在损失上加 `lambda * ewc.penalty()`。
"""
def __init__(self, model: nn.Module, device: str = "cpu"):
self.model = model
self.device = device
self.fisher: list[dict[str, torch.Tensor]] = []
self.opt_params: list[dict[str, torch.Tensor]] = []
@torch.enable_grad()
def _empirical_fisher(self, dataloader, n_samples: int = 1024
) -> dict[str, torch.Tensor]:
"""Diagonal Fisher: E[(d log p(y|x; theta) / d theta)^2]."""
self.model.eval()
fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters()
if p.requires_grad}
seen = 0
for x, _ in dataloader:
x = x.to(self.device)
self.model.zero_grad()
logits = self.model(x)
# 从模型预测分布里采样 y -- 这是"真 Fisher";若直接用真实标签,
# 得到的是"经验 Fisher"。两者实践中都能用。
probs = F.softmax(logits, dim=-1)
y = torch.multinomial(probs, 1).squeeze(-1)
loss = F.cross_entropy(logits, y)
loss.backward()
for n, p in self.model.named_parameters():
if p.grad is not None:
fisher[n] += p.grad.detach() ** 2 * x.size(0)
seen += x.size(0)
if seen >= n_samples:
break
for n in fisher:
fisher[n] /= max(seen, 1)
return fisher
def consolidate(self, dataloader, n_samples: int = 1024) -> None:
"""在每个任务训完之后调用。"""
self.fisher.append(self._empirical_fisher(dataloader, n_samples))
self.opt_params.append(
{n: p.detach().clone() for n, p in self.model.named_parameters()
if p.requires_grad}
)
def penalty(self) -> torch.Tensor:
"""Sum_t Sum_i F_{t,i} * (theta_i - theta*_{t,i})^2."""
if not self.fisher:
return torch.tensor(0.0, device=self.device)
loss = torch.tensor(0.0, device=self.device)
for F_t, theta_t in zip(self.fisher, self.opt_params):
for n, p in self.model.named_parameters():
if n in F_t:
loss = loss + (F_t[n] * (p - theta_t[n]) ** 2).sum()
return 0.5 * loss
def train_task(model, ewc, loader, optimiser, *, ewc_lambda: float,
epochs: int, device: str) -> None:
model.train()
for _ in range(epochs):
for x, y in loader:
x, y = x.to(device), y.to(device)
loss = F.cross_entropy(model(x), y) + ewc_lambda * ewc.penalty()
optimiser.zero_grad()
loss.backward()
optimiser.step()
@torch.no_grad()
def evaluate(model, loader, device: str) -> float:
model.eval()
correct = total = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
correct += (model(x).argmax(-1) == y).sum().item()
total += y.size(0)
return 100.0 * correct / total
1
2
3
4
5
6
7
8
9
10
ewc = EWC(model, device=device)
for t, (train_loader, test_loader) in enumerate(tasks):
opt = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)
train_task(model, ewc, train_loader, opt,
ewc_lambda=400.0 if t > 0 else 0.0,
epochs=5, device=device)
ewc.consolidate(train_loader) # 快照 theta* 与 F
accs = [evaluate(model, t_loader, device) # 检查所有已学任务
for _, t_loader in tasks[:t + 1]]
print(f"After task {t + 1}: {accs}")
实证对比

常见问题
EWC 的 $\lambda$ 怎么选?
为什么我用了 EWC,跑了几十个任务后整个模型像被冻住了?
EWC、MAS、SI 实际怎么选?
重放缓冲应该多大?
持续学习和多任务学习不就是一回事吗?
重放会不会泄漏数据?
Class-IL 为什么比 Task-IL 难那么多?
小结
流派 机制 代表 取舍 正则化 锚定重要参数 EWC、MAS、SI、LwF 不耗内存,但在 class-IL 上偏弱 重放 用旧样本继续训练 ER、A-GEM、DER++ 实战最强;需要存储 动态架构 每任务加容量 Progressive Net、PackNet、SupSup 零遗忘;模型膨胀 元学习 学会"如何继续学习" OML、MER 强,但元训练昂贵 参考文献
系列导航
部分 主题 1 基础与核心概念 2 预训练与微调 3 域适应 4 小样本学习 5 知识蒸馏 6 多任务学习 7 零样本学习 8 多模态迁移 9 参数高效微调 10 持续学习(本文) 11 跨语言迁移 12 工业应用与最佳实践