系列 · 迁移学习 · 第 5 篇

迁移学习(五):知识蒸馏

把大模型的能力压进小模型而几乎不掉点:暗知识、温度缩放、响应/特征/关系蒸馏、自蒸馏与多策略实现的完整指南。

我训练了一个 340M 参数的 BERT 模型,准确率达到 95%,但产品团队希望将其部署到手机上,而手机只能容纳约 10M 参数。如果从头训练一个 10M 的小模型,准确率只能达到 85%;这时,知识蒸馏几乎可以弥补大部分差距——通过让小模型学习大模型的输出分布,而不仅仅是硬标签,最终准确率可以提升到 92%。

Hinton 提出了一个关键洞见:教师模型那些“错误”的预测并非噪声,而是信息。例如,当教师模型对一张猫的图片预测出 0.62 的“猫”、0.14 的“老虎”、0.07 的“狗”和 0.008 的“飞机”,它实际上在告诉你:猫和老虎很像,和狗有点像,但和飞机完全不像。这种类别之间的相似性结构——即暗知识(dark knowledge)——在 one-hot 标签中是看不到的,学生模型一旦学到这些暗知识,就能以很小的参数量实现远超预期的表现。


你将学到什么#

  • 为什么软标签比硬标签包含更多信息
  • 温度缩放:一个简单参数,决定教师模型透露多少暗知识
  • 知识蒸馏的三种类型:响应特征关系
  • 自蒸馏和互学习:无需预训练教师也能有效果
  • 知识蒸馏结合量化和剪枝:压缩率轻松突破 10 倍
  • 干净的 PyTorch 实现,支持全部五种蒸馏模式

前置要求: 本系列第 1–2 篇,PyTorch 基础。


为什么蒸馏有效#

教师和学生在软标签与硬标签结合的损失下训练

部署的现实难题#

大模型在基准测试中表现优异,但在手机、汽车和云服务账单上却让人头疼。四个限制条件迫使我们转向小模型:

  • 内存:移动设备和物联网设备根本装不下几十亿参数。
  • 延迟:自动驾驶需要毫秒级响应,而不是秒级。
  • 成本:每天调用十亿次的模型,每个 FLOP 都是真金白银。
  • 能耗:边缘设备靠电池供电,不是靠发电厂。

剪枝和量化直接修改模型本身,但会损失精度。蒸馏则另辟蹊径:训练一个小模型去模仿大模型的输出分布,而不仅仅是它的 argmax。学生继承了教师的归纳偏置,但不需要继承其参数。

暗知识:软标签到底教了什么#

假设教师对一张猫的图片分类,输出如下:

类别硬标签教师输出
1.00.62
老虎0.00.14
豹子0.00.10
0.00.07
汽车0.00.012

硬标签只说“是猫”,熵为零。软标签则说“是猫,但也像老虎、有点像豹子、隐隐约约像狗、绝对不像汽车”——熵大于零。这份排序是教师通过数百万次梯度更新学到的类别相似性课程,完全免费。

one-hot 与教师 softmax:同一个预测,监督信号截然不同

不是匹配标签,而是匹配分布#

$$\mathcal{L}_{\text{hard}} \;=\; -\sum_c y_c \log \sigma(z_c^S).$$ $$\mathcal{L}_{\text{KD}} \;=\; -\sum_c \sigma(z_c^T / \tau) \, \log \sigma(z_c^S / \tau).$$

由于教师固定不变,这等价于最小化 $\mathrm{KL}\!\left(\sigma(z^T/\tau) \,\|\, \sigma(z^S/\tau)\right)$ 。学生学习的不再是一个标签,而是一整条概率分布。

温度:暗知识的调节旋钮#

$$\sigma(z_i; \tau) \;=\; \frac{\exp(z_i / \tau)}{\sum_j \exp(z_j / \tau)}.$$
温度效果
$\tau \to 0$退化为 one-hot(argmax)
$\tau = 1$标准 softmax
$\tau = 4$ — 10显式化类间相似性
$\tau \to \infty$趋近均匀分布

举个例子,logits $z = [5, 3, 1]$

  • $\tau = 1$$[0.84, 0.11, 0.04]$ ,第三类几乎消失。
  • $\tau = 3$$[0.51, 0.31, 0.18]$ ,第三类的信息得以保留。
$$\sigma(z_i; \tau) \;\approx\; \frac{1}{C} + \frac{z_i - \bar z}{C \tau},$$

这意味着学生可以直接从教师 logits 的相对大小中学习,而不受指数函数的扭曲影响。

相同 logits、三种温度:从尖锐到平滑

组合损失#

$$\mathcal{L} \;=\; \alpha \cdot \tau^2 \cdot \mathcal{L}_{\text{KD}} \;+\; (1 - \alpha) \cdot \mathcal{L}_{\text{hard}}.$$
  • $\alpha \in [0.5, 0.9]$ :表示对教师的信任程度。
  • $\tau^2$ :补偿高温下的梯度收缩。软目标项的梯度按 $1/\tau^2$ 缩放,因此将损失乘回 $\tau^2$ ,让两项量级可比。
  • $\mathcal{L}_{\text{hard}}$ :与真实标签的标准交叉熵。

经验默认值:$\tau = 4$ ;当教师远强于学生时取 $\alpha = 0.9$ ,当两者容量接近时取 $\alpha = 0.5$


响应式蒸馏:只动输出层#

经典方法很简单——只匹配教师模型的输出层,其他一概不管。

Hinton 的算法#

  1. 用完整数据集训练教师模型 $T$
  2. 对每个输入 $x$ ,计算软标签 $\sigma(z^T(x) / \tau)$ 并保存。
  3. 用组合损失函数训练学生模型 $S$
  4. 部署学生模型时,设置 $\tau = 1$

ImageNet 上的典型结果:

设置Top-1
ResNet-34 教师73.3%
ResNet-18 从头训69.8%
ResNet-18 蒸馏71.4%(+1.6)

改一下损失函数,白捡 1.6 个点。

蒸馏 vs. 标签平滑#

$$y_c' \;=\; (1 - \epsilon) \, y_c + \epsilon / C.$$

区别在于“软”的来源不同。标签平滑对所有样本都用同样的均匀分布来软化。蒸馏则是为每个样本生成一个独特的软分布,这个分布来自教师模型。比如,一张波斯猫的图片会在“老虎”“豹子”上分配权重;一辆轿车的图片会在“卡车”“旅行车”上分配权重。这就是蒸馏效果始终优于标签平滑的原因。


基于特征的蒸馏:让中间层也对齐#

响应式蒸馏只关注顶层的匹配,而基于特征的蒸馏还会对齐中间表示——信号更丰富,学生模型能从更多地方学习教师模型的几何特性。

响应蒸馏只对齐 logits;特征蒸馏还对齐中间特征图

FitNets:提示学习#

$$\mathcal{L}_{\text{hint}} \;=\; \| W_r \, F_S^l - F_T^l \|_F^2,$$

其中 $W_r$ 是一个可学习的 1×1 投影,用来调整学生模型的通道维度以匹配教师模型。Romero 等人采用两阶段训练方法:

  1. 先训练学生模型的浅层网络和投影模块,使其对齐教师模型选定的“提示层”。
  2. 冻结浅层网络后,用标准蒸馏方法继续训练其余部分。

注意力传递(Attention Transfer)#

$$A(F) \;=\; \sum_c |F_c|^p, \quad p = 2.$$ $$\mathcal{L}_{\text{AT}} \;=\; \sum_l \left\| \frac{A_S^l}{\|A_S^l\|_2} - \frac{A_T^l}{\|A_T^l\|_2} \right\|_2^2.$$

在 CIFAR-10 数据集上,将 ResNet-110 蒸馏到 ResNet-20 的结果如下:

方法准确率
ResNet-20 基线91.3%
仅响应蒸馏91.8%
注意力传递92.4%

Gram 矩阵蒸馏(NST)#

$$G \;=\; F^\top F,$$

其中 $G_{ij}$ 表示通道 $i$ 和通道 $j$ 的相关性。这是一种二阶统计量,反映的是“纹理”而非“内容”,逐点匹配无法捕捉这种信息。


关系蒸馏:传递样本间的关系#

不单独处理每个样本,而是匹配样本之间的关系

RKD:关系知识蒸馏#

两种方式:

$$\mathcal{L}_{\text{dist}} \;=\; \sum_{(i,j)} \ell_\delta\!\left(d_S(i,j),\, d_T(i,j)\right).$$ $$\mathcal{L}_{\text{angle}} \;=\; \sum_{(i,j,k)} \ell_\delta\!\left(\angle_S(i,j,k),\, \angle_T(i,j,k)\right).$$

实际效果表明,角度比距离更重要($\lambda_{\text{angle}} = 2$$\lambda_{\text{dist}} = 1$ )。因为角度不受尺度影响,能更好地捕捉相对几何结构。

CRD:对比表示蒸馏#

$$\mathcal{L}_{\text{CRD}} \;=\; -\log \frac{\exp\!\left(f_S(x)^\top f_T(x) / \tau\right)}{\sum_{x'} \exp\!\left(f_S(x)^\top f_T(x') / \tau\right)}.$$

目标是最大化学生与教师特征之间的互信息。当学生模型很小时(比如从 ResNet-32 蒸馏到 ResNet-8),CRD 的优势更加明显——在 CIFAR-100 上,性能比仅使用响应蒸馏高出 2% 以上。


自蒸馏:不需要单独的教师模型#

如果没有大模型当教师,照样可以做蒸馏。

Born-Again Networks:每一代都从前一代相同架构的模型中学习

重生网络#

一个令人意外的发现:把模型蒸馏到架构完全相同的副本上,准确率居然还能提升。

  1. 正常训练 $M_1$
  2. $M_1$ 当教师训练 $M_2$ (架构相同)。
  3. $M_2$ 当教师训练 $M_3$
  4. 准确率不再提升时停止。

CIFAR-100 的实验结果:

代次准确率
1(基线)74.3%
2(BAN)75.2%
375.4%
475.5%

背后有两个互补的解释:软标签提供更平滑的梯度,减少了过拟合的风险;每一代模型探索损失曲面的不同区域,相当于在推理时免费获得了一个隐式集成的效果。

深度互学习#

$$\mathcal{L}_i \;=\; \mathcal{L}_{\text{CE}}^i + \frac{1}{M-1} \sum_{j \neq i} \mathrm{KL}\!\left(P_j \,\|\, P_i\right).$$

不需要预训练。不同的随机种子让模型犯不同的错误,互相监督让每个模型吸收其他模型的优点。在 CIFAR-100 上,两个 ResNet-32 联合训练后,各自达到 72.1% 的准确率,而单独训练只有 70.2%。


蒸馏 + 压缩#

蒸馏和其他压缩方法配合得很好,不是互相排斥的关系。

量化感知蒸馏#

把 FP32 转成 INT8,内存能省 4 倍,支持的硬件上还能快 2 到 4 倍,但精度会掉一些。蒸馏能把这个损失补回来一大半:

ResNet-18 / ImageNetTop-1
FP32 基线69.8%
INT8,无蒸馏68.5%(-1.3)
INT8,加蒸馏69.2%(-0.6)

蒸馏直接把量化的精度损失减了一半。

剪枝感知蒸馏#

剪掉重要性最低的通道后,用教师模型的软标签微调学生模型:

VGG-16 / CIFAR-10准确率参数量
原始93.5%14.7M
剪 70%,无蒸馏92.1%4.4M
剪 70%,加蒸馏93.0%4.4M

实际操作流程是这样的:先训练教师模型,然后通过剪枝确定学生模型结构,接着用蒸馏恢复精度,最后再做量化。每一步都可以借助教师模型,最终能做到 10 到 20 倍的压缩,精度损失控制在 1% 以内。

参数量轴上,蒸馏曲线全面优于从头训练曲线

案例:DistilBERT#

Sanh 等人(2019)用三重损失(cosine + MLM + KD)把 BERT-base 蒸馏成 DistilBERT,给出了 NLP 领域蒸馏的经典结果:

DistilBERT:参数减少 40%、推理提速 60%、保留 97% 的 GLUE

BERT-baseDistilBERT变化
参数量110M66M-40%
推理延迟410 ms250 ms-39%
GLUE(平均)79.577.0-3%

这三个数字让知识蒸馏在 NLP 工程领域真正站稳了脚跟。

完整实现#

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from typing import List
import copy

# ===== 损失函数 =====

class KDLoss(nn.Module):
    """响应蒸馏:带温度的 KL 散度,混合硬标签交叉熵。"""
    def __init__(self, temperature: float = 4.0, alpha: float = 0.9):
        super().__init__()
        self.T = temperature
        self.alpha = alpha

    def forward(self, student_logits, teacher_logits, labels):
        kd = F.kl_div(
            F.log_softmax(student_logits / self.T, dim=1),
            F.softmax(teacher_logits / self.T, dim=1),
            reduction='batchmean',
        ) * (self.T ** 2)
        ce = F.cross_entropy(student_logits, labels)
        return self.alpha * kd + (1 - self.alpha) * ce

class FeatureDistillLoss(nn.Module):
    """FitNets:1x1 卷积投影后计算特征图均方误差。"""
    def __init__(self, student_ch: int, teacher_ch: int):
        super().__init__()
        self.proj = nn.Conv2d(student_ch, teacher_ch, 1, bias=False)

    def forward(self, student_feat, teacher_feat):
        return F.mse_loss(self.proj(student_feat), teacher_feat)

class AttentionTransferLoss(nn.Module):
    """匹配通道聚合后的归一化空间注意力图。"""
    def __init__(self, p: float = 2.0):
        super().__init__()
        self.p = p

    def _attn(self, feat):
        a = torch.sum(torch.abs(feat) ** self.p, dim=1, keepdim=True)
        return a / (a.sum(dim=[2, 3], keepdim=True) + 1e-8)

    def forward(self, s_feat, t_feat):
        return F.mse_loss(self._attn(s_feat), self._attn(t_feat))

class RelationalLoss(nn.Module):
    """RKD:样本间的距离和角度关系。"""
    def __init__(self, w_dist: float = 1.0, w_angle: float = 2.0):
        super().__init__()
        self.w_dist, self.w_angle = w_dist, w_angle

    def forward(self, s_feat, t_feat):
        s = F.normalize(s_feat, p=2, dim=1)
        t = F.normalize(t_feat, p=2, dim=1)
        d_loss = F.mse_loss(torch.cdist(s, s), torch.cdist(t, t))
        a_loss = F.mse_loss(s @ s.t(), t @ t.t())
        return self.w_dist * d_loss + self.w_angle * a_loss

# ===== 模型(返回 logits 和中间特征)=====

class _ResNetBackbone(nn.Module):
    def __init__(self, ctor, num_classes: int):
        super().__init__()
        self.net = ctor(weights=None)
        self.net.fc = nn.Linear(self.net.fc.in_features, num_classes)

    def forward(self, x):
        n = self.net
        x = n.maxpool(n.relu(n.bn1(n.conv1(x))))
        feats: List[torch.Tensor] = []
        for layer in (n.layer1, n.layer2, n.layer3, n.layer4):
            x = layer(x)
            feats.append(x)
        logits = n.fc(torch.flatten(n.avgpool(x), 1))
        return logits, feats

class TeacherNet(_ResNetBackbone):
    def __init__(self, num_classes: int = 10):
        super().__init__(torchvision.models.resnet34, num_classes)

class StudentNet(_ResNetBackbone):
    def __init__(self, num_classes: int = 10):
        super().__init__(torchvision.models.resnet18, num_classes)

# ===== 训练器(支持多种蒸馏模式)=====

class DistillationTrainer:
    def __init__(self, teacher, student, device='cpu',
                 mode: str = 'response',
                 temperature: float = 4.0, alpha: float = 0.9):
        self.teacher = teacher.to(device).eval()
        self.student = student.to(device)
        self.device = device
        self.mode = mode
        for p in self.teacher.parameters():
            p.requires_grad = False

        self.kd = KDLoss(temperature, alpha)
        if mode in ('feature', 'combined'):
            ch = [64, 128, 256, 512]
            self.feat = nn.ModuleList(
                [FeatureDistillLoss(s, t).to(device)
                 for s, t in zip(ch, ch)])
        if mode in ('attention', 'combined'):
            self.attn = AttentionTransferLoss()
        if mode in ('relation', 'combined'):
            self.rel = RelationalLoss()

    def _loss(self, s_logits, s_feats, t_logits, t_feats, y):
        loss = self.kd(s_logits, t_logits, y)
        if self.mode in ('feature', 'combined'):
            fl = sum(fn(sf, tf)
                     for fn, sf, tf in zip(self.feat, s_feats, t_feats))
            loss = loss + 0.5 * fl / len(s_feats)
        if self.mode in ('attention', 'combined'):
            al = sum(self.attn(sf, tf)
                     for sf, tf in zip(s_feats, t_feats))
            loss = loss + 0.3 * al / len(s_feats)
        if self.mode in ('relation', 'combined'):
            loss = loss + 0.2 * self.rel(
                s_feats[-1].flatten(1), t_feats[-1].flatten(1))
        return loss

    def train_epoch(self, loader, optimizer):
        self.student.train()
        total, correct, n = 0.0, 0, 0
        for x, y in loader:
            x, y = x.to(self.device), y.to(self.device)
            with torch.no_grad():
                t_logits, t_feats = self.teacher(x)
            s_logits, s_feats = self.student(x)
            loss = self._loss(s_logits, s_feats, t_logits, t_feats, y)
            optimizer.zero_grad(); loss.backward(); optimizer.step()
            total += loss.item() * y.size(0)
            correct += (s_logits.argmax(1) == y).sum().item()
            n += y.size(0)
        return total / n, 100.0 * correct / n

    @torch.no_grad()
    def evaluate(self, loader):
        self.student.eval()
        correct, n = 0, 0
        for x, y in loader:
            x, y = x.to(self.device), y.to(self.device)
            logits, _ = self.student(x)
            correct += (logits.argmax(1) == y).sum().item()
            n += y.size(0)
        return 100.0 * correct / n

# ===== 自蒸馏:Born-Again Networks =====

def self_distill(model_class, train_loader, test_loader,
                 num_generations: int = 3, epochs_per_gen: int = 10,
                 device: str = 'cpu', temperature: float = 4.0):
    teacher = None
    for gen in range(num_generations):
        student = model_class().to(device)
        opt = optim.SGD(student.parameters(), lr=0.1,
                        momentum=0.9, weight_decay=5e-4)
        sched = optim.lr_scheduler.CosineAnnealingLR(opt, epochs_per_gen)

        for _ in range(epochs_per_gen):
            student.train()
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = student(x)
                loss = F.cross_entropy(logits, y)
                if teacher is not None:
                    with torch.no_grad():
                        t_logits, _ = teacher(x)
                    kd = F.kl_div(
                        F.log_softmax(logits / temperature, dim=1),
                        F.softmax(t_logits / temperature, dim=1),
                        reduction='batchmean') * temperature ** 2
                    loss = 0.1 * loss + 0.9 * kd
                opt.zero_grad(); loss.backward(); opt.step()
            sched.step()

        student.eval()
        correct, n = 0, 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                correct += (student(x)[0].argmax(1) == y).sum().item()
                n += y.size(0)
        print(f"第 {gen + 1} 代:准确率 {100.0 * correct / n:.2f}%")

        teacher = copy.deepcopy(student).eval()
        for p in teacher.parameters():
            p.requires_grad = False

# ===== 主流程 =====

def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    norm = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                (0.2023, 0.1994, 0.2010))
    train_tf = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), norm])
    test_tf = transforms.Compose([transforms.ToTensor(), norm])

    train_set = torchvision.datasets.CIFAR10(
        './data', train=True, download=True, transform=train_tf)
    test_set = torchvision.datasets.CIFAR10(
        './data', train=False, download=True, transform=test_tf)
    train_loader = DataLoader(train_set, 128, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_set, 128, num_workers=2)

    teacher = TeacherNet(10)
    student = StudentNet(10)
    trainer = DistillationTrainer(
        teacher, student, device, mode='combined',
        temperature=4.0, alpha=0.7)
    opt = optim.SGD(student.parameters(), lr=0.1,
                    momentum=0.9, weight_decay=5e-4)
    for epoch in range(20):
        loss, acc = trainer.train_epoch(train_loader, opt)
        test_acc = trainer.evaluate(test_loader)
        print(f"第 {epoch + 1} 轮:损失 {loss:.4f} "
              f"训练准确率 {acc:.1f}% 测试准确率 {test_acc:.1f}%")

    print("\n自蒸馏(Born-Again Networks):")
    self_distill(StudentNet, train_loader, test_loader,
                 num_generations=3, epochs_per_gen=10, device=device)

if __name__ == '__main__':
    main()

各模块的职责#

模块作用
KDLoss带温度的软 KL 散度,与硬标签交叉熵加权融合。
FeatureDistillLossFitNets:通过 1×1 卷积投影后计算特征图均方误差。
AttentionTransferLoss匹配通道聚合后的归一化空间注意力图。
RelationalLossRKD:样本间的距离和角度关系。
DistillationTrainer统一支持五种蒸馏模式的训练器。
self_distillBorn-Again Networks:相同架构的迭代自蒸馏。

温度调度策略#

蒸馏方法:验证精度曲线与温度敏感度。

恒定 $\tau = 4$ 是教科书式的默认选择,确实有效。但温度控制的是学生被要求学习的内容,而这一目标会随着训练进程动态变化。训练初期,学生一无所知——此时最受益于观察教师对所有类别的完整排序;训练后期,学生已内化了特征空间的几何结构,需要做出自信的预测。单一标量无法同时满足这两个阶段的需求。

$$\tau(t) \;=\; \tau_{\text{start}} - (\tau_{\text{start}} - \tau_{\text{end}}) \cdot \frac{t}{T},$$

从高值(分布平坦,包含丰富“暗知识”)逐渐降至低值(分布尖锐,提供明确监督)。若在 $T$ 个 epoch 中设置 $\tau_{\text{start}} = 20$$\tau_{\text{end}} = 2$ ,学生将在前半段学习类别间的相对相似性,后半段则聚焦于分类决策的锐化。

从机制上看,为何这能奏效?当 $\tau$ 较高时,softmax 近似于 logits 的线性函数,KL 散度对 student logits 的梯度主要由 teacher logits 之间的差异主导——即关系结构;当 $\tau$ 较低时,softmax 近似 one-hot 分布,KL 散度退化为对 teacher argmax 的标准交叉熵——目标虽硬但无噪声。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch.nn.functional as F

class TemperatureSchedule:
    """在 T 个 epoch 内从 tau_start 线性下降到 tau_end"""
    def __init__(self, tau_start: float, tau_end: float, T: int):
        self.t0, self.t1, self.T = tau_start, tau_end, T

    def __call__(self, epoch: int) -> float:
        frac = min(epoch / max(self.T - 1, 1), 1.0)
        return self.t0 - (self.t0 - self.t1) * frac

def kd_loss_scheduled(s_logits, t_logits, y, tau: float, alpha: float = 0.9):
    kd = F.kl_div(
        F.log_softmax(s_logits / tau, dim=1),
        F.softmax(t_logits / tau, dim=1),
        reduction='batchmean',
    ) * (tau ** 2)
    ce = F.cross_entropy(s_logits, y)
    return alpha * kd + (1 - alpha) * ce

# 基准实验:CIFAR-10 上 ResNet-18 教师 → ResNet-8 学生
sched = TemperatureSchedule(tau_start=20.0, tau_end=2.0, T=epochs)
for epoch in range(epochs):
    tau = sched(epoch)
    student.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            t_logits, _ = teacher(x)
        s_logits, _ = student(x)
        loss = kd_loss_scheduled(s_logits, t_logits, y, tau=tau, alpha=0.9)
        opt.zero_grad(); loss.backward(); opt.step()

CIFAR-10 实验结果(ResNet-18 教师蒸馏至 ResNet-8 学生,共 100 个 epoch):

调度策略测试准确率
恒定 $\tau = 4$88.1%
线性 $20 \to 2$89.4%
线性 $50 \to 4$89.7%

仅将标量替换为函数,即可免费提升 1.3–1.6 个百分点。调度策略的主要收益来自前半段训练——此时与恒定温度基线的差距迅速拉开;最后 20 个 epoch 使用低 $\tau$ 主要起稳定作用。

需注意一点:若 $\tau_{\text{start}}$ 远超 50,软目标将趋近均匀分布——即使有 $\tau^2$ 缩放也无法补偿,梯度相对于交叉熵项变得极小,导致早期训练停滞或震荡。若需“高温”启动,建议将 $\tau$ 上限设为 30 左右,并适当延长调度周期。

过渡:温度决定了学生如何倾听教师的 logits。下一步则是改变学生倾听的内容——即特征及其间的几何关系。


CRD:对比表示蒸馏(Contrastive Representation Distillation)#

我们此前所见的基于特征的 KD 方法(如 FitNets、注意力迁移)均采用逐点匹配:强制学生特征 $f_S(x)$ 逼近教师特征 $f_T(x)$ 。这是一种强约束,且往往浪费资源——它要求学生复现教师的确切激活值,即便下游分类任务仅依赖特征间的相对关系。

对比表示蒸馏(CRD)放松了这一约束。它不再直接匹配 $f_S(x)$$f_T(x)$ ,而是要求同一输入的师生特征对彼此更接近,而非与其他输入的师生特征对更接近。具体而言,输入 $x$ 的教师表征作为正样本,同 batch 中其他输入的教师表征则作为负样本。

$$\mathcal{L}_{\text{CRD}} \;=\; -\log \frac{\exp\!\left(s(z_s, z_t) / \tau\right)}{\sum_{j} \exp\!\left(s(z_s, z_j^-) / \tau\right)},$$

其中 $s(\cdot, \cdot)$ 为余弦相似度,$z_s = g_S(f_S(x))$$z_t = g_T(f_T(x))$ ,而 $g_S, g_T$ 为小型投影头(一层非线性层足矣)。最大化该损失可下界估计师生表征间的互信息。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
import torch.nn.functional as F

class ProjectionHead(nn.Module):
    def __init__(self, in_dim: int, out_dim: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True),
            nn.Linear(in_dim, out_dim))

    def forward(self, x):
        return F.normalize(self.net(x), p=2, dim=1)

class CRDLoss(nn.Module):
    """基于 batch 内负样本的 (student, teacher) 对 InfoNCE 损失"""
    def __init__(self, s_dim: int, t_dim: int,
                 proj_dim: int = 128, tau: float = 0.07):
        super().__init__()
        self.g_s = ProjectionHead(s_dim, proj_dim)
        self.g_t = ProjectionHead(t_dim, proj_dim)
        self.tau = tau

    def forward(self, s_feat, t_feat):
        # 如有必要,对空间维度做池化
        if s_feat.dim() == 4:
            s_feat = F.adaptive_avg_pool2d(s_feat, 1).flatten(1)
        if t_feat.dim() == 4:
            t_feat = F.adaptive_avg_pool2d(t_feat, 1).flatten(1)
        z_s = self.g_s(s_feat)             # (B, D)
        z_t = self.g_t(t_feat).detach()    # 固定教师
        logits = z_s @ z_t.t() / self.tau  # (B, B):第 i 行第 i 列为正样本
        labels = torch.arange(z_s.size(0), device=z_s.device)
        return F.cross_entropy(logits, labels)
$$\mathcal{L} \;=\; \alpha \tau^2 \mathcal{L}_{\text{KD}} + (1 - \alpha) \mathcal{L}_{\text{hard}} + \beta \mathcal{L}_{\text{CRD}}.$$

CIFAR-100 实验(ResNet-50 教师蒸馏至 ResNet-18):

方法测试准确率
硬标签基线73.3%
原始响应 KD75.5%
+ CRD76.7% (+1.2)

当教师与学生容量差异极大时,CRD 提升最为显著——这正是逐点匹配失效的场景:学生根本无法复现教师的特征。CRD 转而要求更弱的条件(相对相似性),反而获得了更好的效果。

过渡:但并非所有蒸馏训练都能顺利收敛——有时损失停滞不前,甚至发散。


蒸馏失败的情形#

蒸馏是一种稳健技术,但并非万能。以下三种失败模式较为常见,值得明确诊断。

$$D_{\text{KL}}(p_t \,\|\, p_s) \;=\; \sum_c p_t(c) \log \frac{p_t(c)}{p_s(c)}$$

在前 5–10 个 epoch 内就 plateau 在高位,不再下降。交叉熵可能继续降低(学生学会了 argmax),但未能吸收“暗知识”。例如,一个 10 万参数的学生试图模仿 1 亿参数的教师,几乎立刻就会遇到此问题。

模态不匹配(Modality mismatch):教师与学生接收不同输入——如 RGB 教师蒸馏给灰度学生、高分辨率教师蒸馏给低分辨率学生,或跨语言场景中分词器不一致。教师的软目标编码了学生无法获取的特征,从学生视角看,$p_t$ 几乎等同于噪声。症状:KD 损失高且波动大,联合损失甚至劣于仅用硬标签训练。

在线蒸馏不稳定(Online distillation instability):在在线或互学习(mutual-learning)设置中,教师与学生同步更新。学生追逐一个移动目标,若教师更新幅度过大,KL 散度轨迹会震荡而非衰减。症状:epoch 间 KL 值上下波动达 20–30%,无明显下降趋势。

以下简单诊断工具可识别上述三种情形:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@torch.no_grad()
def kd_diagnostic(student, teacher, loader, n_epochs: int,
                  optimizer, device='cpu', tau: float = 4.0):
    """返回每 epoch 的平均 KL(p_t || p_s),并判断主要瓶颈"""
    teacher.eval()
    kl_history = []
    for epoch in range(n_epochs):
        student.train()
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            t_logits, _ = teacher(x)
            s_logits, _ = student(x)
            loss = F.kl_div(
                F.log_softmax(s_logits / tau, dim=1),
                F.softmax(t_logits / tau, dim=1),
                reduction='batchmean') * (tau ** 2)
            optimizer.zero_grad(); loss.backward(); optimizer.step()
        # 在干净 pass 上测量 KL
        student.eval()
        kl_sum, n = 0.0, 0
        for x, _ in loader:
            x = x.to(device)
            t_logits, _ = teacher(x)
            s_logits, _ = student(x)
            p_t = F.softmax(t_logits / tau, dim=1)
            log_p_s = F.log_softmax(s_logits / tau, dim=1)
            kl_sum += F.kl_div(log_p_s, p_t, reduction='batchmean').item() * x.size(0)
            n += x.size(0)
        kl_history.append(kl_sum / n)

    init, mid, last = kl_history[0], kl_history[len(kl_history) // 2], kl_history[-1]
    if last > 0.5 * init:
        verdict = 'capacity'   # KL 在整个训练中未减半
    elif max(kl_history[-5:]) - min(kl_history[-5:]) > 0.2 * last:
        verdict = 'lr / instability'
    elif init > 5.0:
        verdict = 'data / modality mismatch'
    else:
        verdict = 'healthy'
    return kl_history, verdict

健康训练通常在约 50 个 epoch 内将初始 KL 减半,并持续缓慢下降。若不符合,诊断结果可指导修复:

  • 容量不足:增大一级学生宽度或深度,或加入中间层特征蒸馏,使单位参数获得更丰富的信号。若学生规模固定,可先蒸馏至一个中间教师(TA-KD),再由该中间教师蒸馏至目标学生。
  • 模态不匹配:蒸馏前对齐输入。将教师输入空间转换为与学生一致,或训练一个小型适配器,将学生特征映射至教师空间后再蒸馏。
  • 不稳定:定期冻结教师(每 $k$ 步更新一次,而非每步),降低学生学习率,或在前几个 epoch 内将 KD 权重 $\alpha$ 从 0 逐步 warmup。

共通规律是:KL 轨迹比测试准确率更敏感。它能提前告诉你学生是否真正在听,而不必等到最终评估才知道是否学到了有用知识。

过渡:掌握这些诊断手段后,蒸馏便从一个寄托希望的损失项,转变为可监控的训练流程。剩下的问题则更偏实践——如何选择超参、压缩极限在哪,以及蒸馏如何与模型压缩工具箱中的其他技术协同工作。

常见问题#

温度怎么选?#

$\tau = 4$ 开始。类别越多、类别之间越相似(比如细粒度的物种分类),温度就要越高,ImageNet 规模的问题可以调到 $\tau = 20$ 。在验证集上对 $\{2, 4, 8, 12, 20\}$ 做网格搜索就行。

学生模型能压缩到多小?#

4–10 倍压缩基本不掉点;超过 50 倍的话,蒸馏也救不了——通常会掉 5–10 个点。经验规律是:当学生模型刚好有足够的容量表示教师模型的知识,但没能力直接从标签学出来时,蒸馏的效果最好。

自蒸馏为什么有效?学生和教师容量不是一样吗?

两个原因。第一,软目标比 one-hot 标签正则化更强,尤其在数据量小时更明显。第二,每一代训练都会落在损失曲面的不同区域,多次迭代相当于零成本的隐式集成。CIFAR-100 上一般能提升 1–2%,3 代之后收益递减。

可以用多个教师吗?#

可以。把它们的软输出平均一下(均匀加权或者按验证准确率加权)。这样通常能提升鲁棒性,而不是单纯的头部准确率,代价是需要训练多个教师模型。

先剪枝还是先蒸馏?#

两步都做,顺序是这样的:先训练教师模型 → 剪枝确定学生模型结构 → 蒸馏恢复精度 → 量化。每一步都有教师模型作为参考,比单独操作更能保留知识。

总结#

知识蒸馏就是教小模型学会大模型的思维方式:

  • 软标签保留了 one-hot 标签丢失的类间结构,这部分信息就是所谓的暗知识。
  • 温度是一个标量,用来调节学生模型能学到多少暗知识。
  • 特征蒸馏关系蒸馏不仅关注 logits,还对齐中间表示和样本间的几何关系。
  • 自蒸馏不需要额外的教师模型,同时还能免费获得一个集成模型。
  • 配合剪枝量化,蒸馏可以实现 10 到 20 倍的压缩率,精度损失控制在个位数。

下一篇:第六部分 —— 多任务学习 ,多个任务共享参数,提升泛化能力和效率。


参考文献#

  1. Hinton, Vinyals, Dean (2015). Distilling the Knowledge in a Neural Network. arXiv:1503.02531
  2. Romero et al. (2015). FitNets: Hints for Thin Deep Nets. ICLR. arXiv:1412.6550
  3. Zagoruyko & Komodakis (2017). Paying More Attention to Attention. ICLR. arXiv:1612.03928
  4. Park et al. (2019). Relational Knowledge Distillation. CVPR. arXiv:1904.05068
  5. Tian et al. (2020). Contrastive Representation Distillation. ICLR. arXiv:1910.10699
  6. Furlanello et al. (2018). Born-Again Neural Networks. ICML. arXiv:1805.04770
  7. Sanh et al. (2019). DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter. arXiv:1910.01108
  8. Zhang et al. (2018). Deep Mutual Learning. CVPR. arXiv:1706.00384
本系列

迁移学习 12 篇

  1. 01 迁移学习(一):基础与核心概念
  2. 02 迁移学习(二):预训练与微调
  3. 03 迁移学习(三):域适应
  4. 04 迁移学习(四):小样本学习
  5. 05 迁移学习(五):知识蒸馏 当前
  6. 06 迁移学习(六):多任务学习
  7. 07 迁移学习(七):零样本学习
  8. 08 迁移学习(八):多模态迁移
  9. 09 迁移学习(九):参数高效微调
  10. 10 迁移学习(十):持续学习
  11. 11 迁移学习(十一):跨语言迁移
  12. 12 迁移学习(十二):工业应用与最佳实践

读有所得?

GitHub 关注我 → 新文周更

GitHub