
迁移学习(五):知识蒸馏
把大模型的能力压进小模型而几乎不掉点:暗知识、温度缩放、响应/特征/关系蒸馏、自蒸馏与多策略实现的完整指南。
我训练了一个 340M 参数的 BERT 模型,准确率达到 95%,但产品团队希望将其部署到手机上,而手机只能容纳约 10M 参数。如果从头训练一个 10M 的小模型,准确率只能达到 85%;这时,知识蒸馏几乎可以弥补大部分差距——通过让小模型学习大模型的输出分布,而不仅仅是硬标签,最终准确率可以提升到 92%。
Hinton 提出了一个关键洞见:教师模型那些“错误”的预测并非噪声,而是信息。例如,当教师模型对一张猫的图片预测出 0.62 的“猫”、0.14 的“老虎”、0.07 的“狗”和 0.008 的“飞机”,它实际上在告诉你:猫和老虎很像,和狗有点像,但和飞机完全不像。这种类别之间的相似性结构——即暗知识(dark knowledge)——在 one-hot 标签中是看不到的,学生模型一旦学到这些暗知识,就能以很小的参数量实现远超预期的表现。
你将学到什么#
- 为什么软标签比硬标签包含更多信息
- 温度缩放:一个简单参数,决定教师模型透露多少暗知识
- 知识蒸馏的三种类型:响应、特征、关系
- 自蒸馏和互学习:无需预训练教师也能有效果
- 知识蒸馏结合量化和剪枝:压缩率轻松突破 10 倍
- 干净的 PyTorch 实现,支持全部五种蒸馏模式
前置要求: 本系列第 1–2 篇,PyTorch 基础。
为什么蒸馏有效#

部署的现实难题#
大模型在基准测试中表现优异,但在手机、汽车和云服务账单上却让人头疼。四个限制条件迫使我们转向小模型:
- 内存:移动设备和物联网设备根本装不下几十亿参数。
- 延迟:自动驾驶需要毫秒级响应,而不是秒级。
- 成本:每天调用十亿次的模型,每个 FLOP 都是真金白银。
- 能耗:边缘设备靠电池供电,不是靠发电厂。
剪枝和量化直接修改模型本身,但会损失精度。蒸馏则另辟蹊径:训练一个小模型去模仿大模型的输出分布,而不仅仅是它的 argmax。学生继承了教师的归纳偏置,但不需要继承其参数。
暗知识:软标签到底教了什么#
假设教师对一张猫的图片分类,输出如下:
| 类别 | 硬标签 | 教师输出 |
|---|---|---|
| 猫 | 1.0 | 0.62 |
| 老虎 | 0.0 | 0.14 |
| 豹子 | 0.0 | 0.10 |
| 狗 | 0.0 | 0.07 |
| 汽车 | 0.0 | 0.012 |
硬标签只说“是猫”,熵为零。软标签则说“是猫,但也像老虎、有点像豹子、隐隐约约像狗、绝对不像汽车”——熵大于零。这份排序是教师通过数百万次梯度更新学到的类别相似性课程,完全免费。

不是匹配标签,而是匹配分布#
$$\mathcal{L}_{\text{hard}} \;=\; -\sum_c y_c \log \sigma(z_c^S).$$ $$\mathcal{L}_{\text{KD}} \;=\; -\sum_c \sigma(z_c^T / \tau) \, \log \sigma(z_c^S / \tau).$$由于教师固定不变,这等价于最小化 $\mathrm{KL}\!\left(\sigma(z^T/\tau) \,\|\, \sigma(z^S/\tau)\right)$ 。学生学习的不再是一个标签,而是一整条概率分布。
温度:暗知识的调节旋钮#
$$\sigma(z_i; \tau) \;=\; \frac{\exp(z_i / \tau)}{\sum_j \exp(z_j / \tau)}.$$| 温度 | 效果 |
|---|---|
| $\tau \to 0$ | 退化为 one-hot(argmax) |
| $\tau = 1$ | 标准 softmax |
| $\tau = 4$ — 10 | 显式化类间相似性 |
| $\tau \to \infty$ | 趋近均匀分布 |
举个例子,logits $z = [5, 3, 1]$ :
- $\tau = 1$ :$[0.84, 0.11, 0.04]$ ,第三类几乎消失。
- $\tau = 3$ :$[0.51, 0.31, 0.18]$ ,第三类的信息得以保留。
这意味着学生可以直接从教师 logits 的相对大小中学习,而不受指数函数的扭曲影响。

组合损失#
$$\mathcal{L} \;=\; \alpha \cdot \tau^2 \cdot \mathcal{L}_{\text{KD}} \;+\; (1 - \alpha) \cdot \mathcal{L}_{\text{hard}}.$$- $\alpha \in [0.5, 0.9]$ :表示对教师的信任程度。
- $\tau^2$ :补偿高温下的梯度收缩。软目标项的梯度按 $1/\tau^2$ 缩放,因此将损失乘回 $\tau^2$ ,让两项量级可比。
- $\mathcal{L}_{\text{hard}}$ :与真实标签的标准交叉熵。
经验默认值:$\tau = 4$ ;当教师远强于学生时取 $\alpha = 0.9$ ,当两者容量接近时取 $\alpha = 0.5$ 。
响应式蒸馏:只动输出层#
经典方法很简单——只匹配教师模型的输出层,其他一概不管。
Hinton 的算法#
- 用完整数据集训练教师模型 $T$ 。
- 对每个输入 $x$ ,计算软标签 $\sigma(z^T(x) / \tau)$ 并保存。
- 用组合损失函数训练学生模型 $S$ 。
- 部署学生模型时,设置 $\tau = 1$ 。
ImageNet 上的典型结果:
| 设置 | Top-1 |
|---|---|
| ResNet-34 教师 | 73.3% |
| ResNet-18 从头训 | 69.8% |
| ResNet-18 蒸馏 | 71.4%(+1.6) |
改一下损失函数,白捡 1.6 个点。
蒸馏 vs. 标签平滑#
$$y_c' \;=\; (1 - \epsilon) \, y_c + \epsilon / C.$$区别在于“软”的来源不同。标签平滑对所有样本都用同样的均匀分布来软化。蒸馏则是为每个样本生成一个独特的软分布,这个分布来自教师模型。比如,一张波斯猫的图片会在“老虎”“豹子”上分配权重;一辆轿车的图片会在“卡车”“旅行车”上分配权重。这就是蒸馏效果始终优于标签平滑的原因。
基于特征的蒸馏:让中间层也对齐#
响应式蒸馏只关注顶层的匹配,而基于特征的蒸馏还会对齐中间表示——信号更丰富,学生模型能从更多地方学习教师模型的几何特性。

FitNets:提示学习#
$$\mathcal{L}_{\text{hint}} \;=\; \| W_r \, F_S^l - F_T^l \|_F^2,$$其中 $W_r$ 是一个可学习的 1×1 投影,用来调整学生模型的通道维度以匹配教师模型。Romero 等人采用两阶段训练方法:
- 先训练学生模型的浅层网络和投影模块,使其对齐教师模型选定的“提示层”。
- 冻结浅层网络后,用标准蒸馏方法继续训练其余部分。
注意力传递(Attention Transfer)#
$$A(F) \;=\; \sum_c |F_c|^p, \quad p = 2.$$ $$\mathcal{L}_{\text{AT}} \;=\; \sum_l \left\| \frac{A_S^l}{\|A_S^l\|_2} - \frac{A_T^l}{\|A_T^l\|_2} \right\|_2^2.$$在 CIFAR-10 数据集上,将 ResNet-110 蒸馏到 ResNet-20 的结果如下:
| 方法 | 准确率 |
|---|---|
| ResNet-20 基线 | 91.3% |
| 仅响应蒸馏 | 91.8% |
| 注意力传递 | 92.4% |
Gram 矩阵蒸馏(NST)#
$$G \;=\; F^\top F,$$其中 $G_{ij}$ 表示通道 $i$ 和通道 $j$ 的相关性。这是一种二阶统计量,反映的是“纹理”而非“内容”,逐点匹配无法捕捉这种信息。
关系蒸馏:传递样本间的关系#
不单独处理每个样本,而是匹配样本之间的关系。
RKD:关系知识蒸馏#
两种方式:
$$\mathcal{L}_{\text{dist}} \;=\; \sum_{(i,j)} \ell_\delta\!\left(d_S(i,j),\, d_T(i,j)\right).$$ $$\mathcal{L}_{\text{angle}} \;=\; \sum_{(i,j,k)} \ell_\delta\!\left(\angle_S(i,j,k),\, \angle_T(i,j,k)\right).$$实际效果表明,角度比距离更重要($\lambda_{\text{angle}} = 2$ ,$\lambda_{\text{dist}} = 1$ )。因为角度不受尺度影响,能更好地捕捉相对几何结构。
CRD:对比表示蒸馏#
$$\mathcal{L}_{\text{CRD}} \;=\; -\log \frac{\exp\!\left(f_S(x)^\top f_T(x) / \tau\right)}{\sum_{x'} \exp\!\left(f_S(x)^\top f_T(x') / \tau\right)}.$$目标是最大化学生与教师特征之间的互信息。当学生模型很小时(比如从 ResNet-32 蒸馏到 ResNet-8),CRD 的优势更加明显——在 CIFAR-100 上,性能比仅使用响应蒸馏高出 2% 以上。
自蒸馏:不需要单独的教师模型#
如果没有大模型当教师,照样可以做蒸馏。

重生网络#
一个令人意外的发现:把模型蒸馏到架构完全相同的副本上,准确率居然还能提升。
- 正常训练 $M_1$ 。
- 用 $M_1$ 当教师训练 $M_2$ (架构相同)。
- 用 $M_2$ 当教师训练 $M_3$ 。
- 准确率不再提升时停止。
CIFAR-100 的实验结果:
| 代次 | 准确率 |
|---|---|
| 1(基线) | 74.3% |
| 2(BAN) | 75.2% |
| 3 | 75.4% |
| 4 | 75.5% |
背后有两个互补的解释:软标签提供更平滑的梯度,减少了过拟合的风险;每一代模型探索损失曲面的不同区域,相当于在推理时免费获得了一个隐式集成的效果。
深度互学习#
$$\mathcal{L}_i \;=\; \mathcal{L}_{\text{CE}}^i + \frac{1}{M-1} \sum_{j \neq i} \mathrm{KL}\!\left(P_j \,\|\, P_i\right).$$不需要预训练。不同的随机种子让模型犯不同的错误,互相监督让每个模型吸收其他模型的优点。在 CIFAR-100 上,两个 ResNet-32 联合训练后,各自达到 72.1% 的准确率,而单独训练只有 70.2%。
蒸馏 + 压缩#
蒸馏和其他压缩方法配合得很好,不是互相排斥的关系。
量化感知蒸馏#
把 FP32 转成 INT8,内存能省 4 倍,支持的硬件上还能快 2 到 4 倍,但精度会掉一些。蒸馏能把这个损失补回来一大半:
| ResNet-18 / ImageNet | Top-1 |
|---|---|
| FP32 基线 | 69.8% |
| INT8,无蒸馏 | 68.5%(-1.3) |
| INT8,加蒸馏 | 69.2%(-0.6) |
蒸馏直接把量化的精度损失减了一半。
剪枝感知蒸馏#
剪掉重要性最低的通道后,用教师模型的软标签微调学生模型:
| VGG-16 / CIFAR-10 | 准确率 | 参数量 |
|---|---|---|
| 原始 | 93.5% | 14.7M |
| 剪 70%,无蒸馏 | 92.1% | 4.4M |
| 剪 70%,加蒸馏 | 93.0% | 4.4M |
实际操作流程是这样的:先训练教师模型,然后通过剪枝确定学生模型结构,接着用蒸馏恢复精度,最后再做量化。每一步都可以借助教师模型,最终能做到 10 到 20 倍的压缩,精度损失控制在 1% 以内。

案例:DistilBERT#
Sanh 等人(2019)用三重损失(cosine + MLM + KD)把 BERT-base 蒸馏成 DistilBERT,给出了 NLP 领域蒸馏的经典结果:

| BERT-base | DistilBERT | 变化 | |
|---|---|---|---|
| 参数量 | 110M | 66M | -40% |
| 推理延迟 | 410 ms | 250 ms | -39% |
| GLUE(平均) | 79.5 | 77.0 | -3% |
这三个数字让知识蒸馏在 NLP 工程领域真正站稳了脚跟。
完整实现#
| |
各模块的职责#
| 模块 | 作用 |
|---|---|
KDLoss | 带温度的软 KL 散度,与硬标签交叉熵加权融合。 |
FeatureDistillLoss | FitNets:通过 1×1 卷积投影后计算特征图均方误差。 |
AttentionTransferLoss | 匹配通道聚合后的归一化空间注意力图。 |
RelationalLoss | RKD:样本间的距离和角度关系。 |
DistillationTrainer | 统一支持五种蒸馏模式的训练器。 |
self_distill | Born-Again Networks:相同架构的迭代自蒸馏。 |
温度调度策略#

恒定 $\tau = 4$ 是教科书式的默认选择,确实有效。但温度控制的是学生被要求学习的内容,而这一目标会随着训练进程动态变化。训练初期,学生一无所知——此时最受益于观察教师对所有类别的完整排序;训练后期,学生已内化了特征空间的几何结构,需要做出自信的预测。单一标量无法同时满足这两个阶段的需求。
$$\tau(t) \;=\; \tau_{\text{start}} - (\tau_{\text{start}} - \tau_{\text{end}}) \cdot \frac{t}{T},$$从高值(分布平坦,包含丰富“暗知识”)逐渐降至低值(分布尖锐,提供明确监督)。若在 $T$ 个 epoch 中设置 $\tau_{\text{start}} = 20$ 、$\tau_{\text{end}} = 2$ ,学生将在前半段学习类别间的相对相似性,后半段则聚焦于分类决策的锐化。
从机制上看,为何这能奏效?当 $\tau$ 较高时,softmax 近似于 logits 的线性函数,KL 散度对 student logits 的梯度主要由 teacher logits 之间的差异主导——即关系结构;当 $\tau$ 较低时,softmax 近似 one-hot 分布,KL 散度退化为对 teacher argmax 的标准交叉熵——目标虽硬但无噪声。
| |
CIFAR-10 实验结果(ResNet-18 教师蒸馏至 ResNet-8 学生,共 100 个 epoch):
| 调度策略 | 测试准确率 |
|---|---|
| 恒定 $\tau = 4$ | 88.1% |
| 线性 $20 \to 2$ | 89.4% |
| 线性 $50 \to 4$ | 89.7% |
仅将标量替换为函数,即可免费提升 1.3–1.6 个百分点。调度策略的主要收益来自前半段训练——此时与恒定温度基线的差距迅速拉开;最后 20 个 epoch 使用低 $\tau$ 主要起稳定作用。
需注意一点:若 $\tau_{\text{start}}$ 远超 50,软目标将趋近均匀分布——即使有 $\tau^2$ 缩放也无法补偿,梯度相对于交叉熵项变得极小,导致早期训练停滞或震荡。若需“高温”启动,建议将 $\tau$ 上限设为 30 左右,并适当延长调度周期。
过渡:温度决定了学生如何倾听教师的 logits。下一步则是改变学生倾听的内容——即特征及其间的几何关系。
CRD:对比表示蒸馏(Contrastive Representation Distillation)#
我们此前所见的基于特征的 KD 方法(如 FitNets、注意力迁移)均采用逐点匹配:强制学生特征 $f_S(x)$ 逼近教师特征 $f_T(x)$ 。这是一种强约束,且往往浪费资源——它要求学生复现教师的确切激活值,即便下游分类任务仅依赖特征间的相对关系。
对比表示蒸馏(CRD)放松了这一约束。它不再直接匹配 $f_S(x)$ 与 $f_T(x)$ ,而是要求同一输入的师生特征对彼此更接近,而非与其他输入的师生特征对更接近。具体而言,输入 $x$ 的教师表征作为正样本,同 batch 中其他输入的教师表征则作为负样本。
$$\mathcal{L}_{\text{CRD}} \;=\; -\log \frac{\exp\!\left(s(z_s, z_t) / \tau\right)}{\sum_{j} \exp\!\left(s(z_s, z_j^-) / \tau\right)},$$其中 $s(\cdot, \cdot)$ 为余弦相似度,$z_s = g_S(f_S(x))$ ,$z_t = g_T(f_T(x))$ ,而 $g_S, g_T$ 为小型投影头(一层非线性层足矣)。最大化该损失可下界估计师生表征间的互信息。
| |
CIFAR-100 实验(ResNet-50 教师蒸馏至 ResNet-18):
| 方法 | 测试准确率 |
|---|---|
| 硬标签基线 | 73.3% |
| 原始响应 KD | 75.5% |
| + CRD | 76.7% (+1.2) |
当教师与学生容量差异极大时,CRD 提升最为显著——这正是逐点匹配失效的场景:学生根本无法复现教师的特征。CRD 转而要求更弱的条件(相对相似性),反而获得了更好的效果。
过渡:但并非所有蒸馏训练都能顺利收敛——有时损失停滞不前,甚至发散。
蒸馏失败的情形#
蒸馏是一种稳健技术,但并非万能。以下三种失败模式较为常见,值得明确诊断。
$$D_{\text{KL}}(p_t \,\|\, p_s) \;=\; \sum_c p_t(c) \log \frac{p_t(c)}{p_s(c)}$$在前 5–10 个 epoch 内就 plateau 在高位,不再下降。交叉熵可能继续降低(学生学会了 argmax),但未能吸收“暗知识”。例如,一个 10 万参数的学生试图模仿 1 亿参数的教师,几乎立刻就会遇到此问题。
模态不匹配(Modality mismatch):教师与学生接收不同输入——如 RGB 教师蒸馏给灰度学生、高分辨率教师蒸馏给低分辨率学生,或跨语言场景中分词器不一致。教师的软目标编码了学生无法获取的特征,从学生视角看,$p_t$ 几乎等同于噪声。症状:KD 损失高且波动大,联合损失甚至劣于仅用硬标签训练。
在线蒸馏不稳定(Online distillation instability):在在线或互学习(mutual-learning)设置中,教师与学生同步更新。学生追逐一个移动目标,若教师更新幅度过大,KL 散度轨迹会震荡而非衰减。症状:epoch 间 KL 值上下波动达 20–30%,无明显下降趋势。
以下简单诊断工具可识别上述三种情形:
| |
健康训练通常在约 50 个 epoch 内将初始 KL 减半,并持续缓慢下降。若不符合,诊断结果可指导修复:
- 容量不足:增大一级学生宽度或深度,或加入中间层特征蒸馏,使单位参数获得更丰富的信号。若学生规模固定,可先蒸馏至一个中间教师(TA-KD),再由该中间教师蒸馏至目标学生。
- 模态不匹配:蒸馏前对齐输入。将教师输入空间转换为与学生一致,或训练一个小型适配器,将学生特征映射至教师空间后再蒸馏。
- 不稳定:定期冻结教师(每 $k$ 步更新一次,而非每步),降低学生学习率,或在前几个 epoch 内将 KD 权重 $\alpha$ 从 0 逐步 warmup。
共通规律是:KL 轨迹比测试准确率更敏感。它能提前告诉你学生是否真正在听,而不必等到最终评估才知道是否学到了有用知识。
过渡:掌握这些诊断手段后,蒸馏便从一个寄托希望的损失项,转变为可监控的训练流程。剩下的问题则更偏实践——如何选择超参、压缩极限在哪,以及蒸馏如何与模型压缩工具箱中的其他技术协同工作。
常见问题#
温度怎么选?#
从 $\tau = 4$ 开始。类别越多、类别之间越相似(比如细粒度的物种分类),温度就要越高,ImageNet 规模的问题可以调到 $\tau = 20$ 。在验证集上对 $\{2, 4, 8, 12, 20\}$ 做网格搜索就行。
学生模型能压缩到多小?#
4–10 倍压缩基本不掉点;超过 50 倍的话,蒸馏也救不了——通常会掉 5–10 个点。经验规律是:当学生模型刚好有足够的容量表示教师模型的知识,但没能力直接从标签学出来时,蒸馏的效果最好。
自蒸馏为什么有效?学生和教师容量不是一样吗?
两个原因。第一,软目标比 one-hot 标签正则化更强,尤其在数据量小时更明显。第二,每一代训练都会落在损失曲面的不同区域,多次迭代相当于零成本的隐式集成。CIFAR-100 上一般能提升 1–2%,3 代之后收益递减。
可以用多个教师吗?#
可以。把它们的软输出平均一下(均匀加权或者按验证准确率加权)。这样通常能提升鲁棒性,而不是单纯的头部准确率,代价是需要训练多个教师模型。
先剪枝还是先蒸馏?#
两步都做,顺序是这样的:先训练教师模型 → 剪枝确定学生模型结构 → 蒸馏恢复精度 → 量化。每一步都有教师模型作为参考,比单独操作更能保留知识。
总结#
知识蒸馏就是教小模型学会大模型的思维方式:
- 软标签保留了 one-hot 标签丢失的类间结构,这部分信息就是所谓的暗知识。
- 温度是一个标量,用来调节学生模型能学到多少暗知识。
- 特征蒸馏和关系蒸馏不仅关注 logits,还对齐中间表示和样本间的几何关系。
- 自蒸馏不需要额外的教师模型,同时还能免费获得一个集成模型。
- 配合剪枝和量化,蒸馏可以实现 10 到 20 倍的压缩率,精度损失控制在个位数。
下一篇:第六部分 —— 多任务学习 ,多个任务共享参数,提升泛化能力和效率。
参考文献#
- Hinton, Vinyals, Dean (2015). Distilling the Knowledge in a Neural Network. arXiv:1503.02531
- Romero et al. (2015). FitNets: Hints for Thin Deep Nets. ICLR. arXiv:1412.6550
- Zagoruyko & Komodakis (2017). Paying More Attention to Attention. ICLR. arXiv:1612.03928
- Park et al. (2019). Relational Knowledge Distillation. CVPR. arXiv:1904.05068
- Tian et al. (2020). Contrastive Representation Distillation. ICLR. arXiv:1910.10699
- Furlanello et al. (2018). Born-Again Neural Networks. ICML. arXiv:1805.04770
- Sanh et al. (2019). DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter. arXiv:1910.01108
- Zhang et al. (2018). Deep Mutual Learning. CVPR. arXiv:1706.00384