图像识别模型训练方法、识别方法、设备、介质及产品技术

技术编号:38257023 阅读:17 留言:0更新日期:2023-07-27 10:19
本申请提供一种图像识别模型训练方法、识别方法、设备、介质及产品,涉及深度学习技术领域。该方法包括:获取当前次训练的样本图像以及识别标签;根据教师模型和学生模型粉笔对样本图像进行识别,得到第一识别结果和第二识别结果;根据第一识别结果和第二识别结果计算蒸馏损失;根据第二识别结果和识别标签计算任务损失;对蒸馏损失的梯度和蒸馏损失的累积梯度加权求和,对任务损失的梯度和任务损失的累积梯度加权求和,蒸馏损失的累积梯度为第一权重,任务损失的累积梯度为第二权重,第一权重大于第二权重;根据新的蒸馏损失的累积梯度和新的任务损失的累积梯度计算的总梯度更新学生模型的参数。本申请可以提高学生模型的训练效果。效果。效果。

【技术实现步骤摘要】
图像识别模型训练方法、识别方法、设备、介质及产品


[0001]本专利技术涉及深度学习
,具体而言,涉及一种图像识别模型训练方法、识别方法、设备、介质及产品。

技术介绍

[0002]知识蒸馏是深度学习领域常用的模型压缩方法。
[0003]在将在计算机设备上训练好的体积较大的模型部署到移动终端设备进行图像识别时,由于移动终端设备的计算能力和速度都比计算机设备差很多,移动终端设备可以无法支持体积较大的模型进行图像识别,因此,需要采用知识蒸馏方法,将体积较大的模型作为教师模型,辅助训练一个体积较少的学生模型,实现与教师模型相同的图像识别功能,将该学生模型部署在移动终端设备上,即可以保证移动终端设备的稳定运行,也可以实现图像识别。
[0004]但是现有的知识蒸馏方法在利用教师模型训练学生模型时,存在学生模型收敛效果差导致学生模型的图像识别准确度不佳的问题。

技术实现思路

[0005]本专利技术的目的在于,针对上述现有技术中的不足,提供一种图像识别模型训练方法、识别方法、设备、介质及产品,以便提高学生模型的训练效果。
[0006]为实现上述目的,本申请实施例采用的技术方案如下:
[0007]第一方面,本申请实施例提供了一种图像识别模型训练方法,所述方法包括:
[0008]获取当前次训练的样本图像以及表征所述样本图像对应的真实识别结果的识别标签;
[0009]根据预先训练的教师模型对当前次训练的样本图像进行识别,得到第一识别结果;以及,根据学生模型对当前次训练的样本图像进行识别,得到第二识别结果;其中,所述学生模型为待训练的图像识别模型;
[0010]根据所述第一识别结果和所述第二识别结果计算表征二者差异的当前次训练的蒸馏损失;以及,根据所述第二识别结果和所述识别标签计算表征二者差异的当前次训练的任务损失;
[0011]对当前次训练的蒸馏损失的梯度和蒸馏损失的累积梯度加权求和,得到新的蒸馏损失的累积梯度;以及,对当前次训练的任务损失的梯度和任务损失的累积梯度加权求和,得到新的任务损失的累积梯度;其中,蒸馏损失的累积梯度对应的加权权重为第一权重,任务损失的累积梯度对应的加权权重为第二权重,且所述第一权重大于所述第二权重;
[0012]根据新的蒸馏损失的累积梯度和新的任务损失的累积梯度计算总梯度,并根据所述总梯度更新所述学生模型的参数。
[0013]可选的,所述图像识别模型的训练过程包括多次训练,每次训练均执行从所述获取当前次训练的样本图像以及表征所述样本图像对应的真实识别结果的识别标签,至所述
根据所述总梯度更新所述学生模型的参数的步骤。
[0014]可选的,所述图像识别模型的训练过程包括多次训练,以是否满足目标条件为分界线,对于满足所述目标条件之后的每次训练,执行从所述获取当前次训练的样本图像以及表征所述样本图像对应的真实识别结果的识别标签,至所述根据所述总梯度更新所述学生模型的参数的步骤,对于满足所述目标条件之前的每次训练,执行以下步骤:
[0015]获取当前次训练的样本图像以及识别标签;
[0016]根据预先训练的教师模型对当前次训练的样本图像进行识别,得到第一识别结果;以及,根据待训练学生模型对当前次训练的样本图像进行识别,得到第二识别结果;
[0017]根据所述第一识别结果和所述第二识别结果计算当前次训练的蒸馏损失;以及,根据所述第二识别结果和所述识别标签计算当前次训练的任务损失;
[0018]对当前次训练的蒸馏损失的梯度和蒸馏损失的累积梯度求和,得到新的蒸馏损失的累积梯度;以及,对当前次训练的任务损失的梯度和任务损失的累积梯度求和,得到新的任务损失的累积梯度;
[0019]根据新的蒸馏损失的累积梯度和新的任务损失的累积梯度计算总梯度,并根据所述总梯度更新所述学生模型的参数。
[0020]可选的,所述目标条件包括以下至少一个:
[0021]当前次训练对应的训练次数大于训练次数阈值;
[0022]当前次训练对应的训练时间大于训练时间阈值;
[0023]当前次训练的蒸馏损失小于第一损失阈值;
[0024]当前次训练的任务损失小于第二损失阈值。
[0025]可选的,所述第一权重和所述第二权重均根据经验权重确定,所述经验权重小于所述第一权重且大于所述第二权重。
[0026]可选的,所述第一权重为所述经验权重和目标偏差之和,所述第二权重为所述经验权重和目标偏差之差。
[0027]第二方面,本申请实施例还提供一种图像识别方法,所述方法包括:
[0028]获取待识别图像;
[0029]采用图像识别模型对所述待识别图像进行识别,得到图像识别结果;其中,所述图像识别模型利用如第一方面任一项所述的方法训练得到。
[0030]第三方面,本申请实施例提供了一种图像识别模型训练装置,所述装置包括:
[0031]样本获取模块,用于获取当前次训练的样本图像以及表征所述样本图像对应的真实识别结果的识别标签;
[0032]目标识别模块,用于根据预先训练的教师模型对当前次训练的样本图像进行识别,得到第一识别结果;以及,根据学生模型对当前次训练的样本图像进行识别,得到第二识别结果;其中,所述学生模型为待训练的图像识别模型;
[0033]损失计算模块,用于根据所述第一识别结果和所述第二识别结果计算表征二者差异的当前次训练的蒸馏损失;以及,根据所述第二识别结果和所述识别标签计算表征二者差异的当前次训练的任务损失;
[0034]累积梯度更新模块,用于对当前次训练的蒸馏损失的梯度和蒸馏损失的累积梯度加权求和,得到新的蒸馏损失的累积梯度;以及,对当前次训练的任务损失的梯度和任务损
失的累积梯度加权求和,得到新的任务损失的累积梯度;其中,蒸馏损失的累积梯度对应的加权权重为第一权重,任务损失的累积梯度对应的加权权重为第二权重,且所述第一权重大于所述第二权重;
[0035]模型优化模块,用于根据新的蒸馏损失的累积梯度和新的任务损失的累积梯度计算总梯度,并根据所述总梯度更新所述学生模型的参数。
[0036]可选的,所述图像识别模型的训练过程包括多次训练,每次训练均采用所述样本获取模块、所述目标识别模块、所述损失计算模块、所述累积梯度更新模块和所述模型优化模块执行从所述获取当前次训练的样本图像以及表征所述样本图像对应的真实识别结果的识别标签,至所述根据所述总梯度更新所述学生模型的参数的步骤。
[0037]可选的,所述图像识别模型的训练过程包括多次训练,以是否满足目标条件为分界线,对于满足所述目标条件之后的每次训练,采用所述样本获取模块、所述目标识别模块、所述损失计算模块、所述累积梯度更新模块和所述模型优化模块执行从所述获取当前次训练的样本图像以及表征所述样本图像对应的真实识别结果的识别标签,至所述根据所述总梯度更新所述学生模型的参数的步骤,对于满足所述目标条件之前的每次训练,所述样本获取模块、所述目标识别模块、所述本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种图像识别模型训练方法,其特征在于,所述方法包括:获取当前次训练的样本图像以及表征所述样本图像对应的真实识别结果的识别标签;根据预先训练的教师模型对当前次训练的样本图像进行识别,得到第一识别结果;以及,根据学生模型对当前次训练的样本图像进行识别,得到第二识别结果;其中,所述学生模型为待训练的图像识别模型;根据所述第一识别结果和所述第二识别结果计算表征二者差异的当前次训练的蒸馏损失;以及,根据所述第二识别结果和所述识别标签计算表征二者差异的当前次训练的任务损失;对当前次训练的蒸馏损失的梯度和蒸馏损失的累积梯度加权求和,得到新的蒸馏损失的累积梯度;以及,对当前次训练的任务损失的梯度和任务损失的累积梯度加权求和,得到新的任务损失的累积梯度;其中,蒸馏损失的累积梯度对应的加权权重为第一权重,任务损失的累积梯度对应的加权权重为第二权重,且所述第一权重大于所述第二权重;根据新的蒸馏损失的累积梯度和新的任务损失的累积梯度计算总梯度,并根据所述总梯度更新所述学生模型的参数。2.根据权利要求1所述的方法,其特征在于,所述图像识别模型的训练过程包括多次训练,每次训练均执行从所述获取当前次训练的样本图像以及表征所述样本图像对应的真实识别结果的识别标签,至所述根据所述总梯度更新所述学生模型的参数的步骤。3.根据权利要求1所述的方法,其特征在于,所述图像识别模型的训练过程包括多次训练,以是否满足目标条件为分界线,对于满足所述目标条件之后的每次训练,执行从所述获取当前次训练的样本图像以及表征所述样本图像对应的真实识别结果的识别标签,至所述根据所述总梯度更新所述学生模型的参数的步骤,对于满足所述目标条件之前的每次训练,执行以下步骤:获取当前次训练的样本图像以及识别标签;根据预先训练的教师模型对当前次训练的样本图像进行识别,得到第一识别结果;以及,根据待训练学生模型对当前次训练的样本图像进行识别,得到第二识别结果;根据所述第一识别结果和所述第二识别结果计算当...

【专利技术属性】
技术研发人员:赵博睿宋仁杰梁嘉骏
申请(专利权)人:南京旷云科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1