分布式模型训练的负载均衡方法和装置制造方法及图纸

技术编号:36430223 阅读:25 留言:0更新日期:2023-01-20 22:42
本公开提供了分布式模型训练的负载均衡方法和装置,涉及人工智能领域,尤其涉及深度学习领域。具体实现方案为:统计分布式模型的各个计算节点上的负载量;根据各个计算节点上的负载量之间的比值确定所述分布式模型是否负载均衡;若负载不均衡且有空闲的计算节点,则为所述分布式模型增加与负载量最大的目标计算节点的模型参数相同的同类计算节点;在各个计算节点上进行梯度反向计算后,将所述同类计算节点上的网络参数的梯度与所述目标计算节点上的网络参数的梯度进行同步。该实施方式实现了通过增加或减少计算节点的数量,有效地均衡各个计算节点的负载,充分利用计算和存储资源。资源。资源。

【技术实现步骤摘要】
分布式模型训练的负载均衡方法和装置


[0001]本公开涉及人工智能领域,尤其涉及深度学习领域,具体为一种分布式模型训练的负载均衡方法和装置。

技术介绍

[0002]在近年来的深度学习模型训练中,使用更多的训练数据和更大的模型趋势未改。更大的模型和数据量意味着更多的计算量和存储需求,也意味着更久的训练时间。那么如何将计算和存储需求分布到多个训练设备来提升训练速度,是关键问题。
[0003]数据并行(data parallelism)是解决上述问题的一种并行策略,在数据并行的模型训练中,训练任务被切分到多个进程(设备)上,每个进程维护相同的模型参数和相同的计算任务,但是处理不同的数据(batch data)。通过这种方式,同一全局数据(global batch)下的数据和计算被切分到了不同的进程,从而减轻了单个设备上的计算和存储压力。
[0004]分布式模型训练(例如,MoE(Mixure

of

Experts,混合专家模型))是实现超大规模模型训练的技术路径之一。该模型的思想是训练多个神经网络(分布在多个计算节点中),每个计算节点训练数据集的不同部分。由于每个计算节点的输入数据量不同,计算时间不均匀,造成严重的负载不平衡:一方面,因为单个计算节点可能处理过量的数据,导致内存超出限制;在另一方面,同步通信必须等待最慢计算节点,导致计算利用率下降,类似于“木桶效应”。

技术实现思路

[0005]本公开提供了一种分布式模型训练的负载均衡方法、装置、设备、存储介质以及计算机程序产品。
[0006]根据本公开的第一方面,提供了一种分布式模型训练的负载均衡方法,包括:统计分布式模型的各个计算节点上的负载量;根据各个计算节点上的负载量之间的比值确定所述分布式模型是否负载均衡;若负载不均衡且有空闲的计算节点,则为所述分布式模型增加与负载量最大的目标计算节点的模型参数相同的同类计算节点;在各个计算节点上进行梯度反向计算后,将所述同类计算节点上的网络参数的梯度与所述目标计算节点上的网络参数的梯度进行同步。
[0007]根据本公开的第二方面,提供了一种分布式模型训练的负载均衡装置,包括:统计单元,被配置成统计分布式模型的各个计算节点上的负载量;确定单元,被配置成根据各个计算节点上的负载量之间的比值确定所述分布式模型是否负载均衡;增加单元,被配置成若负载不均衡且有空闲的计算节点,则为所述分布式模型增加与负载量最大的目标计算节点的模型参数相同的同类计算节点;同步单元,被配置成在各个计算节点上进行梯度反向计算后,将所述同类计算节点上的网络参数的梯度与所述目标计算节点上的网络参数的梯度进行同步。
[0008]根据本公开的第三方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行第一方面所述的方法。
[0009]根据本公开的第四方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行第一方面所述的方法。
[0010]根据本公开的第五方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现第一方面所述的方法。
[0011]本公开的实施例提供的分布式模型训练的负载均衡方法和装置,通过增加或减少计算节点的数量,有效地均衡各个计算节点的负载,充分利用计算和存储资源。
[0012]应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
[0013]附图用于更好地理解本方案,不构成对本公开的限定。其中:
[0014]图1a

1d是本公开的一个实施例可以应用于其中的示例性系统架构的示意图;
[0015]图2是根据本公开的分布式模型训练的负载均衡方法的一个实施例的流程图;
[0016]图3是根据本公开的分布式模型训练的负载均衡方法的一个应用场景的示意图;
[0017]图4是根据本公开的分布式模型训练的负载均衡方法的又一个实施例的流程图;
[0018]图5是根据本公开的分布式模型训练的负载均衡方法的又一个应用场景的示意图;
[0019]图6是根据本公开的分布式模型训练的负载均衡装置的一个实施例的结构示意图;
[0020]图7是适于用来实现本公开的实施例的电子设备的计算机系统的结构示意图。
具体实施方式
[0021]以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
[0022]图1示出了可以应用本公开的分布式模型训练的负载均衡方法或分布式模型训练的负载均衡装置的实施例的示例性系统架构100。
[0023]如图1所示,系统架构100可以包括服务器(Server)和多个客户端(Worker,即计算节点)。
[0024]系统架构100是分布式训练领域普遍采用的编程架构,主要解决以下两类问题:
[0025]1、模型参数过大:单机内存空间不足,需要采用分布式存储。
[0026]2、训练数据过多:单机训练太慢,需要加大训练节点,来提高并发训练速度。
[0027]如图1所示,系统架构100主要包含Server和Worker两个部分,其中Server负责参数的存储和更新,而Worker负责训练。简单来说,基于该系统架构进行模型训练的基本思
路:当训练数据过多,一个Worker训练太慢时,可以引入多个Worker同时训练,这时Worker之间需要同步模型参数。直观想法是,引入一个Server,Server充当Worker间参数交换的媒介。当模型参数过大以至于单机存储空间不足时或Worker过多导致一个Server是瓶颈时,就需要引入多个Server。
[0028]模型训练的具体流程如下:
[0029]1、将训练数据(样本集)均匀的分配给不同的Worker。
[0030]2、将模型参数分片,存储在不同的Server上。
[0031]3、Worker端:读取一个minibatch训练数据,从Server端拉取最新的参数,计算梯度,并根据分片上传给不同的Server。
[0032]4、Server端:接收Worker端上传的梯度,根据优化算法更新参数。根据Server端每次参数更新是否需要等待所有Worker端的梯度,分为同步训练和异步训练两种机制。
[0033]MoE模型可采用系统架构100来训练。在MoE模型中,数据经过骨干网络本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种分布式模型训练的负载均衡方法,包括:统计分布式模型的各个计算节点上的负载量;根据各个计算节点上的负载量之间的比值确定所述分布式模型是否负载均衡;若负载不均衡且有空闲的计算节点,则为所述分布式模型增加与负载量最大的目标计算节点的模型参数相同的同类计算节点;在各个计算节点上进行梯度反向计算后,将所述同类计算节点上的网络参数的梯度与所述目标计算节点上的网络参数的梯度进行同步。2.根据权利要求1所述的方法,其中,所述方法还包括:若负载不均衡且没有空闲的计算节点,则将负载量最小的至少2个计算节点合并成1个计算节点。3.根据权利要求1所述的方法,其中,所述分布式模型为混合专家模型,每个计算节点包括:骨干网络、门控网络和专家网络。4.根据权利要求2所述的方法,其中,分布式模型为混合专家模型,每个计算节点包括:骨干网络、门控网络和专家网络;以及所述将负载量最小的至少2个计算节点合并成1个计算节点,包括:将负载量最小的至少2个计算节点的骨干网络和门控网络的参数分别合并后作为公共骨干网络和公共门控网络;将所述公共门控网络的输出结果分别作为所述负载量最小的至少2个计算节点的专家网络的输入。5.根据权利要求3所述的方法,其中,所述统计分布式模型的各个计算节点上的负载量,包括:统计分布式模型的各个计算节点上的专家网络的负载量。6.根据权利要求3所述的方法,其中,所述将所述同类计算节点上的网络参数的梯度与所述目标计算节点上的网络参数的梯度进行同步,包括:将所述同类计算节点上的专家网络的参数的梯度与所述目标计算节点上的专家网络的参数的梯度进行同步。7.根据权利要求2所述的方法,其中,所述方法还包括:在将负载量最小的至少2个计算节点合并成1个计算节点后空出的计算节点中,加载与负载量最大的目标计算节点的模型参数相同的模型。8.一种分布式模型训练的负载均衡装置,包括:统计单元,被配置成统计分布式模型的各个计算节点上的负载量;确定单元,被配置成根据各个计算节点上的负载量之间的比值确定所述分布式模型是否负载均衡;增加单元,被配置成若负载不均衡且有空闲的计算节点,则为所述分布式模型增加与负载量最大的目标计算节点的模型参数...

【专利技术属性】
技术研发人员:沈亮吴志华于佃海
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1