模型训练方法、装置、电子设备和存储介质制造方法及图纸

技术编号:38223998 阅读:9 留言:0更新日期:2023-07-25 17:55
本申请公开了一种模型训练方法、装置、电子设备和存储介质;本申请可以获取预训练模型和目标模型;获取无标注样本组和标注样本组,以及无标注概率和标注概率;基于无标注概率和标注概率,分别对无标注样本组和标注样本组进行线性组合,得到无标注样本组对应的无标注扩增样本和标注样本组对应的标注扩增样本;基于无标注扩增样本和预训练模型,对目标模型进行初步训练;基于标注扩增样本和预训练模型,对初步训练后的目标模型进行再次训练,得到训练好的目标模型。在本申请中,通过对无标注样本和标注样本进行扩增并结合预训练模型来训练目标模型,可以提升目标模型的泛化性和鲁棒性。由此,本方案可以提升目标模型的性能。本方案可以提升目标模型的性能。本方案可以提升目标模型的性能。

【技术实现步骤摘要】
模型训练方法、装置、电子设备和存储介质


[0001]本申请涉及人工智能
,具体涉及一种模型训练方法、装置、电子 设备和存储介质。

技术介绍

[0002]为了从数据量越来越大且信息冗余度高的数据集中提取出有效信息,需要 的深度学习模型的规模也越来越大,甚至可能需要将多个深度学习模型集成。 但是大规模的预训练模型的推断速度慢、对部署的资源要求高(例如,容量更 大的内存、性能更高的显存等),使得大规模的预训练模型不利于部署在具体 的任务中。
[0003]因此,目前提出了知识蒸馏,即一种模型压缩的方法,知识蒸馏可以将大 规模的预训练模型学习到的知识迁移到目标模型中,使得目标模型可以完成大 规模的预训练模型原本能实现的任务。
[0004]然而,目前知识迁移后的目标模型的性能较差。

技术实现思路

[0005]本申请提供一种模型训练方法、装置、电子设备和存储介质,可以提升目 标模型的性能。
[0006]本申请提供一种模型训练方法,包括:
[0007]获取预训练模型和目标模型;
[0008]获取无标注样本组和标注样本组,以及无标注概率和标注概率,无标注样 本组中包括至少两个无标注样本,标注样本组中包括至少两个标注样本;
[0009]基于无标注概率和标注概率,分别对无标注样本组和标注样本组进行线性 组合,得到无标注样本组对应的无标注扩增样本和标注样本组对应的标注扩增 样本;
[0010]基于无标注扩增样本和预训练模型,对目标模型进行初步训练,得到初步 训练后的目标模型;
[0011]基于标注扩增样本和预训练模型,对初步训练后的目标模型进行再次训 练,得到训练好的目标模型。
[0012]本申请还提供一种模型训练方法装置,包括:
[0013]第一获取单元,用于获取预训练模型和目标模型;
[0014]第二获取单元,用于获取无标注样本组和标注样本组,以及无标注概率和 标注概率,无标注样本组中包括至少两个无标注样本,标注样本组中包括至少 两个标注样本;
[0015]扩增单元,用于基于无标注概率和标注概率,分别对无标注样本组和标注 样本组进行线性组合,得到无标注样本组对应的无标注扩增样本和标注样本组 对应的标注扩增样本;
[0016]初步训练单元,用于基于无标注扩增样本和预训练模型,对目标模型进行 初步训练,得到初步训练后的目标模型;
[0017]再次训练单元,用于基于标注扩增样本和预训练模型,对初步训练后的目 标模型进行再次训练,得到训练好的目标模型。
[0018]在一些实施例中,预训练模型包括多个网络模块,目标模型包括多个网络 层,网络模块与网络层一一对应,标注扩增样本包括标签;再次训练单元具体 用于:
[0019]将标注扩增样本输入预训练模型,得到目标网络模块的第一输出结果,目 标网络模块为多个网络模块中的任一网络模块;
[0020]将标注扩增样本输入初步训练后的目标模型,得到目标网络层的第二输出 结果,目标网络层为多个网络层中与目标网络模块对应的网络层;
[0021]基于第一输出结果、第二输出结果以及标签,对目标模型的参数进行再次 更新。
[0022]在一些实施例中,再次训练单元具体用于:
[0023]根据第一输出结果和第二输出结果,对目标网络层进行损失计算,得到网 络层损失值;
[0024]基于网络层损失值,对目标模型中输入层至目标网络层的参数进行更新;
[0025]根据第二输出结果和标签,对目标模型进行损失计算,得到输出损失值;
[0026]基于输出损失值,对目标模型的参数进行再次更新。
[0027]在一些实施例中,标签包括第一标签和第二标签,再次训练单元具体用于:
[0028]根据第一标签和第二输出结果,对目标模型进行损失计算,得到第一输出 损失值;
[0029]根据第二标签和第二输出结果,对目标模型进行损失计算,得到第二输出 损失值;
[0030]对第一输出损失值和第二输出损失值进行损失融合,得到输出损失值。
[0031]在一些实施例中,标注概率包括第一标注概率和第二标注概率,再次训练 单元具体用于:
[0032]采用第一标注概率,对第一输出损失值进行加权处理,得到第一加权损失 值;
[0033]采用第二标注概率,对第二输出损失值进行加权处理,得到第二加权损失 值;
[0034]基于第一加权损失值和第二加权损失值,得到输出损失值。
[0035]在一些实施例中,网络层损失值包括特征提取损失值,第一输出结果包括 第一特征,第二输出结果包括第二特征;再次训练单元具体用于:
[0036]当第一特征与第二特征的维度不一致时,对第一特征进行空间转换,得到 转换后的第一特征;
[0037]计算第二特征的特征值与转换后的第一特征的特征值之间的特征二范数;
[0038]基于特征二范数,确定特征提取损失值。
[0039]在一些实施例中,网络层损失值包括注意力损失值,第一输出结果包括第 一注意力概率分布,第二输出结果包括第二注意力概率分布;再次训练单元具 体用于:
[0040]计算第一注意力概率分布和第二注意力概率分布之间的注意力相对熵;
[0041]基于注意力相对熵,确定注意力损失值。
[0042]在一些实施例中,网络层损失值包括分类损失值,第一输出结果包括第一 分类概率分布,第二输出结果包括第二分类概率分布;再次训练单元具体用于:
[0043]计算第一分类概率分布和第二分类概率分布之间的分类相对熵;
[0044]基于分类相对熵,确定分类损失值。
[0045]在一些实施例中,预训练模型包括多个网络模块,目标模型包括多个网络 层,网络模块与网络层一一对应;初步训练单元具体用于:
[0046]将无标注扩增样本输入预训练模型,得到目标网络模块的第三输出结果, 目标网络模块为多个网络模块中的任一网络模块;
[0047]将无标注扩增样本输入目标模型,得到目标网络层的第四输出结果,目标 网络层为多个网络层中与目标网络模块对应的网络层;
[0048]基于第三输出结果、第四输出结果,对目标模型的参数进行初步更新。
[0049]在一些实施例中,第一获取单元具体用于:
[0050]获取预训练模型,以及获取待处理的目标模型,待处理的目标模型包括多 个待处理的网络层;
[0051]对预训练模型进行分块处理,得到多个网络模块,网络模块与待处理的网 络层一一对应;
[0052]获取网络模块中第n层网络层的参数,n为正整数;
[0053]将第n层网络层的参数作为对应的待处理的网络层的初始参数,得到目标 模型。
[0054]在一些实施例中,第一获取单元具体用于:
[0055]当第n层网络层与待处理的网络层的维度不一致时,对待处理的网络层的 参数进行随机初始化处理,得到随机初始化后的网本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,包括:获取预训练模型和目标模型;获取无标注样本组和标注样本组,以及无标注概率和标注概率,所述无标注样本组中包括至少两个无标注样本,所述标注样本组中包括至少两个标注样本;基于所述无标注概率和所述标注概率,分别对所述无标注样本组和所述标注样本组进行线性组合,得到所述无标注样本组对应的无标注扩增样本和所述标注样本组对应的标注扩增样本;基于所述无标注扩增样本和所述预训练模型,对所述目标模型进行初步训练,得到初步训练后的目标模型;基于所述标注扩增样本和所述预训练模型,对所述初步训练后的目标模型进行再次训练,得到训练好的目标模型。2.如权利要求1所述的模型训练方法,其特征在于,所述预训练模型包括多个网络模块,所述目标模型包括多个网络层,所述网络模块与所述网络层一一对应,所述标注扩增样本包括标签;所述基于所述标注扩增样本和所述预训练模型,对所述初步训练后的目标模型进行再次训练,包括:将所述标注扩增样本输入所述预训练模型,得到目标网络模块的第一输出结果,所述目标网络模块为所述多个网络模块中的任一网络模块;将所述标注扩增样本输入所述初步训练后的目标模型,得到目标网络层的第二输出结果,所述目标网络层为所述多个网络层中与所述目标网络模块对应的网络层;基于所述第一输出结果、所述第二输出结果以及所述标签,对所述目标模型的参数进行再次更新。3.如权利要求2所述的模型训练方法,其特征在于,所述基于所述第一输出结果、所述第二输出结果以及所述标签,对所述目标模型的参数进行再次更新,包括:根据所述第一输出结果和所述第二输出结果,对所述目标网络层进行损失计算,得到网络层损失值;基于所述网络层损失值,对所述目标模型中输入层至所述目标网络层的参数进行更新;根据所述第二输出结果和所述标签,对所述目标模型进行损失计算,得到输出损失值;基于所述输出损失值,对所述目标模型的参数进行再次更新。4.如权利要求3所述的模型训练方法,其特征在于,所述标签包括第一标签和第二标签,所述根据所述第二输出结果和所述标签,对所述目标模型进行损失计算,得到输出损失值,包括:根据所述第一标签和所述第二输出结果,对所述目标模型进行损失计算,得到第一输出损失值;根据所述第二标签和所述第二输出结果,对所述目标模型进行损失计算,得到第二输出损失值;对所述第一输出损失值和所述第二输出损失值进行损失融合,得到输出损失值。5.如权利要求4所述的模型训练方法,其特征在于,所述标注概率包括第一标注概率和
第二标注概率,所述对所述第一输出损失值和所述第二输出损失值进行损失融合,得到输出损失值,包括:采用所述第一标注概率,对所述第一输出损失值进行加权处理,得到第一加权损失值;采用所述第二标注概率,对所述第二输出损失值进行加权处理,得到第二加权损失值;基于所述第一加权损失值和所述第二加权损失值,得到输出损失值。6.如权利要求3所述的模型训练方法,其特征在于,所述网络层损失值包括特征提取损失值,所述第一输出结果包括第一特征,所述第二输出结果包括第二特征;所述根据所述第一输出结果和所述第二输出结果,对所述目标网络层进行损失计算,得到网络层损失值,包括:当所述第一特征与所述第二特征的维度不一致时,对所述第一特征进行空间转换,得到转换后的第一特征;计算所述第二特征的特征值与所述转换后的第一特征的特征值之间的特征二范数;基于所述特征二范数,确定所述特征提取损失值。7.如权利要求3所述的模型训练方法,其特征在于,所述网络层损失值包括注意力损失值,所述第一输出结果包括第一注意力概率分布,所述第二输出结果包括第二注意力概率分布;所述根据所述第一输出结果和所述第二输出结果,对所述目标网络层进行损失计算,得到网络层损失值,包括:计算所述第一注意力概率分布和所述第二注意力概率分布之间的注意力相对熵;基于所述注意力相对熵,确定所述注意力损失值。8.如权利要求3所述的模型训练方法,其特征在于,所述网络层损失值包括分类损失值,所述第一输出结果包括...

【专利技术属性】
技术研发人员:周洁田乐周霄
申请(专利权)人:腾讯科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1