当前位置: 首页 > 专利查询>清华大学专利>正文

模型训练方法、装置、存储介质及电子设备制造方法及图纸

技术编号:29759615 阅读:26 留言:0更新日期:2021-08-20 21:13
本公开关于一种模型训练方法、装置、存储介质及电子设备,上述模型训练方法应用于关键点检测模型,上述关键点检测模型包括:特征提取器,回归器和对抗回归器,上述模型训练方法包括:确定上述关键点检测模型中待更新的目标对象,其中,上述目标对象包括:上述特征提取器,上述回归器和上述对抗回归器;对上述目标对象的网络参数进行更新,以使上述关键点检测模型从源域向目标域迁移,以至少解决相关技术中的域自适应方法无法实现关键点检测模型在不同的数据域之间进行迁移的问题。

【技术实现步骤摘要】
模型训练方法、装置、存储介质及电子设备
本公开涉及模型训练
,尤其涉及一种模型训练方法、装置、存储介质及电子设备。
技术介绍
相关技术中,针对深度学习模型的分类标注,在真实场景中存在关键点标注的难度大、成本高,而虚拟场景下存在关键点标注难度小、成本低的问题,因此,目前的前景方向是将虚拟数据训练出的模型迁移到真实数据上。但是,由于虚拟数据和真实数据之间存在着巨大的领域差距(domaingap),而且当前的域自适应方法(domainadaptation)主要适用于分类问题,难以在关键点检测类的回归问题上带来帮助,无法实现关键点检测模型在不同的数据域之间进行迁移。
技术实现思路
本公开提供一种模型训练方法、装置、存储介质及电子设备,以至少解决相关技术中的域自适应方法无法实现关键点检测模型在不同的数据域之间进行迁移的问题。本公开的技术方案如下:根据本公开实施例的第一方面,提供一种模型训练方法,上述模型训练方法应用于关键点检测模型,上述关键点检测模型包括:特征提取器,回归器和对抗回归器,上述模型训练方法包括:确定上述关键点检测模型中待更新的目标对象,其中,上述目标对象包括:上述特征提取器,上述回归器和上述对抗回归器;对上述目标对象的网络参数进行更新,以使上述关键点检测模型从源域向目标域迁移。可选的,上述对上述目标对象的网络参数进行更新的步骤包括:对上述特征提取器,上述回归器和上述对抗回归器中的第一网络参数进行更新。可选的,上述对上述特征提取器,上述回归器和上述对抗回归器中的第一网络参数进行更新的步骤包括:对上述回归器和上述对抗回归器输出的热力图进行归一化处理,得到第一处理结果;计算上述源域上每个关键点标签对应的热力图与上述源域上全部关键点标签对应的热力图总和的比值,得到第二处理结果;在空间维度上对上述第一处理结果与上述第二处理结果进行散度计算,得到第三处理结果;基于上述第三处理结果更新上述特征提取器,上述回归器和上述对抗回归器中的第一网络参数,以重新获取上述第三处理结果,直至当上述第三处理结果满足第一预设条件时,停止更新上述第一网络参数。可选的,上述对上述目标对象的网络参数进行更新的步骤包括:对上述对抗回归器中的第二网络参数进行更新,以使得最大化上述对抗回归器在上述目标域的散度和上述回归器预测的散度的差异。可选的,上述对上述对抗回归器中的第二网络参数进行更新的步骤包括:获取上述回归器在上述目标域上输出的与多个关键点对应的多个热力图,并计算上述多个关键点中除目标关键点之外其余关键点对应的剩余热力图总和;计算上述目标关键点对应的第一热力图与上述剩余热力图总和的比值,得到第四处理结果;获取上述对抗回归器在上述目标域上输出的与上述第一热力图对应的第二热力图,并对上述第二热力图进行归一化处理,得到第五处理结果;在空间维度上对上述第四处理结果与上述第五处理结果进行散度计算,得到第六处理结果;基于上述第六处理结果更新上述对抗回归器中的第二网络参数,以重新获取上述第六处理结果,直至当上述第六处理结果满足第二预设条件时,停止更新上述第二网络参数。可选的,上述对上述目标对象的网络参数进行更新的步骤包括:对上述特征提取器中的第三网络参数进行更新,以使得最小化上述对抗回归器在上述目标域的散度和上述回归器预测的散度的差异。可选的,上述对上述特征提取器中的第三网络参数进行更新的步骤包括:获取上述回归器在上述目标域上输出的与多个关键点对应的多个热力图,并计算上述多个关键点中除目标关键点之外其余关键点对应的剩余热力图总和;计算上述目标关键点对应的第一热力图与上述剩余热力图总和的比值,得到第七处理结果;获取上述对抗回归器在上述目标域上输出的与上述第一热力图对应的第二热力图,并对上述第二热力图进行归一化处理,得到第八处理结果;在空间维度上对上述第七处理结果与上述第八处理结果进行散度计算,得到第九处理结果;基于上述第九处理结果更新上述特征提取器中的第三网络参数,以重新获取上述第九处理结果,直至当上述第九处理结果满足第三预设条件时,停止更新上述第三网络参数。根据本公开实施例的第二方面,提供一种模型训练装置,上述模型训练装置应用于关键点检测模型,上述关键点检测模型包括:特征提取器,回归器和对抗回归器,上述模型训练装置包括:确定单元,被配置为执行确定上述关键点检测模型中待更新的目标对象,其中,上述目标对象包括:上述特征提取器,上述回归器和上述对抗回归器;更新单元,被配置为执行对上述目标对象的网络参数进行更新,以使上述关键点检测模型从源域向目标域迁移。可选的,上述更新单元包括:第一更新子单元,被配置为执行对上述特征提取器,上述回归器和上述对抗回归器中的第一网络参数进行更新。可选的,上述第一更新子单元包括:第一处理子单元,被配置为执行对上述回归器和上述对抗回归器输出的热力图进行归一化处理,得到第一处理结果;第一计算子单元,被配置为执行计算上述源域上每个关键点标签对应的热力图与上述源域上全部关键点标签对应的热力图总和的比值,得到第二处理结果;第二计算单元,被配置为执行在空间维度上对上述第一处理结果与上述第二处理结果进行散度计算,得到第三处理结果;第三计算单元,被配置为执行基于上述第三处理结果更新上述特征提取器,上述回归器和上述对抗回归器中的第一网络参数,以重新获取上述第三处理结果,直至当上述第三处理结果满足第一预设条件时,停止更新上述第一网络参数。可选的,上述更新单元包括:第二更新子单元,被配置为执行对上述对抗回归器中的第二网络参数进行更新,以使得最大化上述对抗回归器在上述目标域的散度和上述回归器预测的散度的差异。可选的,上述第二更新子单元包括:获取子单元,被配置为执行获取上述回归器在上述目标域上输出的与多个关键点对应的多个热力图,并计算上述多个关键点中除目标关键点之外其余关键点对应的剩余热力图总和;第四计算单元,被配置为执行计算上述目标关键点对应的第一热力图与上述剩余热力图总和的比值,得到第四处理结果;第二处理子单元,被配置为执行获取上述对抗回归器在上述目标域上输出的与上述第一热力图对应的第二热力图,并对上述第二热力图进行归一化处理,得到第五处理结果;第五计算子单元,被配置为执行在空间维度上对上述第四处理结果与上述第五处理结果进行散度计算,得到第六处理结果;第三更新子单元,被配置为执行基于上述第六处理结果更新上述对抗回归器中的第二网络参数,以重新获取上述第六处理结果,直至当上述第六处理结果满足第二预设条件时,停止更新上述第二网络参数。可选的,上述更新单元包括:第四更新子单元,被配置为执行对上述特征提取器中的第三网络参数进行更新,以使得最小化上述对抗回归器在上述目标域的散度和上述回归器预测的散度的差异。可选的,上述第四更新子单元包括:第六计算子单元,被配置为执行获取上述回归器在上述目标域上输出的与多个关键点对应的多个热力图,并计算上述多个关键点中除目标关键点之外其余关键点对应的剩余热力图总和;第七计算子单元,被配置为执行计算上述目标关键点对应的第一热力图与上述剩余本文档来自技高网...

【技术保护点】
1.一种模型训练方法,其特征在于,所述模型训练方法应用于关键点检测模型,所述关键点检测模型包括:特征提取器,回归器和对抗回归器,所述模型训练方法包括:/n确定所述关键点检测模型中待更新的目标对象,其中,所述目标对象包括:所述特征提取器,所述回归器和所述对抗回归器;/n对所述目标对象的网络参数进行更新,以使所述关键点检测模型从源域向目标域迁移。/n

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述模型训练方法应用于关键点检测模型,所述关键点检测模型包括:特征提取器,回归器和对抗回归器,所述模型训练方法包括:
确定所述关键点检测模型中待更新的目标对象,其中,所述目标对象包括:所述特征提取器,所述回归器和所述对抗回归器;
对所述目标对象的网络参数进行更新,以使所述关键点检测模型从源域向目标域迁移。


2.根据权利要求1所述的模型训练方法,其特征在于,所述对所述目标对象的网络参数进行更新的步骤包括:
对所述特征提取器,所述回归器和所述对抗回归器中的第一网络参数进行更新。


3.根据权利要求2所述的模型训练方法,其特征在于,所述对所述特征提取器,所述回归器和所述对抗回归器中的第一网络参数进行更新的步骤包括:
对所述回归器和所述对抗回归器输出的热力图进行归一化处理,得到第一处理结果;
计算所述源域上每个关键点标签对应的热力图与所述源域上全部关键点标签对应的热力图总和的比值,得到第二处理结果;
在空间维度上对所述第一处理结果与所述第二处理结果进行散度计算,得到第三处理结果;
基于所述第三处理结果更新所述特征提取器,所述回归器和所述对抗回归器中的第一网络参数,以重新获取所述第三处理结果,直至当所述第三处理结果满足第一预设条件时,停止更新所述第一网络参数。


4.根据权利要求1所述的模型训练方法,其特征在于,所述对所述目标对象的网络参数进行更新的步骤包括:
对所述对抗回归器中的第二网络参数进行更新,以使得最大化所述对抗回归器在所述目标域的散度和所述回归器预测的散度的差异。


5.根据权利要求4所述的模型训练方法,其特征在于,所述对所述对抗回归器中的第二网络参数进行更新的步骤包括:
获取所述回归器在所述目标域上输出的与多个关键点对应的多个热力图,并计算所述多个关键点中除目标关键点之外其余关键点对应的剩余热力图总和;
计算所述目标关键点对应的第一热力图与所述剩余热力图总和的比值,得到第四处理结果;
获取所述对抗回归器在所述目标域上输出的与所述第一热力图对应的第二热力图,并对所述第二热力图进行归一化处理,得到第五处理结果;
在空间维度上对所述第四处理结果与所述第五处理结果进行散度计算...

【专利技术属性】
技术研发人员:龙明盛江俊广刘裕峰蔡东阳郭小燕郑文王建民
申请(专利权)人:清华大学北京达佳互联信息技术有限公司
类型:发明
国别省市:北京;11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1