这是蒸馏论文的第四篇,将注意力机制引入知识蒸馏中,总体来说还算比较好懂。作者提出了基于 Activation 和基于 Gradient 两种注意力迁移的方法。
ACTIVATION-BASED ATTENTION TRANSFER
正如图中所示,作者将 HxWxC 的特征图通过计算后得到 attention map,这中计算有一个潜在假设,就是隐藏层神经元的绝对值可以用作表示这个神经元的重要性。这样,我们就可以计算通道维度的统计量,对此作者提出了三种计算方法:绝对值求和、绝对值指数求和以及绝对值指数求最大值
通过计算得出 attention map 后,我们很容易就能理解以下 Student 网络的损失函数
这里前面的 $L (W_s,x)$ 表示标准的交叉熵,即 hard target。后面的 $Q_S^j$ 和 $Q_T^j$ 分别表示 Student 和 Teacher 的 attention map,将他们做归一化并取 p 范数就是 soft target (我的表述不够准确,这一部分就是 Student 在拟合 Teacher 所学到的知识)
从上面的公式我们知道,Student 和 Teacher 的 attention map 必须是等大的,作者给出了 Resnet 网络下 Attention loss 所在的位置
GRADIENT-BASED ATTENTION TRANSFER
这一部分看的还是有些迷迷糊糊,以下讲一下我自己的理解,大家可以做个参考
作者认为网络输出对于输入的梯度,代表了网络对于每一个像素的敏感度。首先定义出两者对于输入的梯度:
那么 Student 学习的目标就是让自己的 $J_S$ 与 $J_T$ 尽可能的相似,所以定义 Student 的损失函数如下:
这个公式中 $W_T$ 和 $x$ 都是固定的,我们计算其对 $W_S$ 的梯度:
实际上在训练过程中,Teacher 网络的 $J_T$ 是可以事先计算好的,而对于 Student 网络,我们需要先对其进行前向传播,根据交叉熵损失进行反向传播计算出 $J_S$,然后根据 $L_{AT}$ 的计算公式进行计算并再次进行反向传播得到梯度,这个梯度就是 Student 本轮需要优化的方向。
此外,为了强制梯度 Attention 的水平翻转不变性,作者把原图和水平翻转的图像都进行传播,最后的损失函数如下(这个没太懂是翻转后的图像再一次计算,还和合并了两次操作):
完结撒花~