Distilling the Knowledge in a Neural Network

蒸馏的定义

将预先训练好的复杂模型的 “知识” 转移到较为轻量的模型上的操作称为蒸馏。(这里的 “知识” 在以前被定义为网络的参数,而在文中作者将其定义为从输入到输出的映射。

Hard targets&Soft targets

Hard targets 是指在分类问题中预测的最终目标通常是一个 one-hot 形式的向量,而 Soft targets 是指保留预测结果中属于每一类别的概率(集成模型中可以是几何或者算数平均)

为什么 Soft targets 有效?

模型给出的当前输入属于其他错误类别的概率,即使很小,也隐藏了一定的信息。以 MNIST 分类为例,数字 “2” 的写法不同会导致其像数字 3(后面多出一点)或者数字 7(少画了一点),一个训练好的模型,即使最终的预测结果是 2,其在 3 或者 7 的类别中也会有一定的概率,这个概率表示模型认为这两者有一定的相似性。直接使用 one-hot 形式会丢失掉这部分的信息。

带温度的 Softmax

Snipaste_2020-07-07_16-35-37.png—————————–Snipaste_2020-07-07_16-36-18.png

在新的公式中,参数 T 就是温度值,当 T 设置为 1 时,公式与 Softmax 一致。T 值升高会避免原先的 sofrmax 对最大值的凸显,这样就避免了 teacher model 在正确类别上给出的置信度过高。

Student 的损失函数

$$
\huge L =\lambdaL_{soft}+(1-\lambda)L_{hard}
$$

这里如果 $\lambda$ 设置为 1,可以理解为这就是传统的模型压缩,用一个小模型去学习了大模型的预测结果。而后面的一项则是 Student 在此基础上做进一步的自我学习,这样 Student 的能力是完全有可能超过 Teacher 的。

为什么蒸馏会有效

虽然 Teacher 模型更加的复杂,但是其网络中有很多参数是冗余的,也就是说其学得的映射关系并不是看上去这么复杂,因此我们可以用更小的 Student 模型来更加精炼的学习这种映射关系。

算法流程

image.png

  1. 在训练集上训练好一个大模型 Teacher model
  2. 使用 Teacher model 生成训练集的 soft target
  3. 利用 Teacher model 的 soft target 以及 ground truth 的 hard target 训练 student model
  4. 保留 student model 用作线上预测

实验

Preliminary experiments on MNIST

net layers hidden units activiation regularization error num(test)
net1 2 1600 relu dropout 67
net2 2 800 relu no 146

以 net1 为 teacher,net2 为 Student,测试记过如下:

net teacher student error num(test)
distilled net net1 net2 74

Experiments on speech recognition

system Test Frame Accuracy Word Error Rate on dev set
baseline 58.9% 10.9%
10XEnsemble 61.1% 10.7%
Distilled model 60.8% 10.7%

其中 basline 的配置为

  • 8 层,每层 2560 个 relu 单元
  • softmax 层的单元数为 14000
  • 训练样本大小约为 700M,2000 个小时的语音文本数据

蒸馏模型的配置为

  • 使用的候选温度为 {1, 2, 5, 10}, 其中 T 为 2 时表现最好
  • hard target 的目标函数给予 0.5 的相对权重
丨fengche丨 wechat
来找我聊天吧~
-------------本文结束感谢您的阅读-------------