蒸馏的定义
将预先训练好的复杂模型的 “知识” 转移到较为轻量的模型上的操作称为蒸馏。(这里的 “知识” 在以前被定义为网络的参数,而在文中作者将其定义为从输入到输出的映射。
Hard targets&Soft targets
Hard targets 是指在分类问题中预测的最终目标通常是一个 one-hot 形式的向量,而 Soft targets 是指保留预测结果中属于每一类别的概率(集成模型中可以是几何或者算数平均)
为什么 Soft targets 有效?
模型给出的当前输入属于其他错误类别的概率,即使很小,也隐藏了一定的信息。以 MNIST 分类为例,数字 “2” 的写法不同会导致其像数字 3(后面多出一点)或者数字 7(少画了一点),一个训练好的模型,即使最终的预测结果是 2,其在 3 或者 7 的类别中也会有一定的概率,这个概率表示模型认为这两者有一定的相似性。直接使用 one-hot 形式会丢失掉这部分的信息。
带温度的 Softmax
—————————–
在新的公式中,参数 T 就是温度值,当 T 设置为 1 时,公式与 Softmax 一致。T 值升高会避免原先的 sofrmax 对最大值的凸显,这样就避免了 teacher model 在正确类别上给出的置信度过高。
Student 的损失函数
$$
\huge L =\lambdaL_{soft}+(1-\lambda)L_{hard}
$$
这里如果 $\lambda$ 设置为 1,可以理解为这就是传统的模型压缩,用一个小模型去学习了大模型的预测结果。而后面的一项则是 Student 在此基础上做进一步的自我学习,这样 Student 的能力是完全有可能超过 Teacher 的。
为什么蒸馏会有效
虽然 Teacher 模型更加的复杂,但是其网络中有很多参数是冗余的,也就是说其学得的映射关系并不是看上去这么复杂,因此我们可以用更小的 Student 模型来更加精炼的学习这种映射关系。
算法流程
- 在训练集上训练好一个大模型 Teacher model
- 使用 Teacher model 生成训练集的 soft target
- 利用 Teacher model 的 soft target 以及 ground truth 的 hard target 训练 student model
- 保留 student model 用作线上预测
实验
Preliminary experiments on MNIST
net | layers | hidden units | activiation | regularization | error num(test) |
---|---|---|---|---|---|
net1 | 2 | 1600 | relu | dropout | 67 |
net2 | 2 | 800 | relu | no | 146 |
以 net1 为 teacher,net2 为 Student,测试记过如下:
net | teacher | student | error num(test) |
---|---|---|---|
distilled net | net1 | net2 | 74 |
Experiments on speech recognition
system | Test Frame Accuracy | Word Error Rate on dev set |
---|---|---|
baseline | 58.9% | 10.9% |
10XEnsemble | 61.1% | 10.7% |
Distilled model | 60.8% | 10.7% |
其中 basline 的配置为
- 8 层,每层 2560 个 relu 单元
- softmax 层的单元数为 14000
- 训练样本大小约为 700M,2000 个小时的语音文本数据
蒸馏模型的配置为
- 使用的候选温度为 {1, 2, 5, 10}, 其中 T 为 2 时表现最好
- hard target 的目标函数给予 0.5 的相对权重