当前位置: 首页 > 专利查询>中南大学专利>正文

一种高泛化性的个性化联邦学习实现方法技术

技术编号:36085583 阅读:15 留言:0更新日期:2022-12-24 11:01
本发明专利技术公开了一种高泛化性的个性化联邦学习实现方法,包括,服务端随机初始化全局双分支模型并发送初始化参数至客户端;客户端初始化本地双分支模型并利用本地数据进行本地迭代训练得到更新的客户端本地模型;将更新后的客户端本地模型训练的统计参数和全局任务分支的模型参数上传至服务端;服务端聚合所有客户端的全局任务分支的模型参数并更新发送给多个客户端;客户端根据服务端更新的全局任务分支模型参数并结合本地迭代训练得到的个性化任务分支模型参数,构成更新的客户端本地双分支模型;客户端使用本地双分支模型基于本地数据迭代训练并循环参与联邦更新直至满足预设标准。可在保证个性化联邦学习有效性的同时提升模型的泛化性。时提升模型的泛化性。时提升模型的泛化性。

【技术实现步骤摘要】
一种高泛化性的个性化联邦学习实现方法


[0001]本专利技术涉及联邦学习
,尤其涉及一种高泛化性的个性化联邦学习实现方法。

技术介绍

[0002]联邦学习是指多个相互隔离的孤岛数据集上训练模型的任务,在愈加严格的隐私政策的要求下,传统中心式汇聚多个数据孤岛的数据来进行数据挖掘的方式变得不可行,而单个数据孤岛的有效数据不足,数据驱动的建模和数据挖掘受到限制,此时联邦学习便能发挥作用。通用联邦学习是指,所有客户端在不共享数据的情况下,共同训练一个共识模型,以尽可能地学到来自多个客户端数据的知识。通用联邦学习步骤主要包括:客户端选择、模型分发、模型训练和模型聚合,通过迭代直至收敛得到一个聚合的共识模型。
[0003]由于联邦学习数据隔离的固有属性,客户端的数据分布不可知,不同客户端模型的学习存在很强的异质性,如通过来自不同地理环境的客户端解决不同的任务客户端,但是此时聚合的共识模型偏向某些客户端从而整体表现不佳。为了处理客户端之间的这种异质性,个性化联邦学习允许每个客户端保留并优化独立的个性化模型,而不是使用全局的共识模型。旨在客户端从联邦学习中获得收益的同时,在本地可见的数据上有更好的表现,即个性化模型的表现优于客户端孤岛式独自训练产生的模型,同时优于联邦共识模型。
[0004]虽然个性化联邦学习方法为联邦客户端的异质性困境提供了解决方案,但是主流的个性化联邦学习实现方法侧重于在可见数据的性能提升。由于对可见数据的进一步优化,大多数主流方法生成的个性化模型容易过拟合,最终导致较强模型偏向性和模型泛化性降低。然而,模型泛化性是现实场景中需要关注的问题,例如,医院客户端接收来自未知医院的转诊患者的数据,不仅能关注联邦模型在本地可见数据的表现,还可以关注其在未知分布数据上的性能。
[0005]因此,亟需一种可侧重于模型的泛化性的个性化联邦学习实现方法,在保证个性化联邦学习有效性的同时,还可以提升模型的泛化性。

技术实现思路

[0006]针对
技术介绍
中的问题,本专利技术提供了一种高泛化性的个性化联邦学习实现方法,利用任务独立的个性化批归一化和全局批归一化特征,通过双分支结构同时学习模型的个性化能力和泛化能力,即不仅能有效地提升客户端本地模型面对未知数据的泛化能力,还能保证客户端本地模型在客户端本地数据分布下的个性化能力。
[0007]第一方面,本专利技术提供了一种高泛化性的个性化联邦学习实现方法,包括,
[0008]步骤1:服务端随机初始化双分支结构的全局模型,将得到的初始化模型参数发送至多个选定的客户端;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支;
[0009]步骤2:每个客户端利用服务端发送的初始化模型参数,初始化双分支结构的客户端本地模型,并利用本地数据进行第一轮本地迭代训练,得到更新后的客户端本地模型;将
更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端;
[0010]步骤3:服务端将所有客户端的全局任务子模型的模型参数进行加权平均计算得到聚合后新的全局任务子模型的模型参数,并将更新后的模型参数发送给多个所选客户端;
[0011]步骤4:客户端利用服务端发送的全局任务子模型的模型参数,更新客户端本地模型中的全局任务子模型的模型参数,结合本轮联邦训练中迭代训练得到的客户端本地模型中的个性化任务子模型,得到更新的客户端本地模型,完成一轮联邦训练;
[0012]步骤5:客户端使用步骤4更新的客户端本地模型基于本地数据进行再一轮迭代训练,更新客户端本地模型参数,并将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端,返回步骤3,循环更新客户端本地模型直至满足预设标准。
[0013]进一步地,服务端使用的全局模型和客服端使用的客户端本地模型结构相同,即模型的特征提取层后添加批归一化层;其中,特征提取层为任务共享层,批归一化层为任务特定层;任务特定层包括全局批归一化层和个性化批归一化层。
[0014]进一步地,全局任务子模型由任务共享层和全局批归一化层构成;个性化任务子模型由任务共享层和个性化批归一化层构成。
[0015]进一步地,所述统计参数包括客户端参与训练的数据量。
[0016]优选地,步骤2中本地迭代训练得到更新后的客户端本地模型的过程具体为:
[0017]将本地数据x输入客户端本地模型后同时执行两个分支得到两个任务的输出,即全局任务输出y
g
和个性化任务输出y
l
,通过计算交叉熵损失分别得到全局任务损失loss
g
和个性化任务损失loss
l

[0018]交叉熵损失的表达式如下:
[0019][0020]其中,a取g或l;y
j
为预测目标,是实际预测结果;m表示参与训练的客户端的数量;
[0021]利用全局任务损失loss
g
和个性化任务损失loss
l
得到总体损失loss
overall
,表示式为:
[0022]loss
overall
=αloss
g
+(1

α)loss
l
[0023]其中,α为损失比例系数;
[0024]结合总体损失和预设的学习率η,客户端通过随机梯度下降和反向传播得到更新的计算模型整体的梯度,得到更新的客户端本地模型的模型参数,客户端本地模型的模型参数更新表达式如下:
[0025][0026]其中,g
l
表示个性化任务子模型优化得到的一次迭代的总体梯度;g
g
分别表示全局任务子模型优化得到的一次迭代的总体梯度;w
g
表示全局任务子模型的模型参数;w
l
表示个性化任务子模型的模型参数;t表示当前联邦训练的轮次;i表示第i个客户端。
[0027]优选地,步骤3中通过加权平均计算得到聚合后新的全局任务子模型的模型参数具体为:
[0028]计算客户端参与训练的数据量占所有客户端参与训练数据总量的比重;
[0029]全局任务子模型的模型参数w
g
的更新公式如下:
[0030][0031]其中,K表示参与训练的客户端总数;k表示第k个客户端;n表示所有客户端参与训练的数据总量;n
k
表示第k个客户端训练的数据量;表示第k个客户端在第t轮联邦训练中的全局任务子模型的模型参数;w
g,t+1
表示第k个客户端在第t+1轮联邦训练中的全局任务子模型的模型参数。
[0032]优选地,步骤5中预设标准具体为:
[0033]根据损失曲线对数据和客户端分布进行判断:
[0034]若为稳定收敛的数据和客户端分布时,客户端本地模型经过预设轮次的联邦训练后,将最后一个轮次的模型参数作为训练结果;
[0035]若为不能稳定收敛的数据本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种高泛化性的个性化联邦学习实现方法,其特征在于,包括,步骤1:服务端随机初始化双分支结构的全局模型,将得到的初始化模型参数发送至多个选定的客户端;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支;步骤2:每个客户端利用服务端发送的初始化模型参数,初始化双分支结构的客户端本地模型,并利用本地数据进行第一轮本地迭代训练,得到更新后的客户端本地模型;将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端;步骤3:服务端将所有客户端的全局任务子模型的模型参数进行加权平均计算得到聚合后新的全局任务子模型的模型参数,并将更新后的模型参数发送给多个所选客户端;步骤4:客户端利用服务端发送的全局任务子模型的模型参数,更新客户端本地模型中的全局任务子模型的模型参数,结合本轮联邦训练中迭代训练得到的客户端本地模型中的个性化任务子模型,得到更新的客户端本地模型,完成一轮联邦训练;步骤5:客户端使用步骤4更新的客户端本地模型基于本地数据进行再一轮迭代训练,更新客户端本地模型参数,并将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端,返回步骤3,循环更新客户端本地模型直至满足预设标准。2.根据权利要求1所述的高泛化性的个性化联邦学习实现方法,其特征在于,服务端使用的全局模型和客服端使用的客户端本地模型结构相同,包括模型的特征提取层和相应的批归一化层;其中,特征提取层为任务共享层,批归一化层为任务特定层;任务特定层包括全局批归一化层和个性化批归一化层。3.根据权利要求2所述的高泛化性的个性化联邦学习实现方法,其特征在于,全局任务子模型由任务共享层和全局批归一化层构成;个性化任务子模型由任务共享层和个性化批归一化层构成。4.根据权利要求1所述的高泛化性的个性化联邦学习实现方法,其特征在于,所述统计参数包括客户端参与训练的数据量。5.根据权利要求1所述的高泛化性的个性化联邦学习实现方法,其特征在于,S2中本地迭代训练得到更新后的客户端本地模型的过程具体为:将本地数据x输入客户端本地模型后同时执行两个分支得到两个任务的输出,即全局任务输出y
g
和个性化任务输出y
l
,通过计算交叉熵损失分别得到全局任务损失loss
g
和个性化任务损失loss
l
;交叉熵损失的表达式如下:其中,a取g或l;y
j
为预测目标,是实际预测结果;m表示参与训练的客户端的数量;利用全局任务损失loss
g
和个性化任务损失loss
l
得到总体损失loss
overall
,表示式为:loss
overall
=αloss
g
+(1

α)loss
l
其中,α为损失比例系数;结合总体损失和预设的学习率η,客户端通过随机梯度下降和反向传播得到更新的计算模型整体的梯度,得到更新的客户端本地模型的模型参数,客户端本地模型的模型参数更新表达式如下:
其中,g
l
表示个性化任务子模型优化得到的一次迭代的总体梯度;g
g
分别表示全局任务子模型优化得到的一次迭代的总体梯度;w
g
表示全局任务子模型的模型参数;w
l
表示个性化任务子模型的模型参数;t表示当前联邦训练的轮次;...

【专利技术属性】
技术研发人员:王建新刘渊沈成超盛韬王殊段桂华
申请(专利权)人:中南大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1