System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 面向非独立同分布数据的联邦知识蒸馏方法及装置制造方法及图纸_技高网

面向非独立同分布数据的联邦知识蒸馏方法及装置制造方法及图纸

技术编号:40146786 阅读:27 留言:0更新日期:2024-01-24 00:28
本申请涉及一种面向非独立同分布数据的联邦知识蒸馏方法及装置,其包括根据公共数据集进行随机采样,获取辅助数据集;基于预设的优化函数以及辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;将生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入生成网络模型,得到生成网络数据;控制客户端基于预设的数据融合算法、生成网络数据以及预设的本地数据进行数据融合,获取融合数据;控制客户端根据预设的局部模型蒸馏算法以及融合数据对深度学习模型进行优化训练,得到全局模型,本申请通过生成网络模型和局部模型蒸馏算法对客户端的深度学习模型进行优化,减少深度学习模型的优化目标与全局优化目标的偏差。

【技术实现步骤摘要】

本申请涉及数据安全,尤其是涉及一种面向非独立同分布数据的联邦知识蒸馏方法及装置


技术介绍

1、随着互联网、物联网、云计算和大数据等各种技术的快速发展,企业面临海量的数据处理与分析,数据的搜集、共享、发布和分析过程中可能导致用户隐私信息的泄露,给用户带来巨大损失。同时,全球数据保护法规越来越严格,企业在使用数据过程中面临隐私泄露和数据违规风险。因此,隐私计算技术变得越发重要。

2、联邦学习是一种新兴的人工智能技术,最初由谷歌在2016年提出,旨在解决个人数据在安卓手机端的隐私问题。该技术的设计动机是保护手机或平板计算机中用户的隐私数据,因此提出了一种数据不动模型动的新型分布式机器学习范式。联邦学习可以看成是一种分布式机器学习框架,与传统的分布式机器学习框架不同,其使用了加密技术,并且各方数据保存在本地。在联邦学习中,各个参与方(例如手机、平板计算机等设备)将本地数据进行计算和更新,然后将结果发送回中央服务器进行聚合。联邦学习体现了集中数据收集和最小化的原则,可以减轻传统集中式机器学习和数据挖掘方法带来的系统和统计层面上的隐私风险和通信效率开销。

3、针对上述中的相关技术,由于联邦学习系统中各个客户端通过不同的硬件或软件设备收集并处理数据,因此客户端之间的数据分布往往是差异极其大的,并进一步导致各客户端深度学习模型的参数不一致。各客户端深度学习模型的优化目标与全局优化目标存在偏差,在模型训练时会远离最优点,从而导致模型在效率、效果、隐私保护层面上都不能达到一个很好的效果。


技术实现思路

1、为了改善各客户端深度学习模型的优化目标与全局优化目标存在偏差,在模型训练时会远离最优点,从而导致模型在效率、效果、隐私保护层面上都不能达到一个很好的效果的问题,本申请提供一种面向非独立同分布数据的联邦知识蒸馏方法及装置。

2、第一方面,本申请提供的一种面向非独立同分布数据的联邦知识蒸馏方法,采用如下的技术方案:包括:

3、根据预设的公共数据集进行随机采样,获取辅助数据集;

4、基于预设的优化函数以及所述辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;

5、将所述生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入所述生成网络模型,得到生成网络数据;

6、控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取融合数据;

7、控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户的深度学习模型进行优化训练,得到全局模型。

8、可选的,所述优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数。

9、可选的,所述对抗目标损失函数的计算公式为:

10、;

11、其中,为所述辅助数据集中的数据样本,为所述噪声向量,为所述生成网络,和则分别代表所述生成网络和所述鉴别网络的模型参数。

12、可选的,所述互信息平滑损失函数的计算公式为:

13、;

14、其中,代表一次批处理过程中所述噪声向量的数量。

15、可选的,所述相似度惩罚损失函数的计算公式为:

16、;

17、其中,和代表重复采样过程中不同的噪声向量。

18、可选的,所述控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取所述融合数据,包括:

19、基于所述生成网络模型生成的所述生成网络数据和客户端的所述本地数据通过所述数据融合算法进行融合,得到所述融合数据;

20、其中,所述数据融合算法的计算公式为:

21、;

22、;

23、;

24、其中,为基于随迭代次数从最小值0增加到最大值0.5的动量参数,为样本的伪标签,和为合成后的数据样本和标签。

25、可选的,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对所述生成网络模型进行优化训练,得到全局模型,包括:

26、计算所述生成网络数据与所述本地数据之间的数量比例;

27、控制客户端基于所述局部模型蒸馏算法、所述数量比例以及所述融合数据对生成网络进行优化训练,得到所述全局模型;

28、其中,所述局部模型蒸馏算法的计算公式为:

29、;

30、其中,其中为所述本地数据的样本数量,为所述生成网络数据的样本数量,是代表客户端本地的深度学习模型在所述生成网络数据和所述融合数据之间kullback-leibler距离,为用于调整知识蒸馏强度的参数,为所述生成网络数据中标签为的样本数量,则代表归一化指数函数。

31、可选的,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对所述生成网络模型进行优化训练,得到全局模型之后,还包括:

32、若存在多个客户端,则控制每个客户端通过所述局部模型蒸馏算法、所述数据融合算法对所述全局模型进行迭代优化,获取全部客户端的优化模型;

33、接收所有客户端的所述优化模型,并根据所述优化模型进行平均加权处理,得到所述全局模型。

34、可选的,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对所述生成网络模型进行优化训练,得到全局模型之后,还包括:

35、接收全体客户端深度学习模型的模型参数;

36、基于每个客户端的所述模型参数通过可学习参数进行加权处理,得到集成模型;

37、基于所述生成网络模型批量生成的所述生成网络数据,得到虚拟数据集;

38、基于全局聚合蒸馏算法和集成模型,通过解耦所述生成网络数据中的类别信息对全局模型进行微调,得到全局微调模型;

39、将所述全局微调模型重新分发给各个客户端,控制每个客户端根据所述局部模型蒸馏算法以及所述融合数据、所述全局聚合蒸馏算法和所述集成模型对所述全局微调模型进行优化训练,直至所述全局微调模型收敛或者达到指定精度;

40、其中,所述集成模型的计算公式为:

41、;

42、其中,是一个可学习参数并处于0到1之间,则是用于控制权重参数正则化的程度,代表客户端上的所述模型参数;

43、所述全局聚合蒸馏算法的定义如下:

44、;

45、其中代表所述全局模型,代表所述集成模型,为所述虚拟数据集中的数据样本。

46、第二方面,本申请还提供一种面向非独立同分布数据的联邦知识蒸馏装置,采用如下技术方案,包括:

47、数据采样模块,用于根据预设的公共数据集进行随机采样,获取辅助数据集;

48、生成网络模块,用于基于预设的优化函数以及所述辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;

49、数据生成模块,用于将所述生本文档来自技高网...

【技术保护点】

1.一种面向非独立同分布数据的联邦知识蒸馏方法,其特征在于,所述方法包括:

2.根据权利要求1所述的方法,其特征在于,所述优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数。

3.根据权利要求2所述的方法,其特征在于,所述对抗目标损失函数的计算公式为:

4.根据权利要求3所述的方法,其特征在于,所述互信息平滑损失函数的计算公式为:

5.根据权利要求3所述的方法,其特征在于,所述相似度惩罚损失函数的计算公式为:

6.根据权利要求4所述的方法,其特征在于,所述控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取所述融合数据,包括:

7.根据权利要求6所述的方法,其特征在于,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型,包括:

8.根据权利要求7所述的方法,其特征在于,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型,包括:

9.根据权利要求1所述的方法,其特征在于,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型之后,还包括:

10.一种面向非独立同分布数据的联邦知识蒸馏装置,其特征在于,所述装置包括:

...

【技术特征摘要】

1.一种面向非独立同分布数据的联邦知识蒸馏方法,其特征在于,所述方法包括:

2.根据权利要求1所述的方法,其特征在于,所述优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数。

3.根据权利要求2所述的方法,其特征在于,所述对抗目标损失函数的计算公式为:

4.根据权利要求3所述的方法,其特征在于,所述互信息平滑损失函数的计算公式为:

5.根据权利要求3所述的方法,其特征在于,所述相似度惩罚损失函数的计算公式为:

6.根据权利要求4所述的方法,其特征在于,所述控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数...

【专利技术属性】
技术研发人员:田辉王欢郭玉刚张志翔
申请(专利权)人:合肥高维数据技术有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1