当前位置: 首页 > 专利查询>西北大学专利>正文

一种基于残差网络的高准确率和高效通信的联邦平均算法制造技术

技术编号:39593488 阅读:16 留言:0更新日期:2023-12-03 19:48
本发明专利技术公开了一种基于残差网络的高准确率和高效通信的联邦平均算法:步骤

【技术实现步骤摘要】
一种基于残差网络的高准确率和高效通信的联邦平均算法


[0001]本专利技术涉及计算机
,具体涉及一种基于残差网络的高准确率和高效通信的联邦平均算法

技术背景
[0002]在数据经济时代,如何从大量数据中充分挖掘有用信息关系着众多企业的命脉

从传统的机器学习到如今的人工智能数据仍是核心地位

现实是,除了少数巨头企业拥有了大量的数据,绝大部分企业拥有的数据不仅量少且质量差,因此不足以支撑人工智能的实现

在传统方法中,通常是把各参与方的数据整合到服务器进行训练并得到模型

但由于数据存在的巨大的潜在价值,以及陆续出台的针对加强数据保护的相关政策,公司之间乃至公司内的部门之间往往不会提供自己的数据与其他的做聚合

因此由于用户隐私

商业机密和政策约束便引发了数据孤岛问题

那么如何联合不同组织与机构整合他们拥有的原始数据,共同训练出一个效果好能力强的模型迫在眉睫

[0003]基于上述背景,联邦学习应运而生,它是一种新的机器学习模式,各参与的客户端可借助其他方数据进行联合建模,而数据却保留在本地进行本地训练,客户端与服务器之间仅交换网络模型的参数,这样便在不暴露私有数据的前提下建立共享的机器学习模型

联邦学习通常迭代以下四个步骤:
1.
客户端向服务器索要模型参数,服务器把最新模型参数给客户端;
2.
客户端用本地数据和最新的模型参数进行本地更新;
3.
客户端把更新后的模型参数发送给服务器;
4.
服务器聚合来自各个客户端的模型参数后用其更新全局模型

但联邦学习提出后也遇到了很多瓶颈,其中就有客户端和服务器要频繁通信,因此联邦学习的通信代价很大,远大于计算代价

因此后续谷歌提出了联邦平均算法,其改进为在客户端本地更新时多迭代几轮得到多轮累积的模型参数上传到服务器聚合,因此该算法以牺牲计算量为代价降低了通信成本,用更少的通信次数便可以达到收敛,解决了这一瓶颈


技术实现思路

[0004]联邦平均算法
(FedAvg)
采用两种模型:
MLP

LeNet5。
这两种模型限制了测试集准确率,并且其测试集准确率在稳定后方差很大,体现在其准确率曲线波动大,在少数通信轮时,准确率在
95
%以下,有的通信回合甚至降至
90
%以下

本专利技术用调整后的深度残差网络替换原网络进行本地更新,在
ResNet18
的基础上更改网络层的相关参数及删除某些层来简化网络结构形成
ResNet18

E
,在准确率和通信效率上找到一个好的平衡

通过对
MNIST
数据集的独立同分布
(IID)
和独立非同分布
(Non

IID)
场景进行了一系列测试,以评估本专利技术

结果表明,所提出的方案在准确率和稳定性方面均优于原来算法,虽然改后较复杂的网络结构增加了每轮通信的时长,但却减少了达到目标准确率的通信轮数,但整体上总通信时长
(
达到目标准确率的通信轮数和每轮通信的时长的乘积
)
减少

[0005]为了实现上述目的,本专利技术采用的技术方案是:
[0006]一种基于残差网络的高准确率和高效通信的联邦平均算法,具体包括以下步骤:
[0007]步骤
1、
服务器对
MNIST
数据集进行
IID

Non

IID
划分,并分配到各客户端上;
[0008]步骤
2、
每个客户端本地构建残差网络模型,具体包括如下子步骤:
[0009]步骤
21
,选择
ResNet18
作为残差网络模型的主干网络;
[0010]步骤
22
,基于主干网络,构建
ResNet18

E
的具体网络结构;所述
ResNet18

E
包括依次连接的3×3卷积层

卷积组
Stage1、Stage2、Stage3
和线性层,其中:
[0011]所述卷积组
Stage1、Stage2、Stage3
均包含两个残差块,每个残差块包括两个相同通道数的3×3卷积层,每个所述卷积层后接一个
Batch Normalization
层和
ReLU
激活函数;并将该残差块的输入跳跃加在其第二个卷积层后的
Batch Normalization
层和
ReLU
激活函数之间;另外,在
Stage2
的第一个残差块和
Stage3
的第一个残差块分别增加一个下采样结构;
[0012]所述线性层为两个依次连接的线性层;
[0013]步骤
3、
设置服务器和客户端之间通信轮数
R
,在每一轮通信中服务器随机选择所有客户端
C
中的
K
个参与该轮通信,并向选中的客户端发送其存储的全局模型的参数;
[0014]步骤
4、
被选中的客户端下载全局模型的参数,在预设的本地迭代次数内进行本地机器学习训练;训练过程中用交叉熵损失函数计算网络损失,反向传播计算梯度,最后用随机梯度下降算法不断更新本地模型精度,得到各自对应的训练好的模型参数;
[0015]步骤
5、
当这
K
个客户端更新结束后,它们分别将本地训练好的模型参数上传给服务器;
[0016]步骤
6、
服务器节点从
K
个客户端接收它们的本地更新后的训练好的模型参数,按加权平均聚合策略更新全局模型参数,并作为聚合结果保存,将更新后的全局模型参数作为当前全局模型参数,即重新执行步骤3‑6开始新一轮通信的全局模型参数更新,直至第
R
轮通信得到全局模型参数,作为最终的全局模型参数

[0017]进一步的,步骤1具体包括如下子步骤:
[0018]步骤
11
,对
MNIST
数据集进行
IID
划分,并分配到各客户端上;
[0019]步骤
12
,对数据集进行
Non

IID
划分,并分配到各客户端上

[0020]进一步的,客户端本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.
一种基于残差网络的高准确率和高效通信的联邦平均算法,其特征在于,具体包括以下步骤:步骤
1、
服务器对
MNIST
数据集进行
IID

Non

IID
划分,并分配到各客户端上;步骤
2、
每个客户端本地构建残差网络模型,具体包括如下子步骤:步骤
21
,选择
ResNet18
作为残差网络模型的主干网络;步骤
22
,基于主干网络,构建
ResNet18

E
的具体网络结构;所述
ResNet18

E
包括依次连接的3×3卷积层

卷积组
Stage1、Stage2、Stage3
和线性层,其中:所述卷积组
Stage1、Stage2、Stage3
均包含两个残差块,每个残差块包括两个相同通道数的3×3卷积层,每个所述卷积层后接一个
Batch Normalization
层和
ReLU
激活函数;并将该残差块的输入跳跃加在其第二个卷积层后的
Batch Normalization
层和
ReLU
激活函数之间;另外,在
Stage2
的第一个残差块和
Stage3
的第一个残差块分别增加一个下采样结构;所述线性层为两个依次连接的线性层;步骤
3、
设置服务器和客户端之间通信轮数
R
,在每一轮通信中服务器随机选择所有客户端
C
中的
K
个参与该轮通信,并向选中的客户端发送其存储的全局模型的参数;步骤
4、
被选中的客户端下载全局模型的参数,在预设的本地迭代次数内进行本地机器学习训练;训练过程中用交叉熵损失函数计算网络损失,反向传播计算梯度,最后用随机梯度下降算法不断更新本地模型精度,得到各自对应的训练好的模型参数;步骤
5、
当这
K
个客户端更新结束后,它们分别将本地训练好的模型参数上传给服务...

【专利技术属性】
技术研发人员:王海李蕊郝明远张寅袁浩博马于惠
申请(专利权)人:西北大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1