模型压缩方法、系统、电子设备及存储介质技术方案

技术编号:35330998 阅读:33 留言:0更新日期:2022-10-26 11:48
本发明专利技术实施例涉及机器学习技术领域,公开了一种模型压缩方法、系统、电子设备及存储介质。模型压缩方法包括:提供训练好的N种类型的复杂模型;对N种类型的复杂模型进行融合,得到训练好的教师模型;基于训练样本、教师模型以及学生模型的损失函数,对学生模型进行训练;学生模型的损失函数由第一损失函数和第二损失函数融合得到,第一损失函数用于计算学生模型的预测值与真实值的损失,第二损失函数用于计算学生模型的logit值与教师模型的logit值的损失。本申请实施例提供的技术方案,可以提高训练得到的学生模型的预测精度。高训练得到的学生模型的预测精度。高训练得到的学生模型的预测精度。

【技术实现步骤摘要】
模型压缩方法、系统、电子设备及存储介质


[0001]本专利技术实施例涉及机器学习
,特别涉及模型压缩方法、系统、电子设备及存储介质。

技术介绍

[0002]文本相似度匹配应用广泛,比如在信息检索中,信息检索系统为了能召回更多与检索词语相似的结果,可以用相似度来识别相似的词语,以此提高召回率。另外,在自动问答中,可以使用自然语言交互,相似度在这里可以用来计算用户以自然语言的提问问句与语料库中问题的匹配程度,那么匹配度最高的那个问题对应的答案将作为响应。
[0003]近年来BERT模型的出现,刷新了文本分类、文本相似度、机器翻译等多个自然语言处理任务的指标,很多人工智能公司也在逐渐将BERT模型应用到实际的工程项目中,虽然BERT的效果较好,但是由于模型太大,不仅对硬件设备的性能要求较高,而且对数据的处理时间会较长。进而,又出现了基于知识蒸馏方式得到一个轻量级模型,以克服模型太大导致的对硬件设备的性能要求较高且对数据的处理时间会较长的问题。现有的知识蒸馏方式中,是将一个训练好的复杂模型作为教师模型,并用该教师模型来指导轻量级的学生模型的学习,从而将教师模型中的暗知识迁移到学生模型中。

技术实现思路

[0004]本专利技术实施例的目的在于提供一种模型压缩方法、电子设备及存储介质,可以提高训练得到的学生模型的预测精度。
[0005]为解决上述技术问题,本专利技术的实施例提供了一种模型压缩方法,包括:提供训练好的N种类型的复杂模型;N为大于或等于2的整数;对N种类型的复杂模型进行融合,得到训练好的教师模型;基于训练样本、所述教师模型以及学生模型的损失函数,对所述学生模型进行训练;所述学生模型的损失函数由第一损失函数和第二损失函数融合得到,所述第一损失函数用于计算所述学生模型的预测值与真实值的损失,所述第二损失函数用于计算所述学生模型的logit值与所述教师模型的logit值的损失;其中,所述训练样本包含样本输入和样本输出,所述学生模型在接收所述样本输入后输出所述预测值且所述学生模型中的logit层输出所述logit值,所述真实值为所述样本输出;所述教师模型在接收所述样本输入后,所述教师模型中的logit层输出所述logit值。
[0006]本专利技术的实施例还提供了一种模型压缩系统,包括:复杂模型训练单元,用于提供训练好的N种类型的复杂模型;N为大于或等于2的整数;教师模型获取单元,用于对所述N种类型的复杂模型进行融合,得到训练好的教师模型;学生模型训练单元,用于基于训练样本、所述教师模型以及学生模型的损失函数,对所述学生模型进行训练;所述学生模型的损失函数由第一损失函数和第二损失函数融合得到,所述第一损失函数用于计算所述学生模型的预测值与真实值的损失,所述第二损失函数用于计算所述学生模型的logit值与所述教师模型的logit值的损失;
[0007]其中,所述训练样本包含样本输入和样本输出,所述学生模型在接收所述样本输入后输出所述预测值且所述学生模型中的logit层输出所述logit值,所述真实值为所述样本输出;所述教师模型在接收所述样本输入后,所述教师模型中的logit层输出所述logit值。
[0008]本专利技术的实施例还提供了一种电子设备,包括:至少一个处理器;以及,与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够上述模型压缩方法。
[0009]本专利技术的实施例还提供了一种计算机可读存储介质,存储有计算机程序,所述计算机程序被处理器执行时实现上述模型压缩方法。
[0010]本专利技术实施例相对于现有技术而言,在基于知识蒸馏的方式的模型压缩过程中:教师模型由N种类型的复杂模型融合得到,可以汲取多种类型的复杂模型的优点,使得教师模型更加全面;在学生模型的损失函数也是由第一损失函数和第二损失函数融合得到,第一损失函数用于计算所述学生模型的预测值与真实值的损失,实现基于硬目标的训练;第二损失函数用于计算所述学生模型的logit值与所述教师模型的logit值的损失,实现基于软目标的训练;使得学生模型的损失函数是融合了基于硬目标和基于软目标的训练,训练精度会更好。因此,本申请实施例的模型压缩方法,可以提高训练得到的学生模型的预测精度。
附图说明
[0011]一个或多个实施例通过与之对应的附图中的图片进行示例性说明,这些示例性说明并不构成对实施例的限定,附图中具有相同参考数字标号的元件表示为类似的元件,除非有特别申明,附图中的图不构成比例限制。
[0012]图1是根据本申请一个实施例的模型压缩方法的流程图;
[0013]图2是根据本申请另一个实施例的模型压缩方法的流程图;
[0014]图3是根据本申请一个实施例的模型压缩系统的方框图;
[0015]图4是根据本申请一个实施例的电子设备的方框图。
具体实施例
[0016]为使本专利技术实施例的目的、技术方案和优点更加清楚,下面将结合附图对本专利技术的各实施例进行详细的阐述。然而,本领域的普通技术人员可以理解,在本专利技术各实施例中,为了使读者更好地理解本申请而提出了许多技术细节。但是,即使没有这些技术细节和基于以下各实施例的种种变化和修改,也可以实现本申请所要求保护的技术方案。以下各个实施例的划分是为了描述方便,不应对本专利技术的具体实现方式构成任何限定,各个实施例在不矛盾的前提下可以相互结合相互引用。
[0017]本专利技术的一个实施例涉及一种模型压缩方法,具体流程如图1所示。
[0018]步骤101,提供训练好的N种类型的复杂模型;N为大于或等于2的整数。
[0019]步骤102,对N种类型的复杂模型进行融合,得到训练好的教师模型。
[0020]步骤103,基于训练样本、教师模型以及学生模型的损失函数,对学生模型进行训
练。学生模型的损失函数由第一损失函数和第二损失函数融合得到,第一损失函数用于计算学生模型的预测值与真实值的损失,第二损失函数用于计算学生模型的logit值与教师模型的logit值的损失。其中,训练样本包含样本输入和样本输出,学生模型在接收样本输入后输出预测值且学生模型中的logit层输出logit值,真实值为样本输出;教师模型在接收样本输入后,教师模型中的logit层输出logit值。其中,学生模型中的logit层是学生模型中的全连接层,教师模型中的logit层是教师模型中的全连接层。
[0021]本专利技术实施例中,在基于知识蒸馏的方式的模型压缩过程中:教师模型由N种类型的复杂模型融合得到,可以汲取多种类型的复杂模型的优点,使得教师模型更加全面;在学生模型的损失函数也是由第一损失函数和第二损失函数融合得到,第一损失函数用于计算所述学生模型的预测值与真实值的损失,实现基于硬目标的训练;第二损失函数用于计算所述学生模型的logit值与所述教师模型的logit值的损失,实现基于软目标的训练;使得学生模型的损失函数是融合了基于硬目标和基于软目标的训练,训练精度本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型压缩方法,其特征在于,包括:提供训练好的N种类型的复杂模型;N为大于或等于2的整数;对所述N种类型的复杂模型进行融合,得到训练好的教师模型;基于训练样本、所述教师模型以及学生模型的损失函数,对所述学生模型进行训练;所述学生模型的损失函数由第一损失函数和第二损失函数融合得到,所述第一损失函数用于计算所述学生模型的预测值与真实值的损失,所述第二损失函数用于计算所述学生模型的logit值与所述教师模型的logit值的损失;其中,所述训练样本包含样本输入和样本输出,所述学生模型在接收所述样本输入后输出所述预测值且所述学生模型中的logit层输出所述logit值,所述真实值为所述样本输出;所述教师模型在接收所述样本输入后,所述教师模型中的logit层输出所述logit值。2.根据权利要求1所述的模型压缩方法,其特征在于,每种类型的复杂模型基于K折交叉验证训练得到,训练好的所述每种类型的复杂模型包括训练好的属于该种类型的K个复杂模型;K为大于或等于2的整数;所述对N种类型的复杂模型进行融合,得到训练好的教师模型,包括:对于所述每种类型的复杂模型,将所述K个复杂模型的K个logit层进行融合,得到所述每种类型的复杂模型的logit层;将所述N种类型的复杂模型的N个logit层进行融合,作为所述教师模型的logit层。3.根据权利要求2所述的模型压缩方法,其特征在于,所述对于所述每种类型的复杂模型,将所述K个复杂模型的K个logit层进行融合,得到所述每种类型的复杂模型的logit层,包括:对于所述每种类型的复杂模型,将所述K个复杂模型的K个logit层输出的K个logit值相加后取平均,作为所述每种类型的复杂模型的logit层输出的logit值;所述将所述N种类型的复杂模型的N个logit层进行融合,作为所述教师模型的logit层,包括:将所述N种类型的复杂模型的N个logit层输出的N个logit值相加后取平均,作为所述教师模型的logit层输出的logit值。4.根据权利要求1所述的模型压缩方法,其特征在于,在对所述学生模型的训练中,若所述训练样本首次被选择,向所述教师模型输入所述样本输入后,得到所述教师模型的logit值,并将所述教师模型的logit值保存在预设的存储单元;若所述训练样本...

【专利技术属性】
技术研发人员:陈贝
申请(专利权)人:达闼机器人股份有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1