FD

FD

- 1 min

参考: 联邦知识蒸馏概述与思考(续)-腾讯云开发者社区-腾讯云 (tencent.com)

知识蒸馏

知识蒸馏是一种模型压缩方法, 通过利用复杂模型(Teacher Model)强大的表征学习能力帮助简单模型(Student Model)进行训练, 主要分为两个步骤:

​ 提取复杂模型的知识, 在这里知识的定义有很多种, 可以使预测的logits, 中间层的输出feature map, 也可以是中间层的Attention map, 主要就是反映了Teacher Model的学习能力, 是一种表征的体现;

​ 将知识迁移/蒸馏到学生模型上去, 迁移的方式也有很多种, 主要是各种loss function的实现, 有L1loss, L2loss以及KL loss等手段.

知识蒸馏可以在保证模型的性能前提下,大幅度的降低模型训练过程中的通信开销和参数数量,知识蒸馏的目的是通过将知识从深度网络转移到一个小网络来压缩和改进模型。

这很适用于联邦学习, 因为联邦学习就是基于服务器-客户端的架构, 本身就需要通信.

FL-FD 数据增强的联邦蒸馏算法

在联邦学习中, 每个设备端执行训练过程都需要与模型大小成比例的通信开销, 因此禁止使用大型模型.

作者提出的联邦蒸馏(FD)算法, 使一种分布式在线知识蒸馏方法, 其通信有效成本的大小不取决于模型大小, 而是取决于输出尺寸.

在进行联邦蒸馏之前, 要先通过联邦增强(FAug)来纠正niid训练数据集: 这是一种使用生成对抗网络(GAN)进行的数据增强方案, 该数据增强方案在隐私泄露和通信开销之间可以进行权衡取舍. 经过训练的GAN可以使每个设备在本地生成所有设备的数据样本, 从而使训练数据集成为IID分布.

联邦蒸馏(FD): 在FD中, 每台设备都将自己视为学生, 并将其他所有设备的平均模型输出视为Teacher Model的输出. 每个模型输出是一组通过softmax函数归一化后的logit值, 其大小由标签数给出.

使用交叉熵来周期性地测量师生的输出差异,交叉熵成为学生的损失调整器,称为蒸馏调整器,从而在培训过程中获得其他设备的知识,具体损失是:KDLoss(Local_Logit,Global_Logit)+CELoss(Local_Logit,Local_Lable)。FD中的每个设备都存储着本地每个标签的平均logit向量,并定期将这些本地平均logit向量上载到服务器。

服务器将从所有设备上载的本地平均Logit向量平均化,从而得出每个标签的全局平均Logit向量。所有标签的全局平均logit向量被下载到每个设备。然后,当每台设备进行蒸馏的时候,其教师的输出为与当前训练样本的标签具有相同标签的全局平均logit向量。

事实上,模型的输出精度会随着训练的进行而增加,因此,在局部logit平均过程中,最好采用加权平均值随着局部计算时间的增加而增加,即当模型采用整体损失函数:a * KDLoss(Local_Logit,Global_Logit)+CELoss(Local_Logit,Local_Lable) * (1-a),随着迭代次数的增加,a应该逐渐减小(模型的输出精度会随着训练的进行而增加,所以本地模型比重应该增大)。

img

总结一下FL-FD算法的过程

1)每个设备都把自己当作一个学生,并将所有其他设备的平均模型输出视为其老师的输出;

2)FD中的每个设备存储每个标签的平均logit向量,并定期将这些本地平均logit向量上传到服务器;

3)对于每个标签,对所有设备上传的本地平均logit向量进行平均,从而得到每个标签的全局平均logit向量;

4)所有标签的全局平均logit向量都被下载到每个设备上,进行蒸馏损失计算,其教师的输出被选择为与当前训练样本的标签相同的全局平均logit向量。

联邦增强

联邦增强(FAvg):因为蒸馏最好在具有相同数据集的效果下进行,由于不同设备之间具有异质性所以在蒸馏前进行数据增强可以提升蒸馏效果。FAug中每个设备都可以识别数据样本中缺少的标签,称为目标标签,并通过无线链路将这些目标标签的少量种子数据样本上载到服务器。

服务器则会通过例如Google视觉数据图像搜索等方法对上传的种子数据样本进行超采样,并使用这些数据来训练一个GAN。

最后,下载经过训练的GAN生成器使每个设备补充目标标签,直到达到IID训练数据集为止。FAug的操作需要确保用户生成的数据的私密性。

实际上,每台设备的数据生成偏差(即目标标签)都可以轻松地显示其隐私敏感信息,为了使这些目标标签对服务器不公开,每个设备还将从目标标签以外的其他标签进行上载(冗余数据样本),由此减少了从每个设备到服务器的隐私泄漏。

comments powered by Disqus