这一次咱们要看的论文是一篇华为的 AAAI2019 的 GAN 和蒸馏相结合的论文。知识蒸馏的基础概念大家可以看这个介绍:知识蒸馏是什么?一份入门随笔。这里,我们先简单的回顾一下知识蒸馏。
什么是知识蒸馏
我们知道在深度学习的大部分网络中,有很多神经元是冗余的,但是在很多移动端,比如手机上,是跑不动这么大的网络的。所以知识蒸馏的一开始的目标是做模型压缩,它的目标就是让一个更小的网络去拟合甚至是超越教师网络的性能。在通常情况下,学生网络的在蒸馏阶段的目标可以用这样的一个函数来表示,这里的损失函数 L 根据算法对知识的定义不同也会有不同的函数表示。
比如 Hinton 在最初的知识蒸馏文章中将教师网络的知识定义为网络输出的每个类别的概率,也就是软标签。所以再这个算法中,损失函数就是教师网络和学生网络输出软标签的 KL 散度。这里 f~T~ 就是教师的输出,f~S~ 是学生的输出,下面的 T 是温度参数,我们通过让学生网络去最小化这个函数,来学习教师网络的知识,从而达到蒸馏的效果。
还有之前提到过的 Hint,通过拟合教师网络和学生网络中间层的输出来达到蒸馏的效果。这里的损失函数 L 就是教师网络和学生网络中间层输出的 L2 距离,这里的 β 是为了让输出大小相同做的一个变换。我们会发现,通过改变损失函数 L,可以得到多种多样的蒸馏方法,但是这些蒸馏方法都有个特点,就是需要教师网络和学生网络在训练的时候有着同样的输入。
Why GAN?
但是在实际应用环境下,很有可能教师端训练时用的是一些隐私性的数据,比如医疗时所用的一些病人的数据,学生网络没有办法去获取这些数据,这样传统的一些蒸馏方法就不能使用了。为了解决这个问题,我们很自然的会想到,是不是可以生成一个假的数据集,然后用这个假数据集作为教师和学生公用的数据集,并在上面做蒸馏操作。在生成假数据方面,近几年最流行的就是生成式对抗网络 GAN 了。这就是为什么我们想把 GAN 和蒸馏去做一个结合。
生成式对抗网络
这里简单的介绍一下生成式对抗网络,GAN 是由两个网络组成的,一个是左边的生成器网络,另一个是右边的判别器网络。生成器接收一个随机向量,并生成一个假样本。以 MNIST 数据集为例,这里的真实数据就是 MNIST 数据集中的手写数字图像,而生成器就需要去生成类似于手写数字的图片。判别器就是一个二分类器,他需要尽可能的去分辨输入的图像是真实的图片还是生成器生成的。当输入的图像是真实图片的时候,判别器需要给他一个很高的分数,当输入的图像是生成器生成的图片的时候,判别器需要给他一个很低的分数。而生成器的目标,就是让自己生成的图像尽可能的被打出一个高分。GAN 就是通过这样一种对抗的形式,来获得生成器和判别器彼此性能的提升。
训练判别器
我们可以通过 GAN 的训练过程来对它做一个更加深入的理解。还是以 MNIST 数据集为例,首先我们需要更新判别器的参数。需要注意的是,在更新判别器参数的时候,生成器的参数是需要完全固定的。还是以 MNIST 数据集为例,我们看上面的绿色虚线框起来的部分,这里输入判别器的是 MNIST 数据集中的一张真实图像,经过判别器之后,会输出其属于真假两个类别的概率,我们假设输出真的概率是 0.6 假的概率是 0.4。对于输入的真实图像,我们希望判别器很确定的将其判断为真,也就是真的概率是 1,假的概率是 0。这样,我们就可以计算两者的一个交叉熵作为判别器的优化目标。这里因为判别器所作的是一个二分类任务,我们可以对交叉熵做进一步的优化,只需要其输出为真的概率,也就是上面的 0.6 尽可能的接近 1 就好了。后面的这个式子中的 D (x) 就是这里的 0.6,我们希望它越大越好。同样的,我们看到下面这里黄色虚线框起来的部分,这里输入生成器的随机向量,在这个例子中我们可以简单的理解为 100 维的服从高斯分布的向量,这个向量提供了一些跟更高维度的信息,生成器通过网络添加更多的细节,从而生成了一张假的图像。所以在这个阶段输入判别器的是生成器生成的加图像,同样的我们也会得到它属于真假两个类别的概率,判别器希望其被判定为真的概率越低越好,也就是这里的 D(G(z))越低越好。通过这里两个优化目标,我们可以对判别器计算损失并反向传播做第一次的参数更新。
训练生成器
在上个阶段,判别器更新完成后,我们把判别器的参数固定,开始训练生成器。同样的,生成器生成一张假图片输入判别器,得到真假两个类别的概率。与前面不同的是,生成器希望自己生成的图片被尽可能的判定为真的,所以它需要让这里的 D (G (Z)) 越大越好。这样,整个生成式对抗网路的优化目标就很好理解了。这个阶段我们更新了生成器的参数,与前面判别器的参数更新放在一起就是 GAN 训练完整的一轮过程。不断的重复这个过程直到网络达到一个收敛的状态。
蒸馏 meets GAN
在了解了生成式对抗网络之后,我们去想它怎么和蒸馏去结合起来。有、有一种很简单很直接的方法就是,直接把教师网络的数据集丢给一个随机初始化的 GAN,让他从头开始训练,然后生成假图片用来蒸馏。这种方法确实没有问题,但是从头开始训练一个 GAN, 是非常非常慢的。我们可以想一下现在所拥有的工具,除了教师端私有的数据集之外,还有一个已经训练好的性能很棒的教师网络。如果可以把这个教师网络利用起来,就可帮助生成模型更快的达到收敛状态,从而降低计算消耗。怎么去用这个教师网络呢,我们知道教师网络是在真实数据上训练出来的,比如说我们用的是 mnist 数据集,那这个预训练好的教师网络就可以很好的提取真实图像的特征,然后用这些特征去分类输入图像。这一点和 GAN 中判别器所作的任务是非常的相似的。所以我们可以把教师网络看作是一个判别器,但是这个判别器和传统的 GAN 的判别器是不一样,他的参数都是预训练好的,而不是随机初始化的,如果还是用原来的训练方法的话,那么随机初始化的生成器无论生成什么样的图片,都会被判别器以很高的置信度判定为假,这样生成器就不知道自己优化的方向了。还有一点就是,教师网络作为判别器做的也不是真假的二分类任务,如果是 mnist 数据集,那他做的就是一个 10 分类,而不是之前我们说过的输出图片属于真假类别的概率。所以对于生成器的优化目标,我们需要做一定的变化。
生成器损失 1
既然生成器是用来生成图片的,我们可以想一想怎么衡量生成器生成图片的好坏呢?这其实是一个非常困难的任务,我们以生成一只猫的图片为例子,可能生成器生成了一只躺着的猫,而真实图像是一只站着的猫,所以不是说生成器生成的图像和数据集中某一张真实的猫的图像一模一样他就是好的,因此也不能的直接对两者去计算 L1L2 损失。我们可以从真实图像的某些特征上出发,因为教师网络是在真实图像上训练出来,所以一个比较好的教师网络,在输入是真实图片的时候,必定会在某一个类别的概率特别的大,在其他类别的概率特别的小。所以我们希望生成器生成的图像也具有这样的性质。这里的 yt 就是对于生成图像,教师网络得到的其属于每个类别的概率,小 t 就是取了概率最大的类别并做成 one-hot 的向量,用交叉熵来衡量他们的相似度。通过最小化这个函数来对生成图像做一个限制。
生成器损失 2
同样的,我们知道卷积神经网络的卷积核就是一个特征提取器,相比于一个随机的向量,真实的输入图片会有更多与之相符合的特征,也就是它被激活的神经元会更多。所以,如果生成的图像和真实图像相似,那么他在中间层被激活的神经元也应该更多,这里用 L1 范数来衡量激活神经元的个数。
生成器损失 3
最后第三项,我们从生成图片本身出发。我们的生成器是随机初始化的,一开始他生成的完全是没有意义的图像,假设生成器经过优化先生成了一张类似于手写数字 0 的图像,只是后判别器给他的分数相对是比较高,这种情况下生成器会觉的 0 这个数字是比较好的,之后无论输入是什么,他都朝着生成一个更好的数字 0 去做。这明显和我们想要的生成器是不一样的,所以对于他在每一个 batch 中生成的图像,我们需要对其做一个限制,最好的情况就是生成的图片在每一个类别上的分布都是均衡的,此时信息量也是最大的,所以我们用这样的一个函数来对它生成图片的多样性做一个限制。加上之前的两项就是生成器总的目标函数,通过这个目标函数,我们可以对生成器进行优化从而得到很多的假图像。有了假图像之后就十分简单了,只需要把之前的蒸馏方法套进来就行了,这里作者选用的是 Hinton 的软标签。
实验
我们看论文的实验结果,首先看红色框框里面的三行,这里三行是说,在教师网络的数据集可以被获取到的情况下,学生网络所能达到的性能。这里教师网络是 ResNet34,学生网络是 ResNet18,数据集用的是 CIFAR10,可以看到第三行 Hinton 的知识蒸馏方法在网络参数量减少一倍的情况下,仍然保持了一个较高的分类准确率。但是到了第四行,在不能得到教师网络数据的情况下,他只能达到 14.89 的准确率。而用作者的方法,则可以达到 92.22 的准确率,这说明作者提出的算法是确实有效的。
这是作者另外作的一个剥离实验,就是去验证之前所说的生成器三项损失的有效性。我们首先看红色的框框,这个是说生成器不进行优化,随机的去生成数据,这种情况下学生网络也能达到 88 的准确率。这个和之前说的 14 的准确率不一样是因为他用的是更加简单的 MNIST 数据集。但是我们看到绿色框框中,单独的使用 one-hot 或者是激活损失,得到结果反而会比随机的要差很多。这是因为缺乏了信息熵损失函数,生成器生成的图片是非常不均衡的,这样学生就无法充分的学习到教师网络的知识。最后我们看到在综合使用三项的情况下的结果是最好的。
此外作者还做了两个额外的实验。左边这张表是说教师网络和学生网络结构一样的情况下,在不同的数据集上用作者的方法所能达到的性能。而右边这个图,是把教师网络和学生网络的卷积核进行可视化后的结果,我们可以看到,在 13456 列学生网络和教师网络卷积核的可视化结果非常的相似,这就说明学生网络确实对教师网络进行了有效的学习。