一种模型训练方法和相关装置制造方法及图纸

技术编号:35057930 阅读:14 留言:0更新日期:2022-09-28 11:07
本申请实施例公开了一种模型训练方法和相关装置,至少涉及人工智能模型中的机器学习,确定待训练模型包括的m个张量与n个并行进程之间的对应关系,m个张量包括在n个张量集合中,每个张量集合包括m个张量中的部分张量,n个张量集合与n个并行进程的对应关系为一一对应关系,使得每个并行进程只维护部分张量。目标并行进程与目标张量具有对应关系,在进行迭代的过程中,目标并行进程仅基于目标张量更新待训练模型的参数,根据更新后的参数训练待训练模型。不仅降低了创建临时缓存的数量,还降低了临时缓存的频繁创建和释放产生的内存碎片。由此,通过每个并行进程至维护部分张量,降低了激活层内存、临时缓存等,进而降低了模型的显存占用。的显存占用。的显存占用。

【技术实现步骤摘要】
一种模型训练方法和相关装置


[0001]本申请涉及计算机
,特别是涉及一种模型训练方法和相关装置。

技术介绍

[0002]随着人工智能的发展,模型逐渐朝着更大量级发展,如量级越大的自然语言模型的准确率更高,例如,生成型已训练变换模型3(Generative Pre

trained Transformer 3,GPT

3)的模型参数已达到175B。
[0003]在预训练阶段,较大的模型需要占用的显存较多。

技术实现思路

[0004]为了解决上述技术问题,本申请提供了一种模型训练方法和相关装置,用于降低训练模型的显存占用。
[0005]本申请实施例公开了如下技术方案:
[0006]一方面,本申请实施例提供一种模型训练方法,所述方法包括:
[0007]确定待训练模型包括的m个张量与n个并行进程之间的对应关系;其中,所述m个张量包括在n个张量集合中,每个张量集合包括所述m个张量中的部分张量,所述n个张量集合与所述n个并行进程的对应关系为一一对应关系,所述张量为所述待训练模型包括的多层网络的输入和输出,m和n为大于1的整数;
[0008]针对所述n个并行进程中的目标并行进程,基于与所述目标并行进程具有对应关系的目标张量集合更新所述待训练模型的参数;
[0009]根据更新后的参数训练所述待训练模型。
[0010]另一方面,本申请实施例提供一种模型训练装置,所述装置包括:确定单元、更新单元和训练单元;
[0011]所述确定单元,用于待训练模型包括的m个张量与n个并行进程之间的对应关系;其中,所述m个张量包括在n个张量集合中,每个张量集合包括所述多个张量中的部分张量,所述n个张量集合与所述n个并行进程的对应关系为一一对应关系,所述张量为所述待训练模型包括的多层网络的输入和输出,m和n为大于1的整数;
[0012]所述更新单元,用于针对所述n个并行进程中的目标并行进程,基于与所述目标并行进程具有对应关系的目标张量集合更新所述待训练模型的参数;
[0013]所述训练单元,用于根据更新后的参数训练所述待训练模型。
[0014]另一方面,本申请实施例提供一种计算机设备,所述设备包括处理器以及存储器:
[0015]所述存储器用于存储程序代码,并将所述程序代码传输给所述处理器;
[0016]所述处理器用于根据所述程序代码中的指令执行上述方面所述的方法。
[0017]另一方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质用于存储计算机程序,所述计算机程序用于执行上述方面所述的方法。
[0018]另一方面,本申请实施例提供了一种计算机程序产品或计算机程序,该计算机程
序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述方面所述的方法。
[0019]由上述技术方案可以看出,确定待训练模型包括的m个张量与n个并行进程之间的对应关系,m个张量包括在n个张量集合中,每个张量集合包括m个张量中的部分张量,n个张量集合与n个并行进程的对应关系为一一对应关系,使得每个并行进程只维护部分张量。其中,张量是待训练模型包括的多层网络的输入和输出,相比于每个进程保存待训练模型产生的所有张量,每个进程仅保存部分所维护的张量,降低了激活层内存。以多个并行进程中的目标并行进程为例,目标并行进程与目标张量集合具有对应关系,在进行迭代的过程中,目标并行进程仅基于目标张量集合更新待训练模型的参数,根据更新后的参数训练待训练模型。相比于为所有张量均创建临时缓存,仅为具有对应关系的张量集合中的张量创建临时缓存,不仅降低了创建临时缓存的数量,还降低了临时缓存的频繁创建和释放产生的内存碎片。由此,在训练模型的过程中,通过每个并行进程只维护部分张量,降低了激活层内存、临时缓存等,进而降低了模型的显存占用。
附图说明
[0020]为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
[0021]图1为本申请实施例提供的模型训练方法的应用场景示意图;
[0022]图2为本申请实施例提供的一种模型训练方法的流程示意图;
[0023]图3为本申请实施例提供一种多个张量与多个并行进程间对应关系的示意图;
[0024]图4为本申请实施例提供又一种多个张量与多个并行进程间对应关系的示意图;
[0025]图5为本申请实施例提供的一种获取规约梯度的示意图;
[0026]图6为本申请实施例提供的一种获取待训练模型更新后的参数的示意图;
[0027]图7为本申请实施例提供的一种兼容模型并行的示意图;
[0028]图8为本申请实施例提供的一种模型训练平台的框架示意图;
[0029]图9为本申请实施例提供的一种模型训练装置的示意图;
[0030]图10为本申请实施例提供的服务器的结构示意图;
[0031]图11为本申请实施例提供的终端设备的结构示意图。
具体实施方式
[0032]下面结合附图,对本申请的实施例进行描述。
[0033]在预训练阶段,模型的显存占用主要包括两个部分,一部分为优化器的状态、梯度以及模型参数等模型状态(model states)产生的显存占用,另一部分为其他内存产生的显存占用。例如,(1)激活层内存(activation memory),用于保存的算法模型中间层的结果输出,以便后续的反向计算。(2)临时缓存(buffer),用于保存前向计算过程中每一层计算层内部计算产生的中间结果,如模型每层网络的输入和输出。(3)临时buffer的频繁创建和释
放产生的内存碎片,也会占用一些显存,导致明明还有显存但是显存申请失败。
[0034]下面以自然语言处理(Natural Language Processing,NLP)领域常用的自适应矩估计(Adaptive momentum,adam)优化器结合混合精度训练为例,分析model states的显存占用。其中,混合精度是指训练时在模型中同时使用单精度浮点数(float point)32和半精度浮点数(float point)16类型,从而加快运行速度,减少内存使用的一种训练方法。下面具体说明。
[0035]训练过程中,各层权重被保存为FP32类型,每次迭代时,制作这些权重的FP16类型副本并使用它们进行前向计算和反向计算,更新时将梯度再转换为FP32类型并用于更新FP32类型的权重。使用这种方法可以解决权重相较于其更新值过大的问题。但是,adam优化器保存了每个参数的FP32类型的主权重(mast本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述方法包括:确定待训练模型包括的m个张量与n个并行进程之间的对应关系;其中,所述m个张量包括在n个张量集合中,每个张量集合包括所述m个张量中的部分张量,所述n个张量集合与所述n个并行进程的对应关系为一一对应关系,所述张量为所述待训练模型包括的多层网络的输入和输出,m和n为大于1的整数;针对所述n个并行进程中的目标并行进程,基于与所述目标并行进程具有对应关系的目标张量集合更新所述待训练模型的参数;根据更新后的参数训练所述待训练模型。2.根据权利要求1所述的方法,其特征在于,所述基于与所述目标并行进程具有对应关系的目标张量集合更新所述待训练模型的参数,包括:基于与所述目标并行进程具有对应关系的目标张量集合进行第i次前向传播和第i次反向传播;根据所述多个并行进程在所述第i次反向传播过程中得到的针对所述目标张量集合所包括的目标张量的多个梯度,确定针对所述目标张量的规约梯度;根据所述规约梯度更新所述待训练模型的参数。3.根据权利要求2所述的方法,其特征在于,所述根据所述规约梯度更新所述待训练模型的参数,包括:调用自适应矩估计优化器更新所述规约梯度对应的目标参数;根据所述目标参数获取所述待训练模型更新后的参数。4.根据权利要求2所述的方法,其特征在于,所述根据所述多个并行进程在所述第i次反向传播过程中得到的针对所述目标张量集合所包括的目标张量的多个梯度,确定针对所述目标张量的规约梯度,包括:获取所述多个并行进程在所述第i次反向传播过程中得到的针对所述目标张量集合所包括的目标张量的多个梯度;根据所述多个梯度的梯度总和与梯度数量,确定针对所述目标张量的规约梯度。5.根据权利要求1所述的方法,其特征在于,所述确定待训练模型包括的m个张量与n个并行进程之间...

【专利技术属性】
技术研发人员:弓静
申请(专利权)人:腾讯科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1