Series · Transfer Learning · Chapter 6

迁移学习(六):多任务学习

多任务学习完全指南:硬/软参数共享、梯度冲突与 PCGrad/GradNorm/CAGrad、辅助任务设计,以及完整的多任务框架实现。

一辆自动驾驶汽车透过同一个摄像头要同时干三件事:检测车辆和行人、分割车道和可行驶区域、估计每个像素的距离。你完全可以训练三个独立的网络。代价是参数量乘以三、推理时多跑两次前向、并且白白浪费一个最显然的事实——这三个任务都需要同样的底层特征(边缘、表面、遮挡线索)。

多任务学习(Multi-Task Learning,MTL)走的是另一条路:一个共享主干、每个输出一个任务头,全部联合训练。做对了,参数减少 60% 每个任务的精度都提升,因为各任务之间互相起到了正则化作用;做砸了,三个任务里有两个反而退化,你会花一个礼拜苦想到底哪里出了问题。

本文要讲的是怎么"做对"。难点不在架构——架构就是一张图。真正的难点是:(1) 交叉熵损失和 L2 深度损失之间几十倍上百倍的尺度差;(2) 两任务方向不一致时占到 30-50% 比例的梯度冲突;(3) 怎么判断哪些任务该不该放进同一个模型。我们会过一遍架构(硬共享 vs 软共享、Cross-Stitch、MTAN),过一遍能在真实损失曲面上活下来的优化器(不确定性加权、GradNorm、PCGrad、CAGrad),最后给一个把这些方法都串起来的可运行 PyTorch 框架。

你将学到什么

  • 多任务学习为什么有效——正则化、数据增强和效率三个角度
  • 硬参数共享 vs 软参数共享,以及 Cross-Stitch / MTAN 的中间地带
  • 动手做之前怎么先量化任务相关性
  • 什么是梯度冲突、出现得有多频繁、PCGrad 和 CAGrad 怎么消除它
  • 用不确定性加权(Kendall et al.)和 GradNorm 平衡损失尺度
  • 三种平衡方法的完整 PyTorch 实现

前置知识: 本系列前两篇,以及 PyTorch 训练模型的基础经验。


为什么要做多任务学习

共享结构:需要相同特征的任务

多任务学习最有说服力的场景,是几个任务都依赖同一套底层表示:

任务特征需要编码什么
目标检测空间布局、物体边界、纹理
语义分割空间布局、物体边界、纹理
深度估计空间布局、纹理、几何线索

三个任务,一套底层特征。三个独立编码器意味着每个都要从零开始学边缘、表面和形状先验。共享编码器则强迫这套特征只学一次,并且来自三个任务的监督信号都在同一个方向推它。

正则化视角

设有 $T$ 个任务,损失为 $\mathcal{L}_1, \ldots, \mathcal{L}_T$,共享参数 $\theta_{\text{sh}}$,任务专属参数 $\theta_t$,联合目标为:

$$ \mathcal{L}_{\text{MTL}} \;=\; \sum_{t=1}^{T} w_t \cdot \mathcal{L}_t(\theta_{\text{sh}}, \theta_t). $$

共享参数 $\theta_{\text{sh}}$ 必须落在"对每个任务都好"的区域的交集里,而这个交集比任何单任务的可行区都小得多——这就相当于给 $\theta_{\text{sh}}$ 加了一个隐式先验。Caruana 1997 年的开山之作就用实验展示了这种正则化效应:模型对任何单一任务的噪声都更难过拟合。

数据增强视角

当主任务数据不足时,相关的辅助任务可以通过共享参数为主任务提供额外的监督信号。

举个具体的例子: 低资源机器翻译(英文→斯瓦希里语,约 10 万句对)。引入英文→法语作为辅助任务(约 1000 万句对)。共享的英语编码器突然多看了 100 倍的英文句子。斯瓦希里语这一侧没有直接受益,但它依赖的编码器变好了——文献中报告的低资源任务 BLEU 提升通常在 5-20%。

计算效率

方案参数量前向次数
三个独立 ResNet-5075 M3
一个共享编码器 + 三个轻量头部31 M1

参数减少约 60%、单次前向就能拿到三个预测——这正是图 6 右边所展示的。对实时系统(自动驾驶、AR、端侧推理)来说,这往往是上多任务的首要理由。

风险:负迁移

多任务学习不是免费午餐。当任务真正需要不同的特征时,联合训练会比分开训练更差。经典反例:

  • 人脸识别需要面部内部细粒度的纹理细节。
  • 场景分类需要全局粗粒度的布局。

强行让两者共享同一个主干,结果是这个主干两边都干得平庸。解决办法不是放弃 MTL,而是:(a) 训练前先量化冲突(下一节),(b) 用软共享让两个任务可以分开走,或者 (c) 用梯度手术阻止一个任务主动去拖另一个的后腿。


参数共享策略

硬参数共享与软参数共享

硬参数共享

最常见、也仍然是最强的基线。所有任务共享同一套主干,每个任务带一个自己的头:

$$ \text{features} \;=\; G_{\text{shared}}(x), \qquad \hat{y}_t \;=\; G_t^{\text{head}}(\text{features}) \quad \forall\, t. $$

经验设计原则:

  • 网络前 70-80% 的层共享(学通用特征)。
  • 后 20-30% 留给任务(每个头 1-3 层即可)。
  • 头部要给得够宽,让任务专属模式有地方放。

硬共享拥有最强的正则化、最少的参数,而且基本不会配错。永远先从这里开始。

软参数共享

每个任务有自己的主干,但通过一个耦合项鼓励同位置的参数保持相似:

$$ \mathcal{L} \;=\; \sum_t \mathcal{L}_t(\theta_t) \;+\; \lambda \!\!\sum_{i \neq j} \! \lVert \theta_i - \theta_j \rVert^2. $$

图 1 右半边那些黄色虚线就是这些耦合项。模型可以一层一层地打破对称——当任务需要相似但不完全相同的特征时很有用。

Cross-Stitch Networks

Cross-Stitch 网络

Cross-Stitch(Misra et al., 2016)介于硬共享和软共享之间。每个任务有自己的一列网络,但每一层都通过一个小小的 cross-stitch 单元让两路激活互相混合:

$$ \tilde{x}_A^{\,l} \;=\; \alpha_{AA}\, x_A^{\,l} + \alpha_{AB}\, x_B^{\,l}, \qquad \tilde{x}_B^{\,l} \;=\; \alpha_{BA}\, x_A^{\,l} + \alpha_{BB}\, x_B^{\,l}. $$

每层的四个标量 $\alpha_{\bullet\bullet}$ 是学出来的。它们的取值具有可解释性:在某一层 $\alpha_{AB}$ 很大,意味着任务 $A$ 在这个深度上严重依赖任务 $B$ 的特征——既能作为架构手段,也能当诊断工具用。

Multi-Task Attention Network(MTAN)

MTAN(Liu et al., 2019)共享一个主干,但让每个任务通过 sigmoid 注意力 mask 在共享特征里挑出自己想要的子集:

$$ \text{mask}_t = \sigma(W_t \cdot F_{\text{shared}} + b_t), \qquad F_t = \text{mask}_t \odot F_{\text{shared}}. $$

mask 是按任务、按层独立的,所以每个任务可以在不同深度"调谐"到不同的通道。在视觉 MTL 里,MTAN 通常是软共享变体里最强的一个。

怎么选

  • 硬共享。 默认。任务紧密相关(同模态、同抽象层级)。
  • Cross-Stitch / MTAN。 硬共享出现负迁移,但任务整体上仍有共性。
  • 完全独立或纯软共享。 任务几乎没什么共性——这种情况要先反问一句:MTL 真的是合适的工具吗?

动手做之前:先量化任务亲和性

任务亲和矩阵和分组

不要凭直觉决定哪些任务塞进一个 MTL 模型。下面三个量化检验都很便宜,开工前值得做。

  1. 迁移实验亲和性(Taskonomy 风格)。先在任务 $A$ 上训,再迁到任务 $B$ 微调,与从零训练 $B$ 比较。提升越大,亲和性越高。
  2. 梯度余弦相似度。 联合训练一个小模型一个 epoch,记录每步的 $\cos(\nabla_\theta \mathcal{L}_A, \nabla_\theta \mathcal{L}_B)$。持续为负 = 任务在打架。
  3. 特征相似性(CKA)。 比较各任务学到的表示。CKA 高 = 同一个主干能同时服务两者。

图 7 左边是七个视觉任务的典型亲和矩阵。注意 Detect / Segment / Edges 之间紧密成团(全部 0.78 以上),而 Caption 与其他任务联系松散得多。右边的树状图把这些数字翻译成了一个具体的分组建议:与其训一个巨型共享编码器,不如分组:

组 1:Detect + Segment + Edges       -> 共享编码器 A
组 2:Depth + Normals                -> 共享编码器 B
组 3:Pose + Caption                 -> 共享编码器 C

Standley et al. (2020) 的实验表明,这种自动分组(强化学习或层次聚类得到的)在多个数据集上都稳定优于手工分组和单一全局编码器。


梯度冲突与任务平衡

梯度冲突与 PCGrad 投影

绝大多数 MTL 项目的性能就是在这里悄悄漏掉的。架构没毛病,但优化器在背后拿一个任务去伤另一个任务。

“冲突"到底是什么

两个任务在共享参数上的梯度冲突,是指:

$$ \nabla \mathcal{L}_1 \cdot \nabla \mathcal{L}_2 \;<\; 0. $$

图 3 左边把几何关系画得一目了然。任务 1 的梯度 $g_1$ 指向右上,任务 2 的 $g_2$ 指向左上。两者的平均 $\bar{g}$ 在 $g_1$ 方向上的分量几乎为零——也就是说朴素的"损失求和"梯度对任务 1 几乎毫无帮助。这个例子里,$g_1$ 与 $g_2$ 的余弦相似度是 $-0.43$。

实际中冲突有多频繁?图 3 右边给了一个真实多任务训练里 $\cos$ 值分布的代表性例子:大约 45% 的更新存在冲突,而且常常会在多个 epoch 里持续存在。

静态权重(基线)

均匀 ($w_t = 1$):简单,偶尔够用,但损失尺度不一致时会崩。分类交叉熵 $\sim 1$ 和深度 MSE $\sim 100$ 直接相加,等于优化器基本只在做单任务深度学习。

手工调权重: 一两个任务、有一个礼拜空闲时还能玩。多了以后行不通。

不确定性加权(Kendall et al., 2018)

不确定性加权

把每个任务的输出建模成一个带可学习噪声 $\sigma_t$ 的高斯分布。负对数似然就变成:

$$ \mathcal{L} \;=\; \sum_t \frac{1}{2\sigma_t^2}\, \mathcal{L}_t \;+\; \log \sigma_t. $$

图 5 里要看的有两点:

  • 左图。 对一个固定的任务损失 $\mathcal{L}$,组合目标在 $\sigma^* = \sqrt{\mathcal{L}}$ 处有唯一极小值。如果没有 $\log \sigma$ 那一项,优化器会一路把 $\sigma_t$ 推到无穷大让加权损失趋零——正是这个正则项让整个把戏不至于退化。
  • 右图。 三个任务的原始损失横跨两个数量级(1500.3),经过学到的 $\sigma_t$ 一加权之后,贡献变得相当 (1.42.80.15)。

相对均匀权重的典型增益:2-5%。代价:每个任务多 1 个标量参数。便宜,几乎一定值得开。

GradNorm:按训练速度平衡梯度幅度

GradNorm 动态

不确定性加权平衡的是损失幅度。GradNorm(Chen et al., 2018)平衡的是梯度幅度,并且条件作用于每个任务的训练进度。

定义任务的相对逆训练速度:

$$ \tilde{r}_t \;=\; \frac{\mathcal{L}_t(t)\,/\,\mathcal{L}_t(0)}{\overline{\mathcal{L}(t)\,/\,\mathcal{L}(0)}}. $$

$\tilde{r}_t > 1$ 意味着任务 $t$ 的损失下降比平均要慢——它在掉队。GradNorm 调节 $w_t$ 让

$$ \lVert w_t \nabla \mathcal{L}_t \rVert \;\approx\; \overline{G}\cdot \tilde{r}_t^{\,\alpha}, $$

其中 $\overline{G}$ 是共享参数上的平均梯度范数,$\alpha \!\approx\! 1.5$ 控制再平衡的力度。

图 4 模拟了一个 60 个 epoch 的过程:

  • 左。 三个任务,损失尺度和收敛速度都差很多。
  • 中。 它们的相对训练速度 $\tilde{r}_t$。慢的回归任务慢慢漂到 1 以上;快的辅助任务掉到 1 以下。
  • 右。 GradNorm 把掉队任务的权重抬上去,把领先任务的权重压下来——全程不需要人工干预。

文献中相对均匀权重的增益:3-8%。

PCGrad:把冲突分量投影掉

PCGrad(Yu et al., 2020)解决的是方向问题,不是幅度问题。一旦两任务梯度冲突,把其中一个投影到另一个的法平面上,剔除主动伤害对方的那部分:

$$ g_i' \;=\; g_i \;-\; \frac{g_i \cdot g_j}{\lVert g_j \rVert^2}\, g_j \qquad \text{当 } g_i \cdot g_j < 0. $$

投影之后 $g_i' \cdot g_j = 0$——冲突彻底消除。图 3 左边那条绿色 $g_1^{PC}$ 就是结果:保留了 $g_1$ 中不会伤害 $g_2$ 的那部分。

伪代码:

for each task i:
    g_i = loss_i 的反向传播
    for each other task j:
        if g_i . g_j < 0:                     # 冲突
            g_i = g_i - proj(g_i, g_j)        # 去掉冲突分量
final_gradient = 所有修改后梯度的均值

理论保证: 最终梯度在一阶意义上不会增加任何任务的损失。

NYUv2 上的实测(语义分割 + 深度 + 表面法向):

  • 均匀权重:mIoU 40.2%,深度误差 0.61
  • PCGrad: mIoU 42.7%,深度误差 0.58

CAGrad:全局最优的冲突解决

CAGrad(Liu et al., 2021)把 PCGrad 推广为一个二次规划:找模长最小、且与所有任务梯度方向都对齐的更新向量:

$$ g^{*} \;=\; \arg\min_g \lVert g \rVert^2 \quad \text{s.t.}\quad g \cdot g_t \geq 0 \;\; \forall t. $$

这就是帕累托最优的下降方向——保证不伤害任何任务,并且是全局而非两两处理。代价是每步 $\mathcal{O}(T^2)$。任务数 $T \leq 5$ 时直接上 CAGrad。

比较与组合

方法控制什么代价
均匀什么也不控免费
不确定性加权损失幅度$T$ 个参数
GradNorm梯度幅度$T$ 个参数 + 一次反向
PCGrad梯度方向$T$ 次反向
CAGrad梯度方向(全局)$T$ 次反向 + QP

GradNorm(控幅度)和 PCGrad/CAGrad(控方向)是正交的。任务数 $\geq 3$ 且尺度差异大时,GradNorm + PCGrad 的组合是稳妥的默认配置。


这些方法到底能拿到多少收益

MTL vs 单任务的精度与成本

图 6 把这些方法在 NYUv2 风格三任务基准上摆到一起:

  • 左。 各任务相对单任务基线的提升。均匀 MTL 已经在每个任务上都赢了(正则化效应是真实的)。不确定性加权再加 1-2 个点。GradNorm 和 PCGrad 各再加 1-2 个点。CAGrad 排在最上面。
  • 右。 成本/效率收益与平衡方法无关:共享编码器 + 3 个头把参数从 75M 砍到 31M,前向次数从 3 次砍到 1 次。

实践要点:一个配置得当的 MTL(带 PCGrad 或 CAGrad)经常能给出一个更小、更快、更准的模型,同时取代它要替换的那一组单任务模型。这是工程和精度难得指向同一边的情形。


辅助任务设计

如果你真正在意的只有一个主任务,MTL 只是用作正则化手段,那么设计问题就变成:该添加哪些辅助任务?

自监督辅助任务(免费监督)

  • 旋转预测。 把输入旋转 0 / 90 / 180 / 270 度,预测角度。强迫模型理解方向和物体结构。
  • 拼图。 打乱图像 patch,预测正确排列。强迫模型理解空间布局。
  • 对比学习(SimCLR / MoCo)。 把同一样本的两个增强版本拉近,不同样本推远。强迫模型学到对增广不变的表示。

领域特定的辅助任务

主任务有效的辅助任务
目标检测边缘检测、深度估计
命名实体识别词性标注、依存句法分析
点击率预测转化率、停留时长、关注概率
语音识别说话人识别、语音活动检测、噪声分类

加多少个辅助任务?

  • 先从 1-2 个最相关的开始。
  • 只有在主任务验证集还在涨的时候才继续加。
  • 通常 2-4 个辅助任务就够了;超过 10 个就该改用任务聚类(图 7),而不是无脑往上叠。

完整代码实现

一个自包含的 PyTorch 框架,支持硬参数共享 + 三种平衡方法(uniform / PCGrad / GradNorm)。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
import numpy as np
from typing import List, Dict, Optional


# ===== 网络架构 =====

class SharedEncoder(nn.Module):
    """共享主干:ResNet-18 的前 3 个 block。"""
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18(pretrained=False)
        self.stem = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        return self.layer3(x)


class TaskHead(nn.Module):
    """任务专属的分类或回归头。"""
    def __init__(self, in_channels, num_outputs, task_type='classification'):
        super().__init__()
        self.task_type = task_type
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, num_outputs))

    def forward(self, x):
        return self.fc(self.pool(x).flatten(1))


class MultiTaskNet(nn.Module):
    """硬参数共享:共享编码器 + 任务专属头部。"""
    def __init__(self, task_configs):
        super().__init__()
        self.encoder = SharedEncoder()
        self.heads = nn.ModuleDict({
            cfg['name']: TaskHead(256, cfg['num_classes'], cfg['type'])
            for cfg in task_configs
        })

    def forward(self, x):
        features = self.encoder(x)
        return {name: head(features) for name, head in self.heads.items()}


# ===== PCGrad =====

class PCGrad:
    """投影冲突梯度(Yu et al., NeurIPS 2020)。"""
    def __init__(self, optimizer, task_names):
        self.optimizer = optimizer
        self.task_names = task_names

    @staticmethod
    def _project(g_i, g_j):
        dot = torch.dot(g_i, g_j)
        if dot < 0:
            g_i = g_i - (dot / (g_j.norm() ** 2 + 1e-8)) * g_j
        return g_i

    def step(self, losses):
        # 1. 各任务在共享参数上的扁平化梯度。
        grads = {}
        for name in self.task_names:
            self.optimizer.zero_grad()
            losses[name].backward(retain_graph=True)
            grads[name] = torch.cat([
                p.grad.flatten() for p in self.optimizer.param_groups[0]['params']
                if p.grad is not None
            ]).clone()

        # 2. 两两投影掉冲突分量。
        modified = {}
        for i, ni in enumerate(self.task_names):
            g = grads[ni].clone()
            for j, nj in enumerate(self.task_names):
                if i != j:
                    g = self._project(g, grads[nj])
            modified[ni] = g

        # 3. 用修改后梯度的均值进行优化器步骤。
        avg_grad = sum(modified.values()) / len(modified)
        self.optimizer.zero_grad()
        idx = 0
        for p in self.optimizer.param_groups[0]['params']:
            if p.grad is not None:
                n = p.numel()
                p.grad = avg_grad[idx:idx + n].view_as(p)
                idx += n
        self.optimizer.step()


# ===== GradNorm =====

class GradNorm:
    """梯度归一化以自适应平衡损失(Chen et al., ICML 2018)。"""
    def __init__(self, model, task_names, alpha=1.5, lr=0.025):
        self.model = model
        self.task_names = task_names
        self.alpha = alpha
        self.weights = nn.Parameter(torch.ones(len(task_names)))
        self.weight_optim = optim.Adam([self.weights], lr=lr)
        self.initial_losses = None

    def step(self, losses):
        if self.initial_losses is None:
            self.initial_losses = {n: l.item() for n, l in losses.items()}

        weighted = [self.weights[i] * losses[n]
                    for i, n in enumerate(self.task_names)]
        total = sum(weighted)

        # 只在共享编码器上计算各任务的梯度范数。
        shared_params = list(self.model.encoder.parameters())
        grad_norms = []
        for wl in weighted:
            grads = torch.autograd.grad(
                wl, shared_params, retain_graph=True, create_graph=True)
            grad_norms.append(
                torch.norm(torch.cat([g.flatten() for g in grads])))

        avg_norm = sum(grad_norms) / len(grad_norms)
        avg_ratio = sum(
            losses[n].item() / (self.initial_losses[n] + 1e-8)
            for n in self.task_names) / len(self.task_names)

        # GradNorm 损失:让 ||w_t * grad_t|| 趋向 avg_norm * r_t^alpha。
        gn_loss = sum(
            torch.abs(grad_norms[i] - avg_norm * (
                (losses[n].item() / (self.initial_losses[n] + 1e-8))
                / (avg_ratio + 1e-8)) ** self.alpha)
            for i, n in enumerate(self.task_names))

        self.weight_optim.zero_grad()
        gn_loss.backward()
        self.weight_optim.step()
        # 重新归一化让权重之和等于任务数。
        with torch.no_grad():
            self.weights.data *= len(self.task_names) / self.weights.sum()

        return total, {n: self.weights[i].item()
                       for i, n in enumerate(self.task_names)}


# ===== 训练器 =====

class MTLTrainer:
    """多任务训练器,支持 uniform / PCGrad / GradNorm。"""
    def __init__(self, model, task_configs, device='cpu', method='uniform'):
        self.model = model.to(device)
        self.device = device
        self.task_configs = {c['name']: c for c in task_configs}
        self.task_names = [c['name'] for c in task_configs]
        self.method = method
        self.optimizer = optim.Adam(model.parameters(), lr=1e-3)

        if method == 'pcgrad':
            self.pcgrad = PCGrad(self.optimizer, self.task_names)
        elif method == 'gradnorm':
            self.gradnorm = GradNorm(model, self.task_names)

    def _losses(self, outputs, targets):
        losses = {}
        for n in self.task_names:
            if self.task_configs[n]['type'] == 'classification':
                losses[n] = F.cross_entropy(outputs[n], targets[n])
            else:
                losses[n] = F.mse_loss(outputs[n], targets[n])
        return losses

    def train_epoch(self, loader, epoch):
        self.model.train()
        stats = {n: 0.0 for n in self.task_names + ['total']}
        for batch in loader:
            inputs = batch['input'].to(self.device)
            targets = {n: batch[n].to(self.device) for n in self.task_names}
            outputs = self.model(inputs)
            losses = self._losses(outputs, targets)

            if self.method == 'uniform':
                total = sum(losses.values())
                self.optimizer.zero_grad()
                total.backward()
                self.optimizer.step()
            elif self.method == 'pcgrad':
                self.pcgrad.step(losses)
                total = sum(l.item() for l in losses.values())
            elif self.method == 'gradnorm':
                total, _ = self.gradnorm.step(losses)
                self.optimizer.zero_grad()
                total.backward()
                self.optimizer.step()

            for n in self.task_names:
                stats[n] += (losses[n].item()
                             if isinstance(losses[n], torch.Tensor)
                             else losses[n])
            stats['total'] += (total.item()
                               if isinstance(total, torch.Tensor) else total)
        nb = len(loader)
        return {k: v / nb for k, v in stats.items()}

    @torch.no_grad()
    def evaluate(self, loader):
        self.model.eval()
        correct = {n: 0 for n in self.task_names}
        total = 0
        for batch in loader:
            inputs = batch['input'].to(self.device)
            targets = {n: batch[n].to(self.device) for n in self.task_names}
            outputs = self.model(inputs)
            total += inputs.size(0)
            for n in self.task_names:
                if self.task_configs[n]['type'] == 'classification':
                    correct[n] += (outputs[n].argmax(1) == targets[n]).sum().item()
        return {n: 100.0 * correct[n] / total for n in self.task_names}


# ===== 演示 =====

class DummyMTLDataset(Dataset):
    def __init__(self, n=1000):
        self.n = n
    def __len__(self):
        return self.n
    def __getitem__(self, i):
        return {
            'input': torch.randn(3, 32, 32),
            'task1': torch.randint(0, 10, ()).item(),
            'task2': torch.randint(0, 5, ()).item(),
        }


def main():
    configs = [
        {'name': 'task1', 'num_classes': 10, 'type': 'classification'},
        {'name': 'task2', 'num_classes': 5,  'type': 'classification'},
    ]
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    loader = DataLoader(DummyMTLDataset(1000), batch_size=32, shuffle=True)
    test_loader = DataLoader(DummyMTLDataset(200), batch_size=32)

    for method in ['uniform', 'pcgrad', 'gradnorm']:
        print(f"\n{'=' * 50}\nMethod: {method}\n{'=' * 50}")
        model = MultiTaskNet(configs)
        trainer = MTLTrainer(model, configs, device, method=method)
        for epoch in range(10):
            stats = trainer.train_epoch(loader, epoch)
            metrics = trainer.evaluate(test_loader)
            print(f"Epoch {epoch+1}: "
                  + " ".join(f"{k}={v:.4f}" for k, v in stats.items())
                  + " | "
                  + " ".join(f"{k}={v:.1f}%" for k, v in metrics.items()))


if __name__ == '__main__':
    main()

代码架构

组件职责
SharedEncoderResNet-18 前 3 个 block 作为共享特征提取器。
TaskHead每个任务的分类或回归头。
MultiTaskNet硬参数共享:编码器 + ModuleDict 头部。
PCGrad在求平均之前,两两投影掉冲突梯度。
GradNorm学每个任务的权重,让梯度幅度跟随 $\tilde{r}_t^\alpha$。
MTLTrainer一个统一接口包住 uniform / PCGrad / GradNorm 三种方法。

FAQ

到底什么时候才该上多任务学习?

诚实的三个理由:(1) 任务共享底层特征,你想要正则化效应;(2) 主任务数据不足,希望辅助任务通过共享编码器把监督信号传过来;(3) 推理时需要便宜地同时输出多个预测。三条都不沾就别用 MTL。

怎么判断任务是不是在打架?

两个便宜的检查。(a) 在共享参数上记录 $\cos(\nabla \mathcal{L}_A, \nabla \mathcal{L}_B)$——持续为负就是冲突(图 3 右)。(b) 把多任务模型里每个任务的精度跟单任务基线比较——任何一个掉了,就发生了负迁移。

硬共享还是软共享?

先用硬共享——更简单、正则化更强、参数更少。只有当你观察到的负迁移在 PCGrad 和 GradNorm 之后还存在时,才换 Cross-Stitch / MTAN。

我的损失尺度差 100 倍,怎么办?

别手工调权重,永远调不出来。先用不确定性加权(图 5),这是成本最低的解决方案;如果还需要追踪任务间训练速度差异,再切到 GradNorm。

PCGrad 能和 GradNorm 一起用吗?

能——两者正交。GradNorm 控幅度,PCGrad 控方向。标准组合是:(1) 用 GradNorm 算出 $w_t$,(2) 形成加权的任务梯度 $w_t g_t$,(3) 在这之上做 PCGrad。三个以上任务、尺度差异大时,这是合理的默认。

辅助任务该加多少个?

开始 1-2 个,没量化过亲和矩阵之前不要超过 4 个。任务超过 10 个,就用图 7 的方式做聚类,每一组一个共享编码器。


总结

多任务学习能让你训出一个又小又快、还经常比单任务更准的模型。架构本身是容易的部分——硬参数共享 + 任务头几乎打不过。难的是在训练循环里保持诚实:

  • 损失尺度平衡——任务输出类型或量级不一样时,不确定性加权或 GradNorm 几乎是必选项。
  • 梯度冲突会影响 30-50% 的更新——PCGrad(便宜)或 CAGrad(更优)能阻止它们悄悄拖累某些任务。
  • 任务选择比任何优化器技巧都更重要——动手前先用梯度余弦或迁移实验量化亲和性。
  • 先硬共享,只有在测量结果指向需要时才上软共享 / Cross-Stitch。

下一篇是零样本学习——让模型分类训练时从未见过的类别,借助属性或语言描述跨越鸿沟。


参考文献

  1. Caruana, R. (1997). Multitask Learning. Machine Learning.
  2. Misra et al. (2016). Cross-Stitch Networks for Multi-task Learning. CVPR. arXiv:1604.03539
  3. Kendall et al. (2018). Multi-Task Learning Using Uncertainty to Weigh Losses. CVPR. arXiv:1705.07115
  4. Chen et al. (2018). GradNorm: Gradient Normalization for Adaptive Loss Balancing. ICML. arXiv:1711.02257
  5. Liu et al. (2019). End-to-End Multi-Task Learning with Attention (MTAN). CVPR. arXiv:1803.10704
  6. Standley et al. (2020). Which Tasks Should Be Learned Together in Multi-task Learning? ICML. arXiv:1905.07553
  7. Yu et al. (2020). Gradient Surgery for Multi-Task Learning (PCGrad). NeurIPS. arXiv:2001.06782
  8. Liu et al. (2021). Conflict-Averse Gradient Descent (CAGrad). NeurIPS. arXiv:2110.14048

系列导航

部分主题
1基础与核心概念
2预训练与微调
3域适应
4小样本学习
5知识蒸馏
6多任务学习(本文)
7零样本学习
8多模态迁移
9参数高效微调
10持续学习
11跨语言迁移
12工业应用与最佳实践

Liked this piece?

Follow on GitHub for the next one — usually one a week.

GitHub