Series · Transfer Learning · Chapter 5

迁移学习(五):知识蒸馏

把大模型的能力压进小模型而几乎不掉点:暗知识、温度缩放、响应/特征/关系蒸馏、自蒸馏与多策略实现的完整指南。

你训练了一个 340M 参数的 BERT,准确率 95%。产品需求是把它塞进一台手机,而手机最多能跑 10M 参数。你从头训一个 10M 的小模型,只能到 85%。这时候,知识蒸馏几乎能补上全部差距:让小模型学习大模型的输出分布,而不只是学习硬标签,最终能跑到 92%。

Hinton 提出的核心观察是:教师模型那些"错误"的预测并不是噪声,而是信息。当教师把一张猫的图片打成 0.62 的"猫"、0.14 的"老虎"、0.07 的"狗"、0.008 的"飞机",它其实在告诉你:猫和老虎长得很像,和狗有点像,和飞机毫无关系。这种类别之间的相似性结构——也就是暗知识(dark knowledge)——在 one-hot 标签里完全看不到,而学生模型一旦学到,就能用很小的容量做出远超预期的事。

你将学到什么

  • 为什么软标签携带的信息严格多于硬标签。
  • 温度缩放:一个标量,控制教师暴露多少暗知识。
  • 蒸馏的三大流派:响应式特征式关系式
  • 自蒸馏与互学习:完全不依赖预训练教师也能涨点。
  • 蒸馏与量化、剪枝叠加:把压缩比推到 10 倍以上。
  • 一份干净的 PyTorch 实现,覆盖五种蒸馏模式。

前置: 本系列第 1-2 篇,PyTorch 基础。


蒸馏为什么有效

教师 / 学生在组合软硬目标损失下联合训练

部署的现实约束

大模型在榜单上漂亮,但在手机、车端和云账单上很难看。四个约束把我们往小模型推:

  • 内存: 移动端 / IoT 设备根本装不下几十亿参数。
  • 延迟: 自动驾驶要的是毫秒级响应,不是秒级。
  • 成本: 一个每天被调用十亿次的模型,每个 FLOP 都是真金白银。
  • 能耗: 边缘设备靠电池,不是发电厂。

剪枝和量化直接动模型本体,掉点几乎是必然。蒸馏走的是另一条路:让小学生模仿大教师的输出分布,而不只是模仿它的 argmax。学生继承了教师的归纳偏置,但不必继承它的参数。

暗知识:软标签到底教了什么

把一张猫的图片喂给教师,看一眼输出:

类别硬标签教师输出
1.00.62
老虎0.00.14
豹子0.00.10
0.00.07
汽车0.00.012

硬标签只说"是猫",熵为零。软标签说"是猫,也有点像老虎、有点像豹子、隐隐约约像狗、绝对不是汽车"——熵大于零。这份排序是教师在数百万次梯度更新里学到的、关于类别相似度的免费课程。

one-hot 与教师 softmax:同一个预测,监督信号截然不同

不是匹配标签,是匹配分布

标准训练最小化对硬标签的交叉熵:

$$ \mathcal{L}_{\text{hard}} \;=\; -\sum_c y_c \log \sigma(z_c^S). $$

蒸馏把 $y_c$ 换成教师的 softmax:

$$ \mathcal{L}_{\text{KD}} \;=\; -\sum_c \sigma(z_c^T / \tau) \, \log \sigma(z_c^S / \tau). $$

由于教师是固定的,这个目标等价于最小化 $\mathrm{KL}\!\left(\sigma(z^T/\tau) \,\|\, \sigma(z^S/\tau)\right)$。学生学习的不再是一个标签,而是一整条概率分布。

温度:暗知识的旋钮

直接拿原始 softmax 输出会有个问题:分布太尖,最高类一上来就是 0.99,其他类被埋进了噪声里,暗知识全没了。温度参数 $\tau$ 用来把分布拉平:

$$ \sigma(z_i; \tau) \;=\; \frac{\exp(z_i / \tau)}{\sum_j \exp(z_j / \tau)}. $$
温度效果
$\tau \to 0$退化为 one-hot(argmax)
$\tau = 1$标准 softmax
$\tau = 4$ – 10类间相似度被显式化
$\tau \to \infty$趋近均匀分布

举个具体例子,logits $z = [5, 3, 1]$:

  • $\tau = 1$:$[0.84, 0.11, 0.04]$,第三类几乎消失。
  • $\tau = 3$:$[0.51, 0.31, 0.18]$,第三类的信息被保留。

在高温下,softmax 大致是 logits 的线性函数:

$$ \sigma(z_i; \tau) \;\approx\; \frac{1}{C} + \frac{z_i - \bar z}{C \tau}, $$

也就是说,学生可以直接从教师 logits 的相对大小学知识,不用被指数函数扭曲。

相同 logits、三种温度:从尖锐到平滑

组合损失

实战里通常要两条腿走路:既向教师蒸馏,也锚定真实标签。

$$ \mathcal{L} \;=\; \alpha \cdot \tau^2 \cdot \mathcal{L}_{\text{KD}} \;+\; (1 - \alpha) \cdot \mathcal{L}_{\text{hard}}. $$
  • $\alpha \in [0.5, 0.9]$:你有多信教师。
  • $\tau^2$:补偿高温下的梯度收缩。软目标项的梯度按 $1/\tau^2$ 缩放,因此把损失乘回 $\tau^2$,让两个分量量级可比。
  • $\mathcal{L}_{\text{hard}}$:与真实标签的标准交叉熵。

经验默认值:$\tau = 4$;当教师远强于学生时取 $\alpha = 0.9$,当两者容量接近时取 $\alpha = 0.5$。


响应式蒸馏:只动输出层

经典做法——只匹配教师的输出层,其它什么都不管。

Hinton 的算法

  1. 在完整数据集上训练教师 $T$。
  2. 对每个样本 $x$,算出软标签 $\sigma(z^T(x) / \tau)$。
  3. 用上面的组合损失训练学生 $S$。
  4. 部署学生,推理时用 $\tau = 1$。

ImageNet 上的代表性数字:

设置Top-1
ResNet-34 教师73.3%
ResNet-18 从头训69.8%
ResNet-18 蒸馏71.4%(+1.6)

只是改了个损失函数,免费拿 1.6 个点。

蒸馏 vs. 标签平滑

标签平滑也在"软化"目标:

$$ y_c' \;=\; (1 - \epsilon) \, y_c + \epsilon / C. $$

差别在于软在哪里。标签平滑对所有样本用同一个均匀软化分布。蒸馏对每个样本用各自的、来自教师的软分布——一张波斯猫的图会在"老虎"“豹子"上有权重;一辆轿车的图会在"卡车"“旅行车"上有权重。这正是蒸馏稳定优于标签平滑的原因。


特征蒸馏:让中间层也对齐

响应式蒸馏只匹配最顶层。特征蒸馏顺手把中间表示也匹配上——监督更密集,学生能从更多位置吸收教师的几何结构。

响应蒸馏只对齐 logits;特征蒸馏还对齐中间特征图

FitNets:提示学习

让学生的中间特征匹配教师的中间特征:

$$ \mathcal{L}_{\text{hint}} \;=\; \| W_r \, F_S^l - F_T^l \|_F^2, $$

其中 $W_r$ 是一个可��习的 1x1 投影,把学生的通道维度对齐到教师。Romero 等人用两阶段训练:

  1. 先训学生的浅层 + 投影,使其匹配教师指定的"提示层”;
  2. 冻结浅层,再用标准蒸馏训上层。

注意力传递(Attention Transfer)

不再匹配特征值本身,而是匹配教师关注哪里。沿通道维度聚合,得到空间注意力图:

$$ A(F) \;=\; \sum_c |F_c|^p, \quad p = 2. $$

然后最小化

$$ \mathcal{L}_{\text{AT}} \;=\; \sum_l \left\| \frac{A_S^l}{\|A_S^l\|_2} - \frac{A_T^l}{\|A_T^l\|_2} \right\|_2^2. $$

CIFAR-10 上把 ResNet-110 蒸馏到 ResNet-20:

方法准确率
ResNet-20 基线91.3%
仅响应蒸馏91.8%
注意力传递92.4%

Gram 矩阵蒸馏(NST)

借鉴神经风格迁移,匹配 Gram 矩阵

$$ G \;=\; F^\top F, $$

其中 $G_{ij}$ 表示通道 $i$ 与通道 $j$ 的相关性。这是二阶统计量——相当于在匹配"风格 / 纹理”,而不是"内容"——逐点匹配捕捉不到的部分。


关系蒸馏:把样本之间的关系也传过去

不再单独看每个样本,而是匹配样本之间的关系

RKD:关系知识蒸馏

两种关系:

距离关系。 对样本对 $(x_i, x_j)$,保持嵌入空间里的成对距离:

$$ \mathcal{L}_{\text{dist}} \;=\; \sum_{(i,j)} \ell_\delta\!\left(d_S(i,j),\, d_T(i,j)\right). $$

角度关系。 对三元组 $(x_i, x_j, x_k)$,保持 $x_j$ 处的夹角:

$$ \mathcal{L}_{\text{angle}} \;=\; \sum_{(i,j,k)} \ell_\delta\!\left(\angle_S(i,j,k),\, \angle_T(i,j,k)\right). $$

经验上,角度比距离更重要($\lambda_{\text{angle}} = 2$,$\lambda_{\text{dist}} = 1$),因为角度对尺度不敏感,描述的是相对几何结构。

CRD:对比表示蒸馏

把教师的表示当作"正样本"视图,其他样本当作负样本:

$$ \mathcal{L}_{\text{CRD}} \;=\; -\log \frac{\exp\!\left(f_S(x)^\top f_T(x) / \tau\right)}{\sum_{x'} \exp\!\left(f_S(x)^\top f_T(x') / \tau\right)}. $$

本质是最大化学生与教师特征之间的互信息。学生越小(比如把 ResNet-32 蒸到 ResNet-8),CRD 相对响应蒸馏的优势越明显——CIFAR-100 上能多 2 个点以上。


自蒸馏:连教师都不需要

要是手头根本没有大教师呢?还能蒸。

Born-Again Networks:每一代都用上一代相同架构的模型作教师

Born-Again Networks

一个反直觉的发现:把模型蒸馏到架构完全相同的副本上,居然也能涨点。

  1. 正常训练 $M_1$。
  2. 用 $M_1$ 当教师训 $M_2$(同架构)。
  3. 用 $M_2$ 当教师训 $M_3$。
  4. 准确率饱和后停止。

CIFAR-100:

代次准确率
1(基线)74.3%
2(BAN)75.2%
375.4%
475.5%

两条互补的解释:(1)软目标比 one-hot 提供更平滑的梯度,正则效应明显,特别是在数据量不大时;(2)每一代落入损失曲面的不同盆地,多代迭代等价于一次零推理代价的隐式集成。

Deep Mutual Learning:互学习

让 $M$ 个学生同时训练,每个都把别人当老师:

$$ \mathcal{L}_i \;=\; \mathcal{L}_{\text{CE}}^i + \frac{1}{M-1} \sum_{j \neq i} \mathrm{KL}\!\left(P_j \,\|\, P_i\right). $$

不需要预训练。不同的初始化会犯不同的错;相互监督让每个模型吸收彼此的强项。CIFAR-100 上,两个 ResNet-32 联训各达 72.1%,单训只有 70.2%。


蒸馏 + 压缩

蒸馏和其他压缩手段是好搭档,不是对手。

量化感知蒸馏

把 FP32 量化到 INT8,能省 4 倍内存、获得 2-4 倍硬件加速,但代价是掉点。蒸馏能把这个代价砍一半:

ResNet-18 / ImageNetTop-1
FP32 基线69.8%
INT8,无蒸馏68.5%(-1.3)
INT8,加蒸馏69.2%(-0.6)

剪枝感知蒸馏

按通道重要性剪掉最弱的若干通道,再用教师软标签微调:

VGG-16 / CIFAR-10准确率参数量
原始93.5%14.7M
剪 70%,无蒸馏92.1%4.4M
剪 70%,加蒸馏93.0%4.4M

实操管线:训教师 → 剪枝定义学生结构 → 蒸馏���复精度 → 量化。每一步都还能依赖教师,因此累积压缩比可以推到 10-20 倍而精度损失控制在 1% 以内。

参数量轴上,蒸馏曲线严格压制从头训练曲线

案例:DistilBERT

Sanh 等人(2019)用三重损失(cosine + MLM + KD)把 BERT-base 蒸馏到 DistilBERT,定下了 NLP 蒸馏的工业级基准:

DistilBERT:参数 -40%、推理 -60%、保留 97% 的 GLUE

BERT-baseDistilBERT变化
参数量110M66M-40%
推理延迟410 ms250 ms-39%
GLUE(平均)79.577.0-3%

这三个数字让"知识蒸馏"在 NLP 工程界第一次真正普及。


完整实现

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from typing import List
import copy


# ===== 损失函数 =====

class KDLoss(nn.Module):
    """响应蒸馏:带温度的 KL 散度 + 硬标签 CE。"""
    def __init__(self, temperature: float = 4.0, alpha: float = 0.9):
        super().__init__()
        self.T = temperature
        self.alpha = alpha

    def forward(self, student_logits, teacher_logits, labels):
        kd = F.kl_div(
            F.log_softmax(student_logits / self.T, dim=1),
            F.softmax(teacher_logits / self.T, dim=1),
            reduction='batchmean',
        ) * (self.T ** 2)
        ce = F.cross_entropy(student_logits, labels)
        return self.alpha * kd + (1 - self.alpha) * ce


class FeatureDistillLoss(nn.Module):
    """FitNets:1x1 投影后做特征图 MSE。"""
    def __init__(self, student_ch: int, teacher_ch: int):
        super().__init__()
        self.proj = nn.Conv2d(student_ch, teacher_ch, 1, bias=False)

    def forward(self, student_feat, teacher_feat):
        return F.mse_loss(self.proj(student_feat), teacher_feat)


class AttentionTransferLoss(nn.Module):
    """对通道做 L^p 聚合并归一化的空间注意力图匹配。"""
    def __init__(self, p: float = 2.0):
        super().__init__()
        self.p = p

    def _attn(self, feat):
        a = torch.sum(torch.abs(feat) ** self.p, dim=1, keepdim=True)
        return a / (a.sum(dim=[2, 3], keepdim=True) + 1e-8)

    def forward(self, s_feat, t_feat):
        return F.mse_loss(self._attn(s_feat), self._attn(t_feat))


class RelationalLoss(nn.Module):
    """RKD:成对距离 + 余弦(角度)关系。"""
    def __init__(self, w_dist: float = 1.0, w_angle: float = 2.0):
        super().__init__()
        self.w_dist, self.w_angle = w_dist, w_angle

    def forward(self, s_feat, t_feat):
        s = F.normalize(s_feat, p=2, dim=1)
        t = F.normalize(t_feat, p=2, dim=1)
        d_loss = F.mse_loss(torch.cdist(s, s), torch.cdist(t, t))
        a_loss = F.mse_loss(s @ s.t(), t @ t.t())
        return self.w_dist * d_loss + self.w_angle * a_loss


# ===== 模型(同时返回 logits 与中间特征)=====

class _ResNetBackbone(nn.Module):
    def __init__(self, ctor, num_classes: int):
        super().__init__()
        self.net = ctor(weights=None)
        self.net.fc = nn.Linear(self.net.fc.in_features, num_classes)

    def forward(self, x):
        n = self.net
        x = n.maxpool(n.relu(n.bn1(n.conv1(x))))
        feats: List[torch.Tensor] = []
        for layer in (n.layer1, n.layer2, n.layer3, n.layer4):
            x = layer(x)
            feats.append(x)
        logits = n.fc(torch.flatten(n.avgpool(x), 1))
        return logits, feats


class TeacherNet(_ResNetBackbone):
    def __init__(self, num_classes: int = 10):
        super().__init__(torchvision.models.resnet34, num_classes)


class StudentNet(_ResNetBackbone):
    def __init__(self, num_classes: int = 10):
        super().__init__(torchvision.models.resnet18, num_classes)


# ===== 训练器(response / feature / attention / relation / combined)=====

class DistillationTrainer:
    def __init__(self, teacher, student, device='cpu',
                 mode: str = 'response',
                 temperature: float = 4.0, alpha: float = 0.9):
        self.teacher = teacher.to(device).eval()
        self.student = student.to(device)
        self.device = device
        self.mode = mode
        for p in self.teacher.parameters():
            p.requires_grad = False

        self.kd = KDLoss(temperature, alpha)
        if mode in ('feature', 'combined'):
            ch = [64, 128, 256, 512]
            self.feat = nn.ModuleList(
                [FeatureDistillLoss(s, t).to(device)
                 for s, t in zip(ch, ch)])
        if mode in ('attention', 'combined'):
            self.attn = AttentionTransferLoss()
        if mode in ('relation', 'combined'):
            self.rel = RelationalLoss()

    def _loss(self, s_logits, s_feats, t_logits, t_feats, y):
        loss = self.kd(s_logits, t_logits, y)
        if self.mode in ('feature', 'combined'):
            fl = sum(fn(sf, tf)
                     for fn, sf, tf in zip(self.feat, s_feats, t_feats))
            loss = loss + 0.5 * fl / len(s_feats)
        if self.mode in ('attention', 'combined'):
            al = sum(self.attn(sf, tf)
                     for sf, tf in zip(s_feats, t_feats))
            loss = loss + 0.3 * al / len(s_feats)
        if self.mode in ('relation', 'combined'):
            loss = loss + 0.2 * self.rel(
                s_feats[-1].flatten(1), t_feats[-1].flatten(1))
        return loss

    def train_epoch(self, loader, optimizer):
        self.student.train()
        total, correct, n = 0.0, 0, 0
        for x, y in loader:
            x, y = x.to(self.device), y.to(self.device)
            with torch.no_grad():
                t_logits, t_feats = self.teacher(x)
            s_logits, s_feats = self.student(x)
            loss = self._loss(s_logits, s_feats, t_logits, t_feats, y)
            optimizer.zero_grad(); loss.backward(); optimizer.step()
            total += loss.item() * y.size(0)
            correct += (s_logits.argmax(1) == y).sum().item()
            n += y.size(0)
        return total / n, 100.0 * correct / n

    @torch.no_grad()
    def evaluate(self, loader):
        self.student.eval()
        correct, n = 0, 0
        for x, y in loader:
            x, y = x.to(self.device), y.to(self.device)
            logits, _ = self.student(x)
            correct += (logits.argmax(1) == y).sum().item()
            n += y.size(0)
        return 100.0 * correct / n


# ===== 自蒸馏:Born-Again Networks =====

def self_distill(model_class, train_loader, test_loader,
                 num_generations: int = 3, epochs_per_gen: int = 10,
                 device: str = 'cpu', temperature: float = 4.0):
    teacher = None
    for gen in range(num_generations):
        student = model_class().to(device)
        opt = optim.SGD(student.parameters(), lr=0.1,
                        momentum=0.9, weight_decay=5e-4)
        sched = optim.lr_scheduler.CosineAnnealingLR(opt, epochs_per_gen)

        for _ in range(epochs_per_gen):
            student.train()
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = student(x)
                loss = F.cross_entropy(logits, y)
                if teacher is not None:
                    with torch.no_grad():
                        t_logits, _ = teacher(x)
                    kd = F.kl_div(
                        F.log_softmax(logits / temperature, dim=1),
                        F.softmax(t_logits / temperature, dim=1),
                        reduction='batchmean') * temperature ** 2
                    loss = 0.1 * loss + 0.9 * kd
                opt.zero_grad(); loss.backward(); opt.step()
            sched.step()

        student.eval()
        correct, n = 0, 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                correct += (student(x)[0].argmax(1) == y).sum().item()
                n += y.size(0)
        print(f"Generation {gen + 1}: {100.0 * correct / n:.2f}%")

        teacher = copy.deepcopy(student).eval()
        for p in teacher.parameters():
            p.requires_grad = False


# ===== 主流程 =====

def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    norm = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                (0.2023, 0.1994, 0.2010))
    train_tf = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), norm])
    test_tf = transforms.Compose([transforms.ToTensor(), norm])

    train_set = torchvision.datasets.CIFAR10(
        './data', train=True, download=True, transform=train_tf)
    test_set = torchvision.datasets.CIFAR10(
        './data', train=False, download=True, transform=test_tf)
    train_loader = DataLoader(train_set, 128, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_set, 128, num_workers=2)

    teacher = TeacherNet(10)
    student = StudentNet(10)
    trainer = DistillationTrainer(
        teacher, student, device, mode='combined',
        temperature=4.0, alpha=0.7)
    opt = optim.SGD(student.parameters(), lr=0.1,
                    momentum=0.9, weight_decay=5e-4)
    for epoch in range(20):
        loss, acc = trainer.train_epoch(train_loader, opt)
        test_acc = trainer.evaluate(test_loader)
        print(f"Epoch {epoch + 1}: loss={loss:.4f} "
              f"train={acc:.1f}% test={test_acc:.1f}%")

    print("\n自蒸馏(Born-Again Networks):")
    self_distill(StudentNet, train_loader, test_loader,
                 num_generations=3, epochs_per_gen=10, device=device)


if __name__ == '__main__':
    main()

各模块的职责

模块作用
KDLoss带温度的软 KL,与硬标签 CE 加权融合。
FeatureDistillLossFitNets:1x1 投影后做特征图 MSE。
AttentionTransferLoss通道聚合 + 归一化的空间注意力匹配。
RelationalLossRKD:样本间距离与角度关系。
DistillationTrainer一个训练器统一支持五种模式。
self_distillBorn-Again Networks:相同架构的迭代自蒸馏。

常见问题

温度怎么选?

默认 $\tau = 4$。类别数越多、类别越相似(比如细粒度物种分类),温度就要越高,ImageNet 量级问题里可以推到 $\tau = 20$。在验证集上对 $\{2, 4, 8, 12, 20\}$ 做网格搜索就够。

学生最多能压多小?

4-10 倍压缩几乎不掉点;超过 50 倍蒸馏也救不回来——大概率掉 5-10 个点。一个经验法则:当学生有足够容量去表示教师学到的东西、但没有足够容量自己从标签学到时,蒸馏的收益最大。

自蒸馏为什么有效?学生跟教师容量一样啊。 两个原因。一是软目标比 one-hot 是更强的正则,尤其在数据不大时;二是每一代落入损失曲面的不同盆地,多代叠加相当于免推理代价的隐式集成。CIFAR-100 上典型涨点 1-2%,3 代之后基本饱和。

多个教师可以用吗?

可以——把它们的软输出平均(均匀或按各自验证准确率加权)。这通常买到的是鲁棒性,而不是头部准确率,代价是要训练多个教师。

先剪枝还是先蒸馏?

两者都做,按这个顺序:训教师 → 剪枝定义学生结构 → 蒸馏恢复精度 → 量化。每一步后面都还有教师可以靠,比独立做每一步保留更多知识。


小结

知识蒸馏是教小模型像大模型那样思考的艺术:

  • 软标签包含 one-hot 标签丢掉的类间结构,这就是暗知识。
  • 温度是一个标量,控制学生能看到多少暗知识。
  • 特征关系蒸馏越过 logits,对齐中间表示和成对几何。
  • 自蒸馏不需要外部教师,附赠一次免费集成。
  • 叠加剪枝、量化之后,蒸馏能把压缩比推到 10-20 倍而精度损失控制在个位数。

下一篇:第六部分 —— 多任务学习 ,多个任务共享参数,提升泛化和效率。


参考文献

  1. Hinton, Vinyals, Dean (2015). Distilling the Knowledge in a Neural Network. arXiv:1503.02531
  2. Romero et al. (2015). FitNets: Hints for Thin Deep Nets. ICLR. arXiv:1412.6550
  3. Zagoruyko & Komodakis (2017). Paying More Attention to Attention. ICLR. arXiv:1612.03928
  4. Park et al. (2019). Relational Knowledge Distillation. CVPR. arXiv:1904.05068
  5. Tian et al. (2020). Contrastive Representation Distillation. ICLR. arXiv:1910.10699
  6. Furlanello et al. (2018). Born-Again Neural Networks. ICML. arXiv:1805.04770
  7. Sanh et al. (2019). DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter. arXiv:1910.01108
  8. Zhang et al. (2018). Deep Mutual Learning. CVPR. arXiv:1706.00384

系列导航

部分主题
1基础与核心概念
2预训练与微调
3域适应
4小样本学习
5知识蒸馏(本文)
6多任务学习
7零样本学习
8多模态迁移
9参数高效微调
10持续学习
11跨语言迁移
12工业应用与最佳实践

Liked this piece?

Follow on GitHub for the next one — usually one a week.

GitHub