当前位置: 首页 > 专利查询>中南大学专利>正文

一种联邦学习实现方法、系统、终端设备及可读存储介质技术方案

技术编号:29097814 阅读:24 留言:0更新日期:2021-06-30 10:08
本发明专利技术公开了一种联邦学习实现方法、系统、终端设备及可读存储介质,该方法包括:客户端使用本地神经网络模型进行第一轮本地迭代得到模型权重和损失值,并上传至服务器;服务器将客户端的模型权重进行加权平均计算得到平均权重,以及根据客户端的损失值对客户端进行分组;客户端利用平均权重更新本地神经网络模型,并基于中位数损失自适应调节本地迭代次数,再基于本地数据进行迭代训练更新模型权值并得到新的损失值;客户端将更新后的模型权重和损失值上传至服务器进行循环更新。本发明专利技术的客户端以所在组的中位数损失值为标准调整本地迭代次数,有效降低本地计算复杂度,提升联邦学习效率,进一步利用公共数据集进行知识蒸馏,提升模型性能。提升模型性能。提升模型性能。

【技术实现步骤摘要】
一种联邦学习实现方法、系统、终端设备及可读存储介质


[0001]本专利技术属于联邦学习
,具体涉及一种联邦学习实现方法、系统、终端设备及可读存储介质。

技术介绍

[0002]在现实世界中,由于行业竞争、隐私安全等问题,数据大都以孤岛形式存在,即使在同一个公司的不同部门之间,实现数据整合也面临着重重阻力,面对难以桥接的数据孤岛,如何安全合法的使用多方数据进行联合建模始终是业界的一个难点。
[0003]为了解决这种数据孤岛问题,谷歌提出了针对移动设备的联邦学习方法。联邦学习可以使各个参与方的数据不出本地而使用各个参与方的数据共同协作训练出一个全局模型,可以解决数据孤岛这一痛点问题。谷歌提出的联邦学习方法步骤如下:首先,服务器选择可以参与本次全局迭代的设备。其次,服务器将上一轮全局模型参数发送给这些被选中的设备。然后,被选中的这些设备使用本地的私有数据和模型进行本地迭代计算,更新模型参数。最后,这些被选中的设备将此次更新的模型参数发送给服务器,服务器对接收到的模型参数进行加权平均,更新全局模型参数。
[0004]然而,谷歌提出的联邦学习方法是针对移动设备的方法,所以每一轮全局迭代中都要选择此次参与的设备。然而,在其它应用场景中,例如不同的医院之间通过联邦学习构建全局模型时,通常是不需要进行设备选择的。而且,谷歌提出的联邦学习方法中参与方的数量(一般超过10000)往往远远大于设备中的数据的数量,而跨机构的联邦学习方法中参与方的数量(一般不超过50)远远小于参与方数据的数量。与此同时,如何提高模型训练的效率也是联邦学习的关注点之一。
[0005]因此,针对跨数据孤岛问题的联邦学习模型,如何实现不限制于移动设备以及适用于跨机构的联邦学习模型,并如何提升联邦学习的通信效率是本专利技术亟需研究的。

技术实现思路

[0006]本专利技术的目的是为了克服现有技术中存在的不足,提供一种联邦学习实现方法、系统、终端设备及可读存储介质,其利用中位数损失将客户端分成不同组来自适应调整客户端的本地计算复杂度,有效提升了联邦学习的通信效率,且所述方法的训练过程无需选择客户端且对客户端的数量并无要求,可以有效应用于跨机构的联邦学习模型中,譬如,不同医院之间的联邦模型。
[0007]一方面,本专利技术提供一种联邦学习实现方法,包括如下步骤:
[0008]步骤1:每个客户端使用本地神经网络模型并利用本地数据进行第一轮本地迭代计算,得到模型权重和损失值,并上传至服务器;
[0009]其中,各个客户端与服务器通讯连接,每个客户端使用同一类本地神经网络;
[0010]步骤2:所述服务器将所有客户端的模型权重进行加权平均计算得到平均权重,以及根据每个客户端的损失值对客户端进行分组,并将平均权值以及客户端所在组的中位数
损失值发送给对应客户端;
[0011]步骤3:客户端利用平均权重更新所述本地神经网络模型,并基于中位数损失自适应调节本地迭代次数,再基于本地数据进行迭代训练更新模型权值并得到新的损失值;
[0012]其中,客户端将更新后的模型权重和损失值上传至服务器进行循环更新,直至客户端的模型满足预设标准。
[0013]本专利技术以中位数损失作为一个标准,让损失小的客户端少训练,损失高的客户端多训练,进行了均衡,总体上所有客户端的本地计算复杂度是小于常规的平均算法,并通过实验进行了有效验证。
[0014]可选地,步骤2中根据每个客户端的损失值对客户端进行分组的分组依据如下:
[0015]以每个客户端的损失值与所在组的中位数损失值的差的绝对值之和最小。
[0016]可选地,依据所述分组依据对客户端进行分组的过程为:将N个客户端上传的损失按照从小到大排列,并按照如下迭代过程将N个客户端划分为g组得到g个曼哈顿距离,且所述g个曼哈顿距离之和最小,其中,所述迭代过程如下:
[0017]A:设定参数i表示损失个数,对应取值范围为1

N;其中,在其取值范围内依次遍历取值;
[0018]B:设定参数j表示分组组数,对应取值范围为1

G;其中,在其取值范围内依次遍历取值;
[0019]C:在i值与j值确定下,设定参数k在[1,i]的范围内依次遍历取值,并按照如下公式计算得到
[0020][0021]式中,表示前i个损失分成了j组后,得到的j个曼哈顿距离之和的最小值,表示前k

1个损失分成了j

1组后,得到的j

1个曼哈顿距离之和的最小值;cost
ki
为从下标为k~i的客户端作为第j组时损失的曼哈顿距离;
[0022]其中,参数k遍历计算完成后,返回步骤B,在参数j的取值范围更新参数j,再执行步骤C;待参数j遍历完成后,再返回步骤A,在参数i的取值范围更新参数i,再执行步骤B与步骤C,实现循环迭代,直至将N个损失化为g组得到的g个曼哈顿距离之和最小,g小于或等于G。
[0023]其中,迭代算法一共是三重for循环,第一个枚举i,范围从1~N,第二个枚举j,范围从1~G,第三个枚举k,范围从1~i,在最后一个for循环中按照上述公式不断更新迭代算法的时间复杂度为O(GN2)。观察上述公式可以发现只要记录每组客户端的最后一次划分就能知道客户端的划分情况。每次全局迭代中都进行上述的调整,然后,每个客户端可以根据自身损失与组中的中位数损失比较来自适应调整本地计算的迭代次数,以此来减小联邦学习的本地计算复杂度。
[0024]可选地,步骤3中基于中位数损失自适应调节本地迭代次数的过程如下:
[0025]首先,客户端利用平均权重更新所述本地神经网络模型后,利用本地数据迭代训练次,E为预设的联邦平均算法中的本地训练迭代次数;
[0026]然后,判断训练后的损失值是否小于中位数损失值,若小于,则停止本地迭代训
练;否则,根据当前轮数确定当前的本地迭代次数为:r为当前轮数;
[0027]其中,若迭代训练过程中,客户端的损失值小于中位数损失值或客户端的本地迭代次数达到停止迭代训练。
[0028]可选地,步骤3之后,还执行:
[0029]步骤4:每个客户端利用公共数据集进行知识蒸馏完成模型权重更新,再将客户端的模型权重和损失值上传至服务器,返回步骤2进行循环更新;
[0030]其中,若客户端的模型满足预设标准,则停止循环更新。
[0031]第二方面,本专利技术提供一种联邦学习实现方法,应用于客户端时,包括如下步骤:
[0032]S1:使用本地神经网络模型并利用本地数据进行第一轮本地迭代计算,得到模型权重和损失值,并上传至服务器;
[0033]S2:接收服务器传送的平均权值以及客户端所在组的中位数损失值;
[0034]其中,所述服务器将所有客户端的模型权重进行加权平均计算得到平均权重,以及根据每个客户端的损失值对客户端进行分组得到客户端所在组的中位数损失值;
本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种联邦学习实现方法,其特征在于:包括如下步骤:步骤1:每个客户端使用本地神经网络模型并利用本地数据进行第一轮本地迭代计算,得到模型权重和损失值,并上传至服务器;其中,各个客户端与服务器通讯连接,每个客户端使用同一类本地神经网络;步骤2:所述服务器将所有客户端的模型权重进行加权平均计算得到平均权重,以及根据每个客户端的损失值对客户端进行分组,并将平均权值以及客户端所在组的中位数损失值发送给对应客户端;步骤3:客户端利用平均权重更新所述本地神经网络模型,并基于中位数损失自适应调节本地迭代次数,再基于本地数据进行迭代训练更新模型权值并得到新的损失值;其中,客户端将更新后的模型权重和损失值上传至服务器进行循环更新,直至客户端的模型满足预设标准。2.根据权利要求1所述的方法,其特征在于:步骤2中根据每个客户端的损失值对客户端进行分组的分组依据如下:以每个客户端的损失值与所在组的中位数损失值的差的绝对值之和最小。3.根据权利要求2所述的方法,其特征在于:依据所述分组依据对客户端进行分组的过程为:将N个客户端上传的损失按照从小到大排列,并按照如下迭代过程将N个客户端划分为g组得到g个曼哈顿距离,且所述g个曼哈顿距离之和最小,其中,所述迭代过程如下:A:设定参数i表示损失个数,对应取值范围为1

N;其中,在其取值范围内依次遍历取值;B:设定参数j表示分组组数,对应取值范围为1

G;其中,在其取值范围内依次遍历取值;C:在i值与j值确定下,设定参数k在[1,i]的范围内依次遍历取值,并按照如下公式计算得到算得到式中,表示前i个损失分成了j组后,得到的j个曼哈顿距离之和的最小值,表示前k

1个损失分成了j

1组后,得到的j

1个曼哈顿距离之和的最小值;cost
ki
为从下标为k~i的客户端作为第j组时损失的曼哈顿距离;其中,参数k遍历计算完成后,返回步骤B,在参数j的取值范围更新参数j,再执行步骤C;待参数j遍历完成后,再返回步骤A,在参数i的取值范围更新参数i,再执行步骤B与步骤C,实现循环迭代,直至将N个损失化为g组得到的g个曼哈顿距离之和最小,g小于或等于G。4.根据权利要求1所述的方法,其特征在于:步骤3中基于中位数损失自适应调节本地迭代次数的过程如下:首先,客户端利用平均权重更新所述本地神经网络模型后,利用本地数据迭代训练次,E为预设的联邦平均算法中的本地训练迭代次数;然后,判断训练后的损失值是否小于中位数损失值,若小于,则停止本地迭代训练;否则,根据当前轮数确定当前的本地迭代次数为:r为当前轮数;
其中,若迭代训练过程中,客户端的损失值小于中位数损失值或客户端的本地迭代次数达到停止迭代训练。5.根据权利要求1所述的方法,其特征在于:步骤3之后,还执行:步骤4:每个客户端利用公共数据集进行知识蒸馏完成模型权重更新,再将客户端的模型权重和损失值上传至服务器,...

【专利技术属性】
技术研发人员:王建新吴帆刘渊安莹胡建中黄伟红
申请(专利权)人:中南大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1