一种模型训练方法、装置、电子设备及存储介质制造方法及图纸

技术编号:35654550 阅读:21 留言:0更新日期:2022-11-19 16:50
本发明专利技术的实施例提供了一种模型训练方法、装置、电子设备及存储介质,方法包括:通过确定教师模型和学生模型,确定初始训练样本数据,逐步将学生模型中的第二模块将教师模型中的第一模块进行替换,每次替换后均进行训练,得到新的教师模型,直到最新得到的新的教师模型中的第一模块均被学生模型中的第二模块替换,得到训练好的目标模型,实现逐步用学生模型的模块替换掉教师模型中的模块,并训练替换模块后的教师模型,从而实现学生模型学习迁移来自教师模型的监督信息,有效降低学生模型学习所需要的训练数据数量,减少训练时间并且提高学生模型的精度。生模型的精度。生模型的精度。

【技术实现步骤摘要】
一种模型训练方法、装置、电子设备及存储介质


[0001]本专利技术涉及模型训练
,具体而言,涉及一种模型训练方法、装置、电子设备及存储介质。

技术介绍

[0002]随着人工智能技术的发展,知识蒸馏技术在模型训练过程中的应用越来越广泛。其中,知识蒸馏是一种采用预先训练好的结构复杂的教师模型(Teacher Model)来训练结构简单的学生模型(Student Model),以实现将教师模型功能赋予学生模型的技术。那么,如何基于知识蒸馏技术,高精度的训练学生模型至关重要。

技术实现思路

[0003]本专利技术的目的在于提供一种模型训练方法、装置、电子设备及存储介质,能够提高训练学生模型的精度。
[0004]为了实现上述目的,本专利技术实施例采用的技术方案如下:
[0005]第一方面,本专利技术实施例提供了一种模型训练方法,所述方法包括:
[0006]确定教师模型和学生模型;
[0007]确定初始训练样本数据,其中,所述初始训练样本数据为训练所述教师模型所使用的训练样本数据;
[0008]将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型,其中,所述教师模型包括多个第一模块,所述学生模型包括多个第二模块;
[0009]基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型;
[0010]返回执行所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型至所述基于所述初始训练样本数据对所述更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被所述学生模型中的第二模块替换,得到训练好的目标模型,其中,所述目标模型中的模块为所述学生模型中的第二模块。
[0011]在可选的实施方式中,所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型的步骤,包括:
[0012]基于伯努利分布方式,控制将所述教师模型中的第一模块替换为与所述学生模型中对应的第二模块的替换概率;
[0013]基于所述替换概率,将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型。
[0014]在可选的实施方式中,所述伯努利分布方式满足以下公式:
[0015]p
d
=min(1,θ(t))=min(1,kt=b);
[0016]其中,b是初始替换率,k是大于0的系数,t是替换次数。
[0017]在可选的实施方式中,所述方法还包括:
[0018]将待检测数据输入至所述目标模型,得到预测数据;
[0019]将所述预测数据进行清洗,得到第一训练数据;
[0020]基于所述第一训练数据对所述目标模型进行训练。
[0021]在可选的实施方式中,所述基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型的步骤,包括:
[0022]基于交叉熵损失函数,确定所述初始训练样本数据的真实标签与预测标签的损失;
[0023]基于所述损失对所述更新后的教师模型的参数进行调整,以获得新的教师模型;
[0024]返回执行所述基于交叉熵损失函数,确定所述初始训练样本数据的真实标签与预测标签的损失至所述基于所述损失对所述更新后的教师模型的参数进行调整,以获得新的教师模型的步骤,直至达到预设训练次数,得到新的教师模型。
[0025]在可选的实施方式中,所述交叉熵损失函数满足以下公式:
[0026]L=


j∈|X|

c∈C
[[z
j
=c]·
log P(z
j
=c|x
j
)];
[0027]其中x
j
∈X为第j个初始训练样本,X为初始训练样本集合,z
j
为初始训练样本的真实标签,c为初始样本的类标签,C为初始训练样本集合的类标签集合,P为初始训练样本的真实标签与预测标签的概率差值。
[0028]在可选的实施方式中,所述将所述预测数据进行清洗,得到第一训练数据的步骤,包括:
[0029]确定所述预测数据的置信度值;
[0030]将置信度小于阈值的第一预测数据进行人工审核;
[0031]接收人工审核后的第一预测数据;
[0032]将人工审核后的所述第一预测数据,作为第一训练数据。
[0033]第二方面,本专利技术实施例提供了一种模型训练装置,所述装置包括:
[0034]第一确定模块,用于确定教师模型和学生模型;
[0035]第二确定模块,用于确定初始训练样本数据,其中,所述初始训练样本数据为训练所述教师模型所使用的训练样本数据;
[0036]替换模块,用于将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型,其中,所述教师模型包括多个第一模块,所述学生模型包括多个第二模块;
[0037]训练模块,用于基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型;
[0038]执行模块,用于返回执行所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型至所述基于所述初始训练样本数据对所述更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被所述学生模型中的第二模块替换,得到训练好的目标模型,其中,所述目标模型中的模块为所述学生模型中的第二模块。
[0039]第三方面,本专利技术实施例提供了一种电子设备,包括存储器和处理器,所述存储器
存储有计算机程序,所述处理器执行所述计算机程序时实现所述模型训练方法的步骤。
[0040]第四方面,本专利技术实施例提供了一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现所述模型训练方法的步骤。
[0041]本专利技术具有以下有益效果:
[0042]本专利技术通过确定教师模型和学生模型,确定初始训练样本数据,将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型,基于初始训练样本数据对更新后的教师模型进行训练,得到新的教师模型,返回执行将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型至基于初始训练样本数据对更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被学生模型中的第二模块替换,得到训练好的目标模型,实现逐步用学生模型的模块替换掉教师模型中的模块,并训练替换模块后的教师模型,从而实现学生模型学习迁移来自教师模型的监督信息,有效降低学生模型学习所需要的训练数据数量,减少训练时间并且提高学生模型的精度。
附图说明
[0043]为了更清楚地说明本专利技术实施例的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述方法包括:确定教师模型和学生模型;确定初始训练样本数据,其中,所述初始训练样本数据为训练所述教师模型所使用的训练样本数据;将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型,其中,所述教师模型包括多个第一模块,所述学生模型包括多个第二模块;基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型;返回执行所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型至所述基于所述初始训练样本数据对所述更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被所述学生模型中的第二模块替换,得到训练好的目标模型,其中,所述目标模型中的模块为所述学生模型中的第二模块。2.根据权利要求1所述的方法,其特征在于,所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型的步骤,包括:基于伯努利分布方式,控制将所述教师模型中的第一模块替换为与所述学生模型中对应的第二模块的替换概率;基于所述替换概率,将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型。3.根据权利要求2所述的方法,其特征在于,所述伯努利分布方式满足以下公式:p
d
=min(1,θ(t))=min(1,kt=b);其中,b是初始替换率,k是大于0的系数,t是替换次数。4.根据权利要求1所述的方法,其特征在于,所述方法还包括:将待检测数据输入至所述目标模型,得到预测数据;将所述预测数据进行清洗,得到第一训练数据;基于所述第一训练数据对所述目标模型进行训练。5.根据权利要求1所述的方法,其特征在于,所述基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型的步骤,包括:基于交叉熵损失函数,确定所述初始训练样本数据的真实标签与预测标签的损失;基于所述损失对所述更新后的教师模型的参数进行调整,以获得新的教师模型;返回执行所述基于交叉熵损失函数,确定所述初始训练样本数据的真实标签与预测标签的损失至所述基于所述损失对所述更新后的教师模型的参数进行调整,以获得新的教师模型的步骤,直至达到预设训练次数,得到新的教师模型。6.根据权利要求5所...

【专利技术属性】
技术研发人员:牟波
申请(专利权)人:成都知道创宇信息技术有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1