在做工作总结的时候,老师问我们看了这么多联邦学习和蒸馏的论文,有没有什么自己的想法呢?具体的比如怎么将蒸馏应用到联邦学习中呢?确实,在之前的工作中大部分都是带任务性的 “被动” 的阅读,以致没有注意到这样基础的问题,这一点以后需要多加注意。这一篇文章就是将蒸馏与联邦学习相结合的一种方法,思路也非常的简单。让我们一起来看一下。
FedAvg
首先回顾一下 FedAvg 算法的流程:
- 在每一轮更新开始之前,随机的选取部分客户端,比例为 C (C≤1)
- 服务端将当前的全局算法状态(模型参数)发送给这些客户
- 每个客户端以步骤 2 的参数做初始化,在本地数据集上执行若干个 E (epoch) 的更新
- 客户端将更新后的参数上传给服务端,服务端计算参数的加权平均并更新
- 重复步骤 1-4
我们注意到在 FedAvg 算法中,服务端和客户端之间传递是整个模型的参数,这样,当我的模型很大的时候,需要传递的参数也非常多,需要很大的通信成本。同时这种简单的取参数平均的方法,必须要求服务端以及各个客户端的模型是完全一样的。这在实际的操作中是非常困难的。
使用蒸馏 & 迁移学习
在学习了蒸馏相关的知识之后,我们很自然的会想到:对于客户端所学习到的东西,我们可不可以不在局限于模型的参数,而是之前所学习过的软标签、中间层输出 Hint 或者是层与层之间的关系 FSP 矩阵呢?这里我们以软标签为例,对应到联邦学习的话就是,服务端得到了各个客户端的软标签,然后计算了一个加权平均值,并把这个值告诉客户端让他们自己去更新。这个过程的其实存在有一个很大的问题,我们知道在蒸馏当中,为了去拟合学生的输出,有一个必要的前提就是,我的教师和学生所用的数据集必须是一样的。而在联邦学习中,每个客户端的数据集通常是 non-IID 的,也就是说我如果想用软标签来更新,就必须有一个公共的数据集,这个数据集是各个客户端数据集的集合,这样的话,就和联邦学习的初衷相违背了,那么怎么解决这个问题呢?既然我需要一个公共的数据集,这个数据集又不能和本地的数据集一样,那我是不是可以使用一些公开的数据集,比如,ImageNet。这个数据集和我客户端需要处理的图像是完全不一样的。所以,要使用这个数据集,就必须用到迁移学习的方法。知道了这些,我们就可以理解他的整个算法了。
FedMD
下面我们看一下他的整个算法流程。这里客户端我画了不同的形状表示他们可以使用不同的模型。首先需要对客户端做一个初始化,每一个客户端都需要在公共数据集上进行充分的训练到达收敛的状态,然后用迁移学习的方法在自己的私有数据集上进行训练。初始化之后就可以开始联邦学习了,每个客户端需要给出公共数据集属于各个类别的概率,就是软标签,这里为了减少通信量可以选取部分公共数据集。然后服务端对这些软标签做一个平均传回客户端,将这个平均值作为客户端需要学习的知识。客户端需要在公共数据集上去拟合这个平均。然后继续在自己的私有数据集上训练几个 epoch。这样一个过程作为一轮,重复这个过程直到各个客户端的模型达到收敛状态。
实验结果
最后是实验部分,作者在两种数据集上做了实验,结果差不多,这里以其中一种为例。公共数据集是手写数字 MNIST,客户端私有的数据集是 FEMNIST,这个和手写数字差不多,是手写的字母。每个客户端采用的是 2-4 层的 CNN,他们使用不同的通道数以及 dropout。图中的折线表示的就是用了 FedMD 之后每个客户端达到的测试准确率。左下方的虚线是指不使用联邦学习,客户端只在自己的私有数据以及公有数据集上训练的结果,这也是 baseline。右上角的虚线表示其他的客户端将数据发送给当前客户端,也就是每个客户端拥有全部的数据所能达到的实验结果,这也是我们需要取接近或者超越的目标。可以看到 FedMD 算法达到了比 baseline 更好的性能但是和拥有全部数据集的情况还是有差距。