当前位置: 首页 > 专利查询>清华大学专利>正文

基于深度强化学习的联邦学习客户端智能选取方法及系统技术方案

技术编号:29491417 阅读:22 留言:0更新日期:2021-07-30 19:03
本发明专利技术公开了一种基于深度强化学习的联邦学习客户端智能选取方法及系统,该方法包括:联邦平台通过从联邦服务市场环境中收集客户端的状态作为输入,输入到基于策略网络的客户端选择智能体中,输出客户端选择方案;联邦平台根据当前环境状况以及客户端选择方案从多个候选客户端中选取一组最优的客户端以协同训练联邦学习模型,并将联邦学习性能作为奖励反馈给客户端选择智能体,以奖励用于优化更新策略网络;策略网络通过强化学习方法离线训练得到。本发明专利技术可从候选移动边缘设备中选择高质量的设备参与联邦学习,以处理分布式客户端低质量数据问题,以显著提高联邦学习质量。

【技术实现步骤摘要】
基于深度强化学习的联邦学习客户端智能选取方法及系统
本专利技术涉及大规模分布式边缘智能学习系统的性能优化
,尤其涉及一种基于深度强化学习的联邦学习客户端智能选取方法及系统。
技术介绍
移动边缘设备的普及使得边缘产生的数据快速增长,同时也促进了现代人工智能应用的繁荣发展。然而,由于隐私问题和高昂的数据传输成本,传统的在云端收集大量数据进行集中式模型训练的机制变得不太可取。为了在不泄露隐私的前提下充分利用数据资源,一种新的学习范式应运而生,即联邦学习(FederatedLearning,FL),它可以让移动边缘设备在不共享其原始数据的情况下协同训练全局模型。在联邦学习中,分布式设备使用自己的数据在本地训练全局模型,然后将模型更新提交给服务器进行模型聚合,聚合后的模型更新用于更新全局模型,然后返回给每个设备以进行下一轮的迭代。全局模型的训练过程便可以通过这种方式以分布式和隐私保护的方式迭代完成。联邦学习尽管在隐私保护方面具有巨大的潜力,但在实现高性能学习质量方面仍然面临着技术挑战。与在数据中心进行训练时数据充足且资源不受限制不同,参与联邦学习的分布式设备通常在硬件条件和数据资源上都受到限制,且存在异质性,这会极大地影响学习性能。例如,由于传感器的缺陷和功率的限制,移动设备难免会收集一些错误标注的低质量数据,导致设备本地学习质量参差不齐。然而,不加区分地聚合低质量的模型更新会反向恶化全局模型的质量。因此,客户端选择,尤其是从候选客户端中选择合适的移动设备参与分布式学习,成为高质量联邦学习的关键。最近,现有的一些工作提出了一些联邦学习的客户端选择方案。例如,Nishio等人提出了一种资源感知的选择方案,根据客户端的计算和通信资源选择客户端,使得能够在有限的资源约束下最大限度地增加参与者的数量,加速联邦学习性能的提升。Mohammed等人通过选择模型测试精度较高的候选客户端参与联邦学习的训练过程,提高了联邦学习的学习精度。Huang等人提出了一种有公平性保证的客户端选择方案,可以在联邦学习的训练效率和公平性之间取得良好的权衡。为了减少联邦学习训练的延迟,Xia等人提出了一种基于多臂老虎机的在线客户端调度方案,可以显著缩短模型训练的时间开销。Wang等人提出利用强化学习智能选择联邦学习的参与客户端,以克服客户端非独立同分布的数据对学习性能的负面影响,加快模型训练过程。但是,现有的客户端选择方案并没有充分考虑客户端的数据质量对联邦学习性能的影响,如何综合考虑客户端的数据数量、数据质量、计算资源等因素对模型训练质量的影响,为联邦学习智能地选取高质量的参与节点,仍需进一步探索和研究。
技术实现思路
本专利技术提供了一种基于深度强化学习的联邦学习客户端智能选取方法(以下简称AUCTION)及系统,用以解决现有的客户端选择方案并没有充分考虑客户端的数据数量、数据质量、计算资源等因素对联邦学习性能的影响的技术问题。为解决上述技术问题,本专利技术提出的技术方案为:一种基于深度强化学习的联邦学习客户端智能选取方法,应用于联邦服务市场框架,联邦服务市场框架包括一个以一定的预算招募客户端完成联邦学习任务的联邦平台和多个愿意参与联邦学习并向联邦平台提交联邦学习任务的候选客户端;包括以下步骤:联邦平台通过从联邦服务市场环境中收集客户端的状态作为输入,输入到基于策略网络的客户端选择智能体中,输出客户端选择方案;联邦平台根据当前环境状况以及客户端选择方案从多个候选客户端中选取一组最优的客户端以协同训练联邦学习模型,并将联邦学习性能作为奖励反馈给客户端选择智能体,以奖励用于优化更新策略网络;策略网络通过强化学习方法离线训练得到。作为本专利技术的方法的进一步改进:客户端选择智能体,为基于编码器-解码器结构的策略网络,编码器将客户端状态映射为中间向量表示,解码器根据中间向量表示生成客户端选择方案;客户端状态包括数据大小、数据质量和价格。优选地,策略网络的强化学习模型,包括状态、动作、奖励和策略:状态:状态s={x1,x2,…,xn}包含给定联邦学习任务所有候选客户端的特征,每个客户端Ci的特征xi是一个三维向量,用xi={qi,di,bi}表示,其中qi和di分别是客户端Ci的数据质量和用于训练的样本数量,bi是客户端Ci完成该学习任务的价格;动作:采用顺序动作,即客户端选择代理通过采取一系列的动作一一做出客户端选择决策;一个单独的动作只从一组最多N个候选客户端中选出一个客户端;奖励:将执行客户端选择操作后从联邦服务市场观察到的奖励r作为训练后损失函数值的减少率,即:其中,F(w)是学习任务测试数据集上的初始全局损失函数值,F(w*)是经过选定客户端的多轮协同训练后达到的测试损失函数值;策略:将客户端选择的一个可行动作a={a1,…,ai,…}定义为候选客户端的一个子集,其中ai∈{C1,C2,…,Cn}且策略网络为一个随机的客户端选择策略π(a|s,B)用于在给定状态s和学习预算B的情况下选择一个可行动作a;训练策略网络的目标是最大化累计奖励。优选地,最大化累计奖励,表示为:其中r(a|s)是在状态s执行动作a后的奖励;使用REINFORCE算法来优化J,使用梯度下降来不断优化参数θ:其中b(s)代表一个独立于a的基准函数用于加速训练过程;参数θ是编码器和解码器可学习参数的并集。优选地,编码器包括:客户端嵌入层首先通过线性投影把三维输入特征xi转化为初始的dh维嵌入向量其中Wx和bx为可学习参数;然后,嵌入向量会经过L个注意力层更新,其中,每一个注意力层l∈{1,2,…,L}输出嵌入向量每个注意力层包含一个MHA层和一个FF层,每层后面都添加了一个跳跃连接和批归一化。优选地,解码器包括:基于编码器输出的嵌入向量和解码器在时间t′<t时间输出的客户端选择结果,解码器在每个时间点t输出一个选中的客户端at直到学习预算用尽;解码器的网络包含一个多头注意力层和一个单头注意力层。本专利技术还提供一种计算机系统,包括存储器、处理器以及存储在存储器上并可在处理器上运行的计算机程序,处理器执行计算机程序时实现上述任一方法的步骤。本专利技术具有以下有益效果:1、本专利技术的基于深度强化学习的联邦学习客户端智能选取方法及系统,可以利用客户端当前的学习质量相关的监测信息和历史的模型训练记录,自动学习客户端选择策略,以在联邦学习服务市场中实时地做出客户端选择决策。2、在优选方案中,本专利技术利用深度强化学习技术,将客户端选择策略编码为神经网络,将每个客户端的数据大小、数据质量和学习价格作为输入,并输出在学习预算内选择的客户端子集,然后策略网络观察所选客户端的联邦学习性能,再利用策略梯度算法逐步改进其客户端选择策略。3、本专利技术的基于深度强化学习的联邦学习客户端智能选取方法及系统,为了能够适应联邦服务市场中客户端数量的动态变化并减小强化学习算法的搜索空间,本专利技术设计了基于编码器-解码器本文档来自技高网
...

【技术保护点】
1.一种基于深度强化学习的联邦学习客户端智能选取方法,应用于联邦服务市场框架,所述联邦服务市场框架包括一个以一定的预算招募客户端完成联邦学习任务的联邦平台和多个愿意参与联邦学习并向联邦平台提交联邦学习任务的候选客户端;其特征在于,包括以下步骤:/n联邦平台通过从联邦服务市场环境中收集客户端的状态作为输入,输入到基于策略网络的客户端选择智能体中,输出客户端选择方案;联邦平台根据当前环境状况以及所述客户端选择方案从所述多个候选客户端中选取一组最优的客户端以协同训练联邦学习模型,并将联邦学习性能作为奖励反馈给所述客户端选择智能体,以奖励用于优化更新策略网络;所述策略网络通过强化学习方法离线训练得到。/n

【技术特征摘要】
1.一种基于深度强化学习的联邦学习客户端智能选取方法,应用于联邦服务市场框架,所述联邦服务市场框架包括一个以一定的预算招募客户端完成联邦学习任务的联邦平台和多个愿意参与联邦学习并向联邦平台提交联邦学习任务的候选客户端;其特征在于,包括以下步骤:
联邦平台通过从联邦服务市场环境中收集客户端的状态作为输入,输入到基于策略网络的客户端选择智能体中,输出客户端选择方案;联邦平台根据当前环境状况以及所述客户端选择方案从所述多个候选客户端中选取一组最优的客户端以协同训练联邦学习模型,并将联邦学习性能作为奖励反馈给所述客户端选择智能体,以奖励用于优化更新策略网络;所述策略网络通过强化学习方法离线训练得到。


2.根据权利要求1所述的基于深度强化学习的联邦学习客户端智能选取方法,其特征在于,所述客户端选择智能体,为基于编码器-解码器结构的策略网络,编码器将客户端状态映射为中间向量表示,解码器根据所述中间向量表示生成客户端选择方案;所述客户端状态包括数据大小、数据质量和价格。


3.根据权利要求2所述的基于深度强化学习的联邦学习客户端智能选取方法,其特征在于,所述策略网络的强化学习模型,包括状态、动作、奖励和策略:
状态:状态s={x1,x2,…,xn}包含给定联邦学习任务所有候选客户端的特征,每个客户端Ci的特征xi是一个三维向量,用xi={qi,di,bi}表示,其中qi和di分别是客户端Ci的数据质量和用于训练的样本数量,bi是客户端Ci完成该学习任务的价格;
动作:采用顺序动作,即客户端选择代理通过采取一系列的动作一一做出客户端选择决策;一个单独的动作只从一组最多N个候选客户端中选出一个客户端;
奖励:将执行客户端选择操作后从联邦服务市场观察到的奖励r作为训练后损失函数值的减少率,即:



其中,F(w)是学习任务测试数据集上的初始全局损失函数值,F(w*)是经...

【专利技术属性】
技术研发人员:张尧学邓永恒吕丰任炬
申请(专利权)人:清华大学中南大学
类型:发明
国别省市:北京;11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1