基于知识蒸馏的模型训练方法、装置、设备及存储介质制造方法及图纸

技术编号:34961881 阅读:16 留言:0更新日期:2022-09-17 12:41
本发明专利技术适用于人工智能技术领域,尤其涉及一种基于知识蒸馏的模型训练方法、装置、设备及存储介质,该方法通过获取满足目标条件的第一模型和不满足目标条件的第二模型,根据第一模型的输出构建的优化损失函数,得到更新后的第二模型,计算并确定目标相似度,构建相似度损失函数,并将相似度损失函数与优化损失函数的和作为目标损失函数,对第二模型进行训练,得到满足目标条件的第二模型。将第一模型与第二模型中网络层之间的表征相似度作为第二模型中损失函数的一部分,充分利用了网络中间层的信息,增加了第二模型学习第一模型的能力和范围,从而提高第二模型训练时的稳定性和收敛性。性。性。

【技术实现步骤摘要】
基于知识蒸馏的模型训练方法、装置、设备及存储介质


[0001]本专利技术涉及人工智能领域,尤其涉及一种基于知识蒸馏的模型训练方法、装置、设备及存储介质。

技术介绍

[0002]目前,深度学习神经网络已成功应用于各种计算机视觉应用,如图像分类、对象检测和语义分割,大型的深度学习模型训练必须从非常大的、高度冗余的数据集中训练得到,但是数据集的数据量较大时模型训练需要占据大量的时间和存储空间,因此,为了缩短训练时间和减少资源占据,使用知识蒸馏方法对大型深度学习网络进行压缩得到了广泛运用,知识蒸馏方法对教师网络与学生网络的匹配度要求较高,而当前的知识蒸馏方法只会对训练样本集的标签进行优化,无法应对教师网络与学生网络的匹配程度不高的情况,导致的训练过程不稳定、不收敛的问题。因此,如何改进知识蒸馏的训练过程,以提高学生网络训练过程的稳定性、收敛性成为亟待解决的问题。

技术实现思路

[0003]基于此,有必要针对上述技术问题,提供一种基于知识蒸馏的模型训练方法、装置、设备及存储介质,以解决训练过程中不稳定、不收敛的问题。
[0004]本申请实施例的第一方面提供了一种基于知识蒸馏的模型训练方法,所述方法包括:
[0005]获取满足目标条件的第一模型和不满足目标条件的第二模型,第一模型包括M个网络层,第二模型包括N个网络层,N、M均为大于零的整数;
[0006]根据第一模型的输出构建的优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型;
[0007]计算第一模型中M个网络层分别与更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度;
[0008]根据目标相似度,构建相似度损失函数,并将相似度损失函数与优化损失函数的和作为目标损失函数;
[0009]使用训练集对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。
[0010]本申请实施例的第二方面提供了一种基于知识蒸馏的模型训练装置,所述装置包括:
[0011]获取模型模块,用于获取满足目标条件的第一模型和不满足所述目标条件的第二模型,所述第一模型包括M个网络层,所述第二模型包括N个网络层,N、M均为大于零的整数;
[0012]更新模块,用于根据第一模型的输出构建的优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型;
[0013]目标相似度确定模块,用于计算第一模型中M个网络层分别与更新后的第二模型
中N个网络层的表征相似度,通过预设选取条件,确定目标相似度;
[0014]目标损失函数确定模块,用于根据目标相似度,构建相似度损失函数,并将相似度损失函数与优化损失函数的和作为目标损失函数;
[0015]训练模块,用于使用训练集对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。
[0016]第三方面,本专利技术实施例提供一种计算机设备,所述计算机设备包括处理器、存储器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如第一方面所述的基于知识蒸馏的模型训练方法。
[0017]第四方面,本专利技术实施例提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如第一方面所述的基于知识蒸馏的模型训练方法。
[0018]本专利技术与现有技术相比存在的有益效果是:
[0019]本专利技术通过获取满足目标条件的第一模型和不满足目标条件的第二模型,第一模型包括M个网络层,第二模型包括N个网络层,N、M均为大于零的整数,根据第一模型的输出构建的优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型,计算第一模型中M个网络层分别与更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度,根据目标相似度,构建相似度损失函数,并将相似度损失函数与优化损失函数的和作为目标损失函数,使用训练集对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。将第一模型与第二模型中网络层之间的表征相似度作为第二模型中损失函数的一部分,充分利用了网络中间层的信息,增加了第二模型学习第一模型的能力和范围,从而提高第二模型训练时的稳定性和收敛性。
附图说明
[0020]为了更清楚地说明本专利技术实施例的技术方案,下面将对本专利技术实施例的描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本专利技术的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
[0021]图1是本专利技术一实施例提供的一种基于知识蒸馏的模型训练方法的一应用环境示意图;
[0022]图2是本专利技术一实施例提供的一种基于知识蒸馏的模型训练方法的流程示意图;
[0023]图3是本专利技术一实施例提供的一种基于知识蒸馏的模型训练装置的结构示意图;
[0024]图4是本专利技术一实施例提供的一种计算机设备的结构示意图。
具体实施方式
[0025]下面将结合本专利技术实施例中的附图,对本专利技术实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本专利技术一部分实施例,而不是全部的实施例。基于本专利技术中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本专利技术保护的范围。
[0026]应当理解,当在本专利技术说明书和所附权利要求书中使用时,术语“包括”指示所描
述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
[0027]还应当理解,在本专利技术说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
[0028]如在本专利技术说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当...时”或“一旦”或“响应于确定”或“响应于检测到”。类似地,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述条件或事件]”或“响应于检测到[所描述条件或事件]”。
[0029]另外,在本专利技术说明书和所附权利要求书的描述中,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
[0030]在本专利技术说明书中描述的参考“一个实施例”或“一些实施例”等意味着在本专利技术的一个或多个实施例中包括结合该实施例描述的特定特征、结构或特点。由此,在本说明书中的不同之处出现的语句“在一个实施例中”、“在一些实施例中”、“在其他一些实施例中”、“在另外一些实施例中”等不是必然都参考相同的实施例,而是意味着“一个或多个但不是所有的实施例”,除非是以其他方式另外特别强调。术语“包括”、“包含”、“具有”及它们的变形都意味着“包括但不限于”,除非是以其他方式另外特别强调本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于知识蒸馏的模型训练方法,其特征在于,包括:获取满足目标条件的第一模型和不满足所述目标条件的第二模型,所述第一模型包括M个网络层,所述第二模型包括N个网络层,N、M均为大于零的整数;根据所述第一模型的输出构建的优化损失函数,更新所述第二模型的初始损失函数,得到更新后的第二模型;计算所述第一模型中M个网络层分别与所述更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度;根据所述目标相似度,构建相似度损失函数,并将所述相似度损失函数与所述优化损失函数的和作为目标损失函数;使用训练集对所述第二模型进行训练,直至所述目标损失函数收敛,得到满足所述目标条件的第二模型。2.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述根据所述第一模型的输出构建的优化损失函数,更新所述第二模型的初始损失函数,得到更新后的第二模型,包括:将带有原始标签的第一训练样本输入至所述第一模型中,以所述第一模型输出的新标签更新所述第一训练样本对应的原始标签,得到第二训练样本;利用所述第一训练样本与所述第二训练样本,分别对第二模型进行训练,得到第一知识蒸馏损失函数与第二知识蒸馏损失函数;通过所述第一知识蒸馏损失函数与所述第二知识蒸馏损失函数,构建优化损失函数;根据所述优化损失函数,更新所述第二模型的初始损失函数,得到更新后的第二模型。3.如权利要求2所述的基于知识蒸馏的模型训练方法,其特征在于,所述通过所述第一知识蒸馏损失函数与所述第二知识蒸馏损失函数,构建优化损失函数,包括:对所述第一知识蒸馏损失函数与所述第二知识蒸馏损失函数设置不同的初始参数,得到初始蒸馏损失函数;使用梯度下降算法对所述初始蒸馏损失函数进行参数更新,得到目标参数,使用所述目标参数更新初始蒸馏损失函数,得到优化损失函数。4.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述计算所述第一模型中M个网络层分别与所述更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度,包括:分别获取所述第一模型中M个网络层与所述更新后的第二模型中N个网络层中每个网络层的特征矩阵;计算所述第一模型中M个网络层的特征矩阵分别与所述更新后的第二模型中N个网络层的特征矩阵的表征相似度,得到所述第一模型中每个网络层对应的表征相似度序列;从所述第一模型中每个网络层对应的表征相似...

【专利技术属性】
技术研发人员:张楠王健宗瞿晓阳
申请(专利权)人:平安科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1