基于联邦学习的模型训练方法、装置、设备和介质制造方法及图纸

技术编号:35359068 阅读:32 留言:0更新日期:2022-10-26 12:41
本公开提供了一种基于联邦学习的模型训练方法、装置、设备和介质,涉及人工智能领域,尤其涉及深度学习领域。具体实现方案为:获取联邦学习的各参与方通过自身训练样本进行相同模型训练所得的局部网络参数;根据各所述局部网络参数,确定各所述参与方的训练评价数据;根据各所述训练评价数据,对各所述参与方进行聚类;根据相同类别的参与方的局部网络参数,确定相应参与方的全局网络参数;将各所述参与方的全局网络参数反馈至相应参与方,以供模型训练。根据本公开的技术,提高了所训练模型的精度和鲁棒性。型的精度和鲁棒性。型的精度和鲁棒性。

【技术实现步骤摘要】
基于联邦学习的模型训练方法、装置、设备和介质


[0001]本公开涉及人工智能领域,具体为深度学习领域,尤其涉及一种基于联邦学习的模型训练方法、装置、设备和介质。

技术介绍

[0002]联邦学习(Federated Learning)是一种新兴的人工智能基础技术。
[0003]目前,联邦学习的研究主要集中在横向联邦和纵向联邦,在横向联邦学习场景中,各节点(Client)在本地训练数据,将训练的模型参数信息上传到中央服务器,由中央服务器(ParamServer)聚合参数信息以达到共同训练的目的。

技术实现思路

[0004]本公开提供了一种基于联邦学习的模型训练方法、装置、设备和介质。
[0005]根据本公开的一方面,提供了一种基于联邦学习的模型训练方法,该方法包括:
[0006]获取联邦学习的各参与方通过自身训练样本进行相同模型训练所得的局部网络参数;
[0007]根据各所述局部网络参数,确定各所述参与方的训练评价数据;
[0008]根据各所述训练评价数据,对各所述参与方进行聚类;
[0009]根据相同类别的参与方的局部网络参数,确定相应参与方的全局网络参数;
[0010]将各所述参与方的全局网络参数反馈至相应参与方,以供模型训练。
[0011]根据本公开的另一方面,提供了一种电子设备,该电子设备包括:
[0012]至少一个处理器;以及
[0013]与所述至少一个处理器通信连接的存储器;其中,
[0014]所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开任一实施例提供的基于联邦学习的模型训练方法。
[0015]根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,计算机指令用于使计算机执行根据本公开任一实施例提供的基于联邦学习的模型训练方法。
[0016]本公开实施例提高了所训练模型的精度和鲁棒性。
[0017]应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
[0018]附图用于更好地理解本方案,不构成对本公开的限定。其中:
[0019]图1是本公开实施例提供的一种基于联邦学习的模型训练方法的示意图;
[0020]图2是本公开实施例提供的另一种基于联邦学习的模型训练方法的示意图;
[0021]图3是本公开实施例提供的又一种基于联邦学习的模型训练方法的示意图;
[0022]图4是本公开实施例提供的再一种基于联邦学习的模型训练方法的示意图;
[0023]图5是本公开实施例提供的一种基于联邦学习的模型训练装置的示意图;
[0024]图6是用来实现本公开实施例的基于联邦学习的模型训练方法的电子设备的框图。
具体实施方式
[0025]以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
[0026]本公开实施例提供的基于联邦学习的模型训练方法和基于联邦学习的模型训练装置,适用于利用联邦学习方法,对至少一个参与方的本地模型进行联合训练的场景中。本公开实施例所提供的各基于联邦学习的模型训练方法,可以由基于联邦学习的模型训练装置执行,该装置可以采用软件和/或硬件实现,并具体配置于电子设备中,该电子设备可以是具有一定数据运算能力的电子设备,例如可以是服务器。为了便于表述,以下将该电子设备统称为中央服务器,不应理解为对电子设备的具体呈现形式的限定。
[0027]为了便于理解,首先对基于联邦学习的模型训练方法进行详细说明。
[0028]参见图1所示基于联邦学习的模型训练方法,包括:
[0029]S101、获取联邦学习的各参与方通过自身训练样本进行相同模型训练所得的局部网络参数。
[0030]联邦学习是一种分布式机器学习技术,其核心思想是通过在多个拥有训练样本的数据源之间进行分布式模型训练,在不需要交换训练样本的前提下,仅通过交换模型参数或中间结果的方式,构建模型。参与方用于表征联邦学习中拥有训练样本的数据源,可以作为训练节点参与基于联邦学习的模型训练。参与方的数量为至少一个,不同参与方可以同时进行模型训练。一个参与方用于训练一个模型,各参与方训练的模型结构相同。自身训练样本是指参与方自身持有或存储的样本数据,无需向其他参与方进行训练样本的共享。各参与方的自身训练样本可以相同,也可以至少部分不同。局部网络参数用于表征各参与方对模型进行训练后得到的网络参数。
[0031]具体的,参与联邦学习的各参与方,利用自身训练样本对相同模型进行训练,将训练后的模型参数作为局部网络参数发送至中央服务器;中央服务器接收各参与方的局部网络参数。
[0032]S102、根据各所述局部网络参数,确定各所述参与方的训练评价数据。
[0033]训练评价数据用于表征参与方对模型进行训练的训练情况。具体的,各参与方之间由于自身训练样本不同或节点训练能力不同,导致模型训练后得到的局部网络参数存在差异。因此,通过训练评价数据可以量化不同参与方的模型训练情况,用以作为后续对参与方进行聚类的重要参照。参与方与训练评价数据相对应。
[0034]其中,训练评价数据可以包括局部密度数据和跟随距离数据等。
[0035]本公开实施例对训练评价数据的确定方式不作任何限定。在一个可选实施例中,
可以根据各参与方得到的局部网络参数,可以通过DPC(clustering by fast search and find of density peaks,密度峰值聚类)算法,确定各参与方的局部密度数据和跟随距离数据,作为训练评价数据,以供后续使用。
[0036]S103、根据各所述训练评价数据,对各所述参与方进行聚类。
[0037]具体的,根据各参与方的训练评价数据,对参与方进行聚类,得到至少一个类别的参与方。相同类别的参与方的训练评价数据相同或近似,且相同类别的参与方的数量为至少一个。
[0038]S104、根据相同类别的参与方的局部网络参数,确定相应参与方的全局网络参数。
[0039]全局网络参数是指属于相同类别的参与方的网络参数的联合训练结果。
[0040]具体的,可以对相同类别的参与方的局部网络参数求平均值,将平均值作为相应参与方的全局网络参数;也可以基于预设联合平均(FedAvg)算法,对相同类别的参与方的局部网参数进行加权平均,将加权平均结果作为相应参与方的全局网络参数。其中,可以根据相同类别的参与方的自身训练样本数量,确定局部网络参数的权重。
[本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.基于联邦学习的模型训练方法,包括:获取联邦学习的各参与方通过自身训练样本进行相同模型训练所得的局部网络参数;根据各所述局部网络参数,确定各所述参与方的训练评价数据;根据各所述训练评价数据,对各所述参与方进行聚类;根据相同类别的参与方的局部网络参数,确定相应参与方的全局网络参数;将各所述参与方的全局网络参数反馈至相应参与方,以供模型训练。2.根据权利要求1所述的方法,其中,所述训练评价数据包括局部密度数据和跟随距离数据。3.根据权利要求2所述的方法,其中,根据各所述局部网络参数,确定各所述参与方的局部密度数据,包括:根据不同参与方的局部网络参数,确定不同参与方之间的参数差异数据;针对各参与方,根据该参与方对应的参数差异数据,确定该参与方的局部密度数据。4.根据权利要求3所述的方法,其中,所述根据不同参与方的局部网络参数,确定不同参与方之间的参数差异数据,包括:根据不同参与方的局部网络参数,确定不同参与方之间的参数距离数据;将邻域截断距离与不同参与方之间的参数距离数据的比值,作为不同参与方之间的参数差异数据。5.根据权利要求4所述的方法,其中,所述方法还包括:根据不同参与方之间的参数距离数据,确定所述邻域截断距离。6.根据权利要求3所述的方法,其中,所述针对各参与方,根据该参与方对应的参数差异数据,确定该参与方的局部密度数据,包括:针对各参与方,将该参与方对应的各参数差异数据映射至指数空间,得到相应密度关联数据;将该参与方的各密度关联数据进行叠加,得到该参与方的所述局部密度数据。7.根据权利要求2所述的方法,其中,根据各所述局部网络参数,确定各所述参与方的跟随距离数据,包括:根据不同参与方的局部网络参数,确定不同参与方之间的参数距离数据;针对所述局部密度数据较小的参与方,将该参与方与其他参与方之间的较小参数距离数据作为该参与方的跟随距离数据;针对所述局部密度数据较大的参与方,将该参与方与其他参与方之间的较大参数距离数据作为该参与方的跟随距离数据。8.根据权利要求2

7任一项所述的方法,其中,所述根据各所述训练评价数据,对各所述参与方进行聚类,包括:根据各参与方的局部密度数据与局部密度阈值的相对大小,以及各参与方的跟随距离数据与跟随距离阈值的相对大小,确定各聚类中心;根据各所述聚类中心,对各所述参与方进行聚类。9.根据权利要求8所述的方法,其特征在于,所述方法还包括:根据不同参与方的局部密度数据,确定所述局部密度阈值;以及,根据不同参与方的跟随距离数据,确定所述跟随距离阈值。
10.基于联邦学习的模型训练装置,包括:局部网络参数获取模块,用于获取联邦学习的各参与方通过自身训练样本进行相同模型训练所得的局部网络参数;训练评价数据获取模块,用于根据各所述局部网络参数,确定各所述参与方的训练评价数据;聚类模块,用于根据各所述训练评价数据,对各所述参与方进行聚类;全局网络参数确定模块,用于根据相同类别的参与方的局部网络参数,确定相应参与方的全局网络参数;全局网络参数反馈模块,用于将各所述参与方的全局网络参数反馈至相应参与方,以供模型训练。11.根据权利要求10所述的装置...

【专利技术属性】
技术研发人员:彭胜波周吉文
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1