系列 · 迁移学习 · 第 3 篇

迁移学习(三):域适应

域适应实战指南:协变量偏移、标签偏移、DANN 梯度反转、MMD 对齐、CORAL、自训练、AdaBN,以及一份可运行的 DANN 完整实现。

我的自动驾驶模型在加州晴天的高速公路上表现堪称完美,但一到西雅图下雨,Top-1 准确率就从 95% 直接跌到 70%。问题不在于模型变差了,而是数据分布发生了偏移——训练集里压根没有傍晚湿滑沥青路面的样本。

这正是域适应(Domain Adaptation)要解决的日常难题:你拥有大量源域标注数据,目标域却只有无标签数据,但模型必须在目标域上表现良好。本文将带你从第一性原理出发,一步步理解域适应,并最终实现一个可运行的 DANN 模型。


你将学到什么#

  • 分布偏移的三种类型——协变量偏移、标签偏移、概念偏移——以及各自的应对策略
  • Ben-David 上界:为什么域适应可行,以及它具体能帮你减少哪个关键量
  • DANN:如何通过梯度反转层,在一次反向传播中完成对抗对齐
  • MMD 和 CORAL:显式、非对抗式的分布匹配损失函数
  • 自训练、AdaBN、CycleGAN、ADDA——现代域适应工具箱中的其他实用方法
  • 完整的 PyTorch 实现,涵盖 DANN 的所有细节
  • 方法选择决策树,附带 Office-31 和 DomainNet 数据集上的基准性能

前置要求:已阅读本系列第 1–2 篇;对 GAN 式对抗训练有基本了解。


分布偏移的三种类型#

一个(domain)由特征空间 $\mathcal{X}$ 及其边缘分布 $P(X)$ 定义;一个任务(task)则由标签空间 $\mathcal{Y}$ 及条件分布 $P(Y \mid X)$ 定义。域适应研究的是当源域与目标域在这两者之一上存在分歧时该如何处理。

设置源域 $\mathcal{D}_S$目标域 $\mathcal{D}_T$目标
数据大量带标注 $(x_i, y_i)$多数无标注 $x_j$学习一个在 $\mathcal{D}_T$ 上有效的函数 $f: \mathcal{X} \to \mathcal{Y}$

源域与目标域的特征对齐

这张图一目了然地展示了域适应的核心:适应前,源域训练出的决策边界穿过目标域数据稀疏的空白区域;适应后,两个域在特征流形上对齐,同一边界即可同时适用于两者。

协变量偏移——输入分布变了#

$$P_S(X) \neq P_T(X), \qquad P_S(Y \mid X) = P_T(Y \mid X)$$

标注规则没变,只是观测到的内容不同了。典型例子包括:

  • 用 2020 年邮件训练的垃圾邮件过滤器部署到 2026 年:话题变了,但“什么是垃圾邮件”的定义没变。
  • 在西门子 CT 扫描仪上训练的模型用于 GE 扫描仪的数据:成像特性不同,但放射科医生的判读标准一致。
$$\mathbb{E}_{P_T}[\ell(f(X), Y)] = \mathbb{E}_{P_S}\!\left[\frac{P_T(X)}{P_S(X)}\,\ell(f(X), Y)\right].$$

高维空间中直接估计概率密度几乎不可能,因此实践中通常直接估计密度比——比如用 KLIEP、uLSIF,或训练一个二分类器区分源域和目标域样本;此时,该分类器的贝叶斯最优输出本身就隐含了密度比。

标签偏移——类别先验变了#

$$P_S(Y) \neq P_T(Y), \qquad P_S(X \mid Y) = P_T(X \mid Y)$$

各类别的条件分布不变,只是整体比例变了。例如:

  • 在 ICU 场景训练的疾病预测模型部署到门诊,疾病流行率大幅下降。
  • 在年轻用户群体中训练的推荐系统推广至全年龄段用户。

标准解法:利用 EM 算法(BBSE / RLLS 效果不错)在无标签目标数据上估计目标先验 $P_T(Y)$ ,然后按 $P_T(y) / P_S(y)$ 对源模型输出的概率进行重缩放并归一化。

概念偏移——规则本身变了#

$$P_S(Y \mid X) \neq P_T(Y \mid X)$$

这是最棘手的情况。比如单词 “sick” 在音乐评论中是褒义(“这段 beat 太 sick 了”),在产品评论中却是贬义——词相同,但标签含义相反。若目标域完全没有标签,任何方法都无法可靠解开这种混淆。此时至少需要少量目标域标注样本,即进入半监督域适应(semi-supervised DA)设定。


理论:Ben-David 上界#

$$ \epsilon_T(h) \;\leq\; \epsilon_S(h) \;+\; \tfrac{1}{2}\, d_{\mathcal{H}\Delta\mathcal{H}}(\mathcal{D}_S, \mathcal{D}_T) \;+\; \lambda^{*}. $$
含义你能做什么
$\epsilon_S(h)$源域误差在源域上训练得更好
$d_{\mathcal{H}\Delta\mathcal{H}}$域间对称差散度这正是域适应要降低的部分
$\lambda^{*}$最优联合预测器的误差不可约减——若其值很大,再好的方法也无济于事

两点关键启示:

  1. 域适应的效果受理论上限约束。如果源任务和目标任务本质不同($\lambda^*$ 很大),那你需要的是新标签,而不是更复杂的损失函数。
  2. 域散度有可操作的代理指标。训练一个二分类器来区分源域和目标域的特征。若其准确率接近 50%,说明特征已实现域不变性——这正是 DANN 自动完成的机制。

DANN:一次反向传播完成的对抗对齐#

DANN 训练动画:目标域特征向源域类别簇迁移。

DANN 前后:域分布合并,类别仍可分。

域对抗神经网络(Domain-Adversarial Neural Network, DANN;Ganin 等,2016)是最具影响力的对抗方法,也是“最小化域散度代理”思想最简洁的实现。

DANN 架构与梯度反转层

三个子网,一个共享主干#

子网功能训练数据
特征提取器 $G_f$将输入 $x$ 映射为特征 $f = G_f(x)$源域 + 目标域
标签预测器 $G_y$根据特征 $f$ 预测标签 $\hat{y}$源域标签
域判别器 $G_d$判断特征 $f$ 来自源域还是目标域源域 + 目标域
$$ \min_{G_f,\, G_y}\; \max_{G_d}\quad \mathcal{L}_y(G_y \circ G_f) \;-\; \lambda\, \mathcal{L}_d(G_d \circ G_f). $$

$G_d$ 试图区分两个域,而 $G_f$ 则要欺骗 $G_d$ ,同时确保 $G_y$ 能正确分类源域数据。

梯度反转层(GRL)#

$$ \text{前向传播:}\; \text{GRL}(x) = x, \qquad \text{反向传播:}\; \frac{\partial\,\text{GRL}}{\partial x} = -\lambda\, I. $$

GRL 插在特征到域判别器的路径上。反向传播时,判别器的梯度在传回 $G_f$ 前会翻转符号。因此,同一次 loss.backward() 调用可以:

  • 用正常梯度更新 $G_y$ ,提升分类性能;
  • 用正常梯度更新 $G_d$ ,增强判别能力;
  • 翻转后的梯度更新 $G_f$ ,使其生成能迷惑 $G_d$ 的特征,同时仍支持 $G_y$ 的分类任务。

无需交替训练,无需多个优化器,也无需手动冻结参数。

对抗权重调度#

$$\lambda_p = \frac{2}{1 + \exp(-\gamma p)} - 1, \qquad \gamma \approx 10,$$

其中 $p \in [0, 1]$ 表示训练进度。初期($\lambda \approx 0$ ),网络专注学习源域特征;后期($\lambda \to 1$ ),域对齐逐渐生效。忽略这一调度是“DANN 训练成功但效果不如仅用源域模型”的最常见原因


MMD:在 RKHS 中对齐均值#

对抗对齐虽强,但训练不稳定。非对抗替代方案是定义显式分布距离并直接最小化。最大均值差异(Maximum Mean Discrepancy, MMD;Gretton 等,2012)是标准选择。

最大均值差异:核均值嵌入

核心思想#

$$\mu_P = \mathbb{E}_{X \sim P}[\phi(X)] \;\in\; \mathcal{H}.$$ $$\text{MMD}^2(P_S, P_T) = \|\mu_{P_S} - \mu_{P_T}\|_{\mathcal{H}}^2.$$

图示清晰表明:即使原始直方图部分重叠,核均值嵌入也能凸显差距,阴影面积即为 $\text{MMD}^2$

实际计算的估计量#

$$ \widehat{\text{MMD}}^2 = \frac{1}{n_s^2}\sum_{i,j} k(x_i^s, x_j^s) + \frac{1}{n_t^2}\sum_{i,j} k(x_i^t, x_j^t) - \frac{2}{n_s n_t}\sum_{i,j} k(x_i^s, x_j^t). $$ $$\mathcal{L} = \mathcal{L}_{\text{task}} + \lambda \cdot \widehat{\text{MMD}}^2\!\big(G_f(X_S),\, G_f(X_T)\big).$$

这就是 DAN / DDC(Long 等,2015;Tzeng 等,2014)。

实用技巧#

  • 使用多核 MMD:混合多个不同带宽的高斯核 $k = \sum_u \beta_u k_{\sigma_u}$ ,对带宽选择更鲁棒。
  • 中位数启发法选带宽:取 batch 内成对距离的中位数——简单、稳健,通常足够好。
  • 对深层特征应用 MMD:底层特征含域特定纹理,需对齐的是高层抽象表示。

MMD 与 DANN 快速对比#

MMDDANN
距离度量核 RKHS 范数Jensen–Shannon(通过判别器)
优化方式直接最小化对抗 minimax(GRL)
稳定性非常稳定偶尔震荡
表达能力受核选择限制更灵活
适用场景中小差距、数据较少差距大、数据充足

合理工作流:先试 MMD;若效果停滞,再切换到 DANN。


CORAL:对齐二阶统计量#

对齐均值很好,但同时对齐均值和协方差往往更优。CORAL(Sun & Saenko,2016)正是如此。

CORAL 协方差对齐

$$\mathcal{L}_{\text{CORAL}} = \frac{1}{4 d^2} \|C_S - C_T\|_F^2.$$

直观理解——白化 + 重新着色。将源特征乘以 $C_S^{-1/2} C_T^{1/2}$ ,先去除源域协方差“指纹”,再赋予目标域的协方差结构。Deep CORAL 直接将此损失加入网络,靠梯度隐式完成相同操作。

CORAL 极其高效(每 batch 仅需一个矩阵和一个 Frobenius 范数)、完全确定性,且在轻度偏移下表现惊人。在尝试 MMD 或 DANN 前,它是个极佳的基线。


AdaBN:永远该先尝试的免费午餐#

最简单的域适应技巧:在目标数据上重新计算 BatchNorm 统计量

标准 BN 在测试时使用源域训练累积的 running mean/variance。若目标分布不同,这些统计量就不准了,而它们恰好位于每个卷积层与非线性激活之间。AdaBN(Li 等,2017)的做法是:

  1. 正常在源域上训练模型。
  2. 冻结权重后,用无标签目标数据跑前向传播,重新计算每层 BN 的 $\mu_T, \sigma_T^2$
  3. 部署时,用目标统计量替换源统计量。

成本:几分钟;代码改动:仅需替换几个 BatchNorm 的 running stats;效果:在协变量偏移下,通常能挽回 2–10 个点的准确率。任何复杂方法之前,请务必先试这个


基于 GAN 和像素级的适配方法#

有时差距过于视觉化——如合成图 vs 真实图、白天 vs 夜晚——此时仅对齐特征为时已晚,需直接转换输入本身。

  • CycleGAN 学习两个生成器 $G: \mathcal{X}_S \to \mathcal{X}_T$$F: \mathcal{X}_T \to \mathcal{X}_S$ ,并施加循环一致性约束 $F(G(x)) \approx x$ 。先将源图像转为目标风格,再用原始标签训练分类器。注意:循环一致性不能保证语义保留,建议结合感知损失或恒等损失以保安全。
  • ADDA 解耦源域和目标域编码器。阶段一:正常训练源编码器+分类器。阶段二:用源编码器初始化目标编码器,通过对抗域判别器进行适应,同时冻结分类器。阶段三:测试时,目标输入经目标编码器源分类器处理。这种不对称设计赋予 ADDA 比 DANN 更强的容量,代价是多一个训练阶段。

自训练:为目标域生成标签#

对抗和统计对齐将目标域视为无差别的整体。自训练(伪标注)则更进一步:用当前模型为目标样本生成标签,并用这些标签继续训练。

自训练 / 伪标注

流程如下:

  1. 在源域上训练 $f$
  2. 对所有目标样本预测,仅保留置信度满足 $\max_y f(x)_y > \tau$ 的样本($\tau$ 为高置信度阈值)。
  3. 将保留的 (输入, 预测) 对视为新标注数据,重新训练。
  4. 迭代。

自训练强大却被低估,但有一个臭名昭著的失败模式:确认偏差——错误但高置信的预测被反复喂回训练,导致错误放大。标准缓解措施包括:

  • 设高阈值 $\tau$ (通常 ≥0.9);
  • 按类别平衡选择(限制每类保留数量);
  • 增强下的一致性正则化(FixMatch 风格);
  • 每轮从源模型重启,而非从前一轮自训练模型继续。

决策树:方法怎么选,何时用?#

实践中,强效 pipeline 往往组合多种方法:先用 AdaBN 拿下易得收益,再用 MMD 或 DANN 对齐特征,最后以自训练收尾。


基准测试:这到底有多大用?#

Office-31 和 DomainNet 基准

数据基于 ResNet-50 主干的文献平均值。两点值得注意:

  • 从“什么都不做”到“做点什么”,提升最显著。即使只用 AdaBN,也能弥补可观差距。动手做远比纠结“完美方法”重要。
  • DomainNet 比 Office-31 难得多。DomainNet 上 40% 的准确率已属强劲——该数据集含 345 类,横跨 6 种视觉风格迥异的域。解读 DA 准确率时,务必与 source-only 基线对比,而非看绝对值。

域适应真正发挥作用的地方#

  • 医学影像:西门子 vs GE 扫描仪、1.5T vs 3T MRI、医院 A vs 医院 B。
  • 自动驾驶:晴天 → 雨天、城市 A → 城市 B、仿真 → 真实。
  • 推荐系统:国家间、年份间、Web → 移动端。
  • NLP:影评 → 商品评、新闻 → 社交、正式 → 口语。
  • Sim-to-real:机器人与自动驾驶中,从合成数据迁移到真实传感器数据。

共同模式:源域标签充足,目标域标签昂贵或不可得,但模型必须上线


可视化效果——t-SNE 前后对比#

训练 DA 模型后的标准检查:用 t-SNE 投影源域和目标域特征。适应前,样本按聚类;适应后,按类别聚类。

t-SNE 域适应前后

若“之后”图中仍有两个分离的域块,说明对齐失败;若融合为一个块且呈现类别结构,则对齐成功。这张图比任何单一指标都更具诊断价值

完整实现:DANN#

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Function
import numpy as np
from sklearn.metrics import accuracy_score

class GradientReversalFunction(Function):
    """前向传播不变,反向传播时翻转梯度。"""

    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambda_, None

class GradientReversalLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.lambda_ = 1.0

    def set_lambda(self, val):
        self.lambda_ = val

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_)

class FeatureExtractor(nn.Module):
    def __init__(self, input_dim=28 * 28, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.5),
        )

    def forward(self, x):
        return self.net(x.view(x.size(0), -1))

class LabelPredictor(nn.Module):
    def __init__(self, feature_dim=256, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

class DomainDiscriminator(nn.Module):
    def __init__(self, feature_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feature_dim, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, 1), nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x)

class DANN(nn.Module):
    """域对抗神经网络。"""

    def __init__(self, input_dim=28 * 28, hidden_dim=256, num_classes=10):
        super().__init__()
        self.feature_extractor = FeatureExtractor(input_dim, hidden_dim)
        self.label_predictor = LabelPredictor(hidden_dim, num_classes)
        self.domain_discriminator = DomainDiscriminator(hidden_dim)
        self.grl = GradientReversalLayer()

    def forward(self, x, alpha=1.0):
        features = self.feature_extractor(x)
        class_logits = self.label_predictor(features)
        self.grl.set_lambda(alpha)
        domain_logits = self.domain_discriminator(self.grl(features))
        return class_logits, domain_logits

class DANNTrainer:
    def __init__(self, model, source_loader, target_loader, test_loader,
                 num_epochs=100, lr=1e-3, device="cpu", gamma=10.0):
        self.model = model.to(device)
        self.source_loader = source_loader
        self.target_loader = target_loader
        self.test_loader = test_loader
        self.num_epochs = num_epochs
        self.device = device
        self.gamma = gamma
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.class_criterion = nn.CrossEntropyLoss()
        self.domain_criterion = nn.BCELoss()

    def _adaptive_lambda(self, epoch):
        # sigmoid 曲线控制对抗权重,训练初期小,后期接近 1
        p = epoch / self.num_epochs
        return 2.0 / (1.0 + np.exp(-self.gamma * p)) - 1.0

    def train_epoch(self, epoch):
        self.model.train()
        source_iter = iter(self.source_loader)
        target_iter = iter(self.target_loader)
        n_batches = min(len(self.source_loader), len(self.target_loader))
        total_loss = 0.0
        lambda_p = self._adaptive_lambda(epoch)

        for _ in range(n_batches):
            try:
                src_x, src_y = next(source_iter)
            except StopIteration:
                source_iter = iter(self.source_loader)
                src_x, src_y = next(source_iter)
            try:
                tgt_x, _ = next(target_iter)
            except StopIteration:
                target_iter = iter(self.target_loader)
                tgt_x, _ = next(target_iter)

            src_x = src_x.to(self.device)
            src_y = src_y.to(self.device)
            tgt_x = tgt_x.to(self.device)

            # 前向计算:源域和目标域都过两个分支
            src_class_logits, src_dom_logits = self.model(src_x, lambda_p)
            _, tgt_dom_logits = self.model(tgt_x, lambda_p)

            # 源域分类损失
            class_loss = self.class_criterion(src_class_logits, src_y)
            # 域判别损失:源域标记为 1,目标域标记为 0
            d_loss_s = self.domain_criterion(
                src_dom_logits, torch.ones_like(src_dom_logits))
            d_loss_t = self.domain_criterion(
                tgt_dom_logits, torch.zeros_like(tgt_dom_logits))
            domain_loss = d_loss_s + d_loss_t

            loss = class_loss + domain_loss
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()

        return total_loss / n_batches

    @torch.no_grad()
    def evaluate(self):
        self.model.eval()
        preds, labels = [], []
        for x, y in self.test_loader:
            x = x.to(self.device)
            logits, _ = self.model(x, alpha=0.0)
            preds.extend(logits.argmax(dim=1).cpu().numpy())
            labels.extend(y.numpy())
        return accuracy_score(labels, preds)

    def train(self):
        best = 0.0
        for epoch in range(self.num_epochs):
            loss = self.train_epoch(epoch)
            acc = self.evaluate()
            if (epoch + 1) % 10 == 0:
                lam = self._adaptive_lambda(epoch)
                print(f"epoch {epoch + 1:3d}  loss={loss:.4f}  "
                      f"target_acc={acc:.4f}  lambda={lam:.3f}")
            best = max(best, acc)
        print(f"best target accuracy: {best:.4f}")

def main():
    N, D, C = 10000, 28 * 28, 10
    # 模拟源域和目标域数据分布偏移
    src_x = torch.randn(N, 1, 28, 28)
    src_y = torch.randint(0, C, (N,))
    tgt_x = torch.randn(N, 1, 28, 28) + 0.5     # 目标域分布偏移
    tgt_y = torch.randint(0, C, (N,))           # 训练时不使用
    test_x = torch.randn(2000, 1, 28, 28) + 0.5
    test_y = torch.randint(0, C, (2000,))

    BS = 128
    src_loader = DataLoader(TensorDataset(src_x, src_y), BS, shuffle=True)
    tgt_loader = DataLoader(TensorDataset(tgt_x, tgt_y), BS, shuffle=True)
    test_loader = DataLoader(TensorDataset(test_x, test_y), BS)

    model = DANN(D, 256, C)
    trainer = DANNTrainer(model, src_loader, tgt_loader, test_loader,
                          num_epochs=100, lr=1e-3)
    trainer.train()

if __name__ == "__main__":
    main()

这段代码的核心逻辑#

组件功能
GradientReversalLayer前向不变,反向翻转梯度——将对抗训练简化为单次反向传播。
_adaptive_lambdaSigmoid 调度 $\frac{2}{1+e^{-\gamma p}} - 1$ ——先学特征,再加对抗。
class_loss标准交叉熵,仅用源域标签(目标域无标签)。
domain_lossBCE 损失:源域=1,目标域=0——训练域判别器。
GRL + 域分支反向时梯度翻转回 $G_f$ → 迫使特征提取器隐藏域信息。
evaluate(alpha=0)测试时设 $\lambda=0$ ,GRL 无效——仅用分类头预测。

CORAL vs MMD vs DANN:实证对比#

这三种对齐损失在纸面上看起来不同,但解决的是同一个问题——将源域和目标域的特征拉入表示空间的同一区域。为了具体比较它们的权衡,固定一个基准任务,并用相同的骨干网络运行全部三种方法。

实验设置: Office-31 数据集,Amazon → Webcam。源域 $D_S$ 包含 2817 张带标签图像(31 类);目标域 $D_T$ 包含 795 张无标签图像。使用 ImageNet 预训练的 ResNet-50,微调最后的 block,在分类器前加一个 256 维的瓶颈层。每个 batch 包含 32 张源域图像 + 32 张目标域图像,使用带动量 0.9 的 SGD 优化器,初始学习率 $10^{-3}$ ,训练 50 轮。唯一变化的是附加在瓶颈层上的对齐损失。

CORAL#

$$\mathcal{L}_{\text{CORAL}} = \frac{1}{4 d^2} \|C_S - C_T\|_F^2.$$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
def coral_loss(fs, ft):
    # fs, ft : (B, d) 源域和目标域的瓶颈特征
    d = fs.size(1)
    fs_c = fs - fs.mean(0, keepdim=True)
    ft_c = ft - ft.mean(0, keepdim=True)
    cs = (fs_c.t() @ fs_c) / (fs.size(0) - 1)
    ct = (ft_c.t() @ ft_c) / (ft.size(0) - 1)
    return ((cs - ct) ** 2).sum() / (4 * d * d)

def train_step_coral(model, src_x, src_y, tgt_x, opt, lam=1.0):
    fs = model.bottleneck(model.backbone(src_x))
    ft = model.bottleneck(model.backbone(tgt_x))
    logits = model.classifier(fs)
    ce = F.cross_entropy(logits, src_y)  # 分类损失
    align = coral_loss(fs, ft)           # 对齐损失
    loss = ce + lam * align
    opt.zero_grad(); loss.backward(); opt.step()
    return ce.item(), align.item()

MMD(多核)#

$$\widehat{\text{MMD}}^2 = \tfrac{1}{n_s^2}\!\sum k(x_i^s,x_j^s) + \tfrac{1}{n_t^2}\!\sum k(x_i^t,x_j^t) - \tfrac{2}{n_s n_t}\!\sum k(x_i^s,x_j^t).$$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def mk_mmd2(fs, ft, sigmas=(1, 2, 4, 8, 16)):
    def gram(a, b):
        d2 = ((a[:, None, :] - b[None, :, :]) ** 2).sum(-1)
        return sum(torch.exp(-d2 / (2 * s * s)) for s in sigmas)
    Kss = gram(fs, fs); Ktt = gram(ft, ft); Kst = gram(fs, ft)
    return Kss.mean() + Ktt.mean() - 2 * Kst.mean()

def train_step_mmd(model, src_x, src_y, tgt_x, opt, lam=1.0):
    fs = model.bottleneck(model.backbone(src_x))
    ft = model.bottleneck(model.backbone(tgt_x))
    logits = model.classifier(fs)
    ce = F.cross_entropy(logits, src_y)
    align = mk_mmd2(fs, ft)
    loss = ce + lam * align
    opt.zero_grad(); loss.backward(); opt.step()
    return ce.item(), align.item()

DANN#

$$\min_{G_f, G_y}\, \max_{G_d}\; \mathcal{L}_y - \lambda\, \mathcal{L}_d.$$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
def train_step_dann(model, disc, grl, src_x, src_y, tgt_x, opt, lam_p):
    grl.set_lambda(lam_p)
    fs = model.bottleneck(model.backbone(src_x))
    ft = model.bottleneck(model.backbone(tgt_x))
    logits = model.classifier(fs)
    ce = F.cross_entropy(logits, src_y)
    d_s = disc(grl(fs)); d_t = disc(grl(ft))
    y_s = torch.ones_like(d_s); y_t = torch.zeros_like(d_t)
    d_loss = F.binary_cross_entropy(d_s, y_s) + F.binary_cross_entropy(d_t, y_t)
    loss = ce + d_loss
    opt.zero_grad(); loss.backward(); opt.step()
    return ce.item(), d_loss.item()

实验结果#

方法目标域准确率单次运行时间超参敏感性
仅源域(无 DA)68.2%8 分钟
CORAL76.0%12 分钟
多核 MMD78.4%18 分钟中(核带宽)
DANN80.1%22 分钟高(GRL 调度)

CORAL 除了 $\lambda$ 外无需调参,而 $\lambda$ 几乎不影响结果——只要在 $[0.1, 10]$ 范围内,性能与最优相差不到 1 个百分点。MMD 对带宽敏感;多核版本缓解了这个问题,但仍建议做一次超参搜索。DANN 是三者中最敏感的——如果 $\lambda$ 调度不当,效果甚至比仅用源域还差。

收敛行为#

CORAL 和 MMD 均在约 30 轮内达到平台期并保持稳定。损失曲线和目标域准确率曲线都是单调的,你可以放心地在源域验证集上早停。

DANN 则不同。分类损失和域判别损失相互对抗,导致后期目标域准确率波动达 2–4 个百分点。你需要用 sigmoid 形式的 $\lambda$ 调度来保证早期训练稳定,即便如此,最佳策略仍是监控一个小的目标域验证集——直接取最后一轮的结果往往是错误的。

实用建议#

  1. 从 CORAL 开始: 无需超参搜索、确定性、仅需 12 行代码。如果它能缩小大部分差距,直接上线。
  2. 若 CORAL 达到瓶颈,尝试多核 MMD。 多一个超参(带宽集合),依然稳定,通常能再提升 1–3 个百分点。
  3. 仅当对抗训练可行时才考虑 DANN——你有足够预算多次运行以找到合适的 $\lambda$ 调度,并且能监控一个小的目标域验证信号。

坦白总结:复杂性只能带来几个百分点的准确率提升,而非数量级的飞跃。如果这几个点至关重要(如医疗、自动驾驶),选 DANN;否则,CORAL 节省的调试时间远超其性能差距。

桥梁:选择哪种对齐损失的前提是知道你面对的是哪种分布偏移。下一部分将提供诊断方法。


检测发生了哪种类型的偏移#

三种偏移类型——协变量偏移 $P(X)$ 、标签偏移 $P(Y)$ 、概念偏移 $P(Y \mid X)$ ——需要不同的应对策略,用错方法可能适得其反。重要性加权可修正协变量偏移,但对概念偏移无效;先验校正可处理标签偏移,但如果输入本身已偏移则毫无意义。在选用方法前,先做诊断。

算法 1 —— 通过域分类器检测协变量偏移#

训练一个二分类器区分源域输入(标签 1)和目标域输入(标签 0)。保留一部分用于验证,读取 AUC。

  • AUC ≈ 0.5:源域和目标域输入不可区分——无协变量偏移。
  • AUC ≈ 1.0:输入分布差异很大——存在显著协变量偏移,应使用重要性权重或特征对齐。
  • 中间值为渐进信号——分类器对源域样本的预测可估计密度比 $w(x) = P_T(x) / P_S(x) = (1 - p) / p$ ,其中 $p = P(\text{source} \mid x)$

算法 2 —— 通过先验比较检测标签偏移#

$$\hat P_T(y) = \frac{1}{n_T} \sum_{j=1}^{n_T} \mathbb{1}[\arg\max_y f(x_j^t) = y].$$ $$\mathrm{KL}(P_S(Y) \,\|\, \hat P_T(Y)) = \sum_y P_S(y) \log \frac{P_S(y)}{\hat P_T(y)}.$$

KL > ~0.05 值得警惕;> 0.2 是强标签偏移信号。(严格做法可用 BBSE / RLLS 通过混淆矩阵去卷积得到无偏 $P_T(Y)$ 估计;快速诊断用噪声版本已足够。)

算法 3 —— 通过每类置信度检测概念偏移#

概念偏移最隐蔽:输入“看起来”一样,类别比例也相同,但标注规则变了。其标志是在目标域上自信但错误的预测

如果你有少量带标签的目标域数据(每类 50–100 个样本),计算该子集上每类的平均预测置信度,并与对应准确率比较。高置信度但低准确率即为该类存在概念偏移的指纹。

诊断辅助工具#

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score

@torch.no_grad()
def _features(model, loader, device):
    feats, labels, probs = [], [], []
    for batch in loader:
        x = batch[0].to(device)
        y = batch[1] if len(batch) > 1 else None
        f = model.bottleneck(model.backbone(x)).cpu()
        p = F.softmax(model.classifier(f.to(device)), dim=1).cpu()
        feats.append(f); probs.append(p)
        if y is not None: labels.append(y)
    return (torch.cat(feats), torch.cat(probs),
            torch.cat(labels) if labels else None)

def diagnose_shift(source_model, source_loader, target_loader,
                   target_labelled_loader=None, device="cpu"):
    source_model.eval()
    fs, ps, ys = _features(source_model, source_loader, device)
    ft, pt, _ = _features(source_model, target_loader, device)

    # 1. 协变量偏移 —— 域分类器 AUC
    X = torch.cat([fs, ft]).numpy()
    d = torch.cat([torch.ones(len(fs)), torch.zeros(len(ft))]).numpy()
    clf = LogisticRegression(max_iter=1000).fit(X, d)
    auc = roc_auc_score(d, clf.predict_proba(X)[:, 1])
    covariate = max(0.0, 2 * (auc - 0.5))         # 0 = 无,1 = 最大

    # 2. 标签偏移 —— 源先验与预测目标先验的 KL
    C = ps.size(1)
    p_src = torch.bincount(ys, minlength=C).float() / len(ys)
    yhat_t = pt.argmax(dim=1)
    p_tgt = torch.bincount(yhat_t, minlength=C).float() / len(yhat_t)
    eps = 1e-8
    kl = (p_src * ((p_src + eps).log() - (p_tgt + eps).log())).sum().item()
    label = min(1.0, kl / 0.2)                    # 归一化到 [0,1]

    # 3. 概念偏移 —— 带标签目标域上的置信度-准确率差距
    concept = None
    if target_labelled_loader is not None:
        ftl, ptl, ytl = _features(source_model, target_labelled_loader, device)
        conf, pred = ptl.max(dim=1)
        acc = (pred == ytl).float().mean().item()
        gap = max(0.0, conf.mean().item() - acc)
        concept = min(1.0, gap / 0.3)

    return {"covariate": covariate, "label": label, "concept": concept,
            "auc": auc, "kl": kl}

数值示例#

合成 2D 数据,三种场景,分别运行相同诊断:

场景AUC(协变量)KL(标签)置信-准确率差(概念)诊断结果
干净高斯,相同标签0.510.010.02无偏移
目标域 $x_1$ 偏移 +10.940.040.03协变量偏移
目标域先验 $[0.1, 0.9]$ vs 源域 $[0.5, 0.5]$0.520.410.04标签偏移
目标域决策边界翻转0.500.020.38概念偏移

每种偏移类型只点亮一列。混合偏移会点亮多列——数值大小告诉你修复顺序。

桥梁:有了诊断结果,你就能选择合适方法。但所有依赖伪标签的方法(自训练、FixMatch、联合训练等)都假设这些标签可信。这要求置信度经过校准,下一部分将解决此问题。


域偏移下的置信度校准#

自训练和大多数半监督域自适应方法通过置信度过滤伪标签:当 $\max_y f(x)_y > \tau$ 时保留 $(x, \hat y)$ 。隐含假设是高 softmax 置信度意味着高准确率。但在域偏移下,这一假设失效。

源域训练的模型在目标域上通常过度自信——对仅 70% 正确的预测赋予 95% 的概率。Amazon 训练后在 Webcam 上的期望校准误差(ECE)通常高达 18–22%。若用 $\tau = 0.9$ 过滤伪标签,你会保留大量自信但错误的标签——这是确认偏误的经典配方。

两种互补修复方法,按复杂度递增排列。

修复 1 —— 温度缩放#

$$\hat p_y = \frac{\exp(z_y / T)}{\sum_{y'} \exp(z_{y'} / T)}.$$

$T > 1$ 可软化过度自信的预测;$T < 1$ 可锐化欠自信的预测。优化是一维凸问题——L-BFGS 几次迭代即可收敛。

修复 2 —— 伪标签上的 focal loss#

训练过滤后的伪标签时,对每个样本加权 $(1 - \hat p)^\gamma$ (focal loss 技巧)。高置信度伪标签(最可能错误的过度自信样本)权重小;中等置信度样本(模型仍有信号)获得完整梯度。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn.functional as F

def calibrate_temperature(model, val_loader, device="cpu", max_iter=50):
    """在小目标验证集上学一个标量 T"""
    model.eval()
    logits_list, labels_list = [], []
    with torch.no_grad():
        for x, y in val_loader:
            f = model.bottleneck(model.backbone(x.to(device)))
            logits_list.append(model.classifier(f).cpu())
            labels_list.append(y)
    logits = torch.cat(logits_list); labels = torch.cat(labels_list)

    T = torch.nn.Parameter(torch.ones(1) * 1.5)
    opt = torch.optim.LBFGS([T], lr=0.1, max_iter=max_iter)

    def closure():
        opt.zero_grad()
        loss = F.cross_entropy(logits / T.clamp(min=1e-2), labels)
        loss.backward()
        return loss
    opt.step(closure)
    return float(T.detach().clamp(min=1e-2))

def filter_and_weight(logits, T, tau=0.9, gamma=2.0):
    """校准过滤 + focal 加权用于伪标签"""
    p = F.softmax(logits / T, dim=1)
    conf, yhat = p.max(dim=1)
    keep = conf > tau
    w = (1.0 - conf[keep]) ** gamma
    return yhat[keep], w, keep

数值效果#

在前述 Amazon → Webcam 设置下,使用 100 个样本的目标验证集:

阶段目标域 ECE自训练 F1
仅源域 logits18.4%74.6
+ 温度缩放($T \approx 2.1$4.1%76.9
+ 伪标签 focal 加权3.8%77.8

仅校准就提升了 2.3 F1;在此基础上加 focal 加权再提升 0.9。ECE 从 18.4% 降至 4.1% 更关键——这意味着置信度现在名副其实,$\tau = 0.9$ 的过滤器实际选出的是 ~90% 正确的标签,而非 ~70%。

桥梁:至此,实用工具箱已完备——诊断偏移类型、选择对齐方法、在生成伪标签前校准置信度。下文总结将整个流程提炼为一份检查清单。

总结#

域适应解决的是迁移学习中最实际的问题:训练与部署数据分布不同。工具箱按实现成本从低到高排序:

  • AdaBN —— 在目标域重算批归一化统计量;零成本、无需重训、永远优先尝试。
  • CORAL —— 对齐源域与目标域的协方差矩阵;开销小、结果确定。
  • MMD(DAN) —— 匹配核均值嵌入;稳定、有理论支撑,默认用多核。
  • DANN —— 通过梯度反转层实现对抗域对齐;一次反向传播搞定。
  • CDAN / ADDA —— 更灵活的变体,适合大域间差距。
  • CycleGAN —— 特征对齐不足时,用像素级转换补足。
  • 自训练 —— 用置信度阈值筛选伪标签;榨取最后几个点的精度。

Ben-David 上界揭示了可能性边界:只要联合最优误差 $\lambda^*$ 足够小,压低源误差和域散度,目标误差自然下降。若 $\lambda^*$ 本身很大,再多对齐也无济于事——此时你需要的是标签。

接下来是 第 4 篇——Few-Shot Learning ,我们将彻底抛弃“源域数据充足”的假设,探索每类仅有 1–5 个样本时的学习方法。


参考文献#

  1. Ganin et al. (2016). Domain-Adversarial Training of Neural Networks. JMLR. arXiv:1505.07818
  2. Long et al. (2015). Learning Transferable Features with Deep Adaptation Networks. ICML. arXiv:1502.02791
  3. Sun & Saenko (2016). Deep CORAL: Correlation Alignment for Deep Domain Adaptation. ECCV. arXiv:1607.01719
  4. Zhu et al. (2017). Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks (CycleGAN). ICCV. arXiv:1703.10593
  5. Tzeng et al. (2017). Adversarial Discriminative Domain Adaptation (ADDA). CVPR. arXiv:1702.05464
  6. Long et al. (2018). Conditional Adversarial Domain Adaptation (CDAN). NeurIPS. arXiv:1705.10667
  7. Ben-David et al. (2010). A Theory of Learning from Different Domains. Machine Learning.
  8. Li et al. (2017). Revisiting Batch Normalization for Practical Domain Adaptation (AdaBN). arXiv:1603.04779
  9. Gretton et al. (2012). A Kernel Two-Sample Test (MMD). JMLR. paper
  10. Lipton et al. (2018). Detecting and Correcting for Label Shift with Black Box Predictors. ICML. arXiv:1802.03916
  11. Sohn et al. (2020). FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. NeurIPS. arXiv:2001.07685
本系列

迁移学习 12 篇

  1. 01 迁移学习(一):基础与核心概念
  2. 02 迁移学习(二):预训练与微调
  3. 03 迁移学习(三):域适应 当前
  4. 04 迁移学习(四):小样本学习
  5. 05 迁移学习(五):知识蒸馏
  6. 06 迁移学习(六):多任务学习
  7. 07 迁移学习(七):零样本学习
  8. 08 迁移学习(八):多模态迁移
  9. 09 迁移学习(九):参数高效微调
  10. 10 迁移学习(十):持续学习
  11. 11 迁移学习(十一):跨语言迁移
  12. 12 迁移学习(十二):工业应用与最佳实践

读有所得?

GitHub 关注我 → 新文周更

GitHub