当前位置: 首页 > 专利查询>中山大学专利>正文

基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置制造方法及图纸

技术编号:34860484 阅读:34 留言:0更新日期:2022-09-08 08:03
本发明专利技术公开了一种基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置,方法包括下述步骤:获取无监督目标域自然样本集;构建鲁棒无监督域自适应图像分类框架,包括非鲁棒目标域教师模型和鲁棒目标域学生模型;使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练;构造鲁棒目标域学生模型,在无监督目标域自然样本集上进行对抗蒸馏训练,输出图像分类结果。本方法将知识蒸馏和对抗训练结合起来,在源域数据完全缺失的情况下,只使用非鲁棒源域模型获得目标域上的鲁棒模型,在保持对目标域自然样本分类性能的同时,有效地提升了对目标域对抗样本的分类性能和模型的鲁棒性。和模型的鲁棒性。和模型的鲁棒性。

【技术实现步骤摘要】
基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置


[0001]本专利技术属于计算机图像分类的
,具体涉及一种基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置。

技术介绍

[0002]无监督域自适应学习能将知识从标记的源域转移到未标记的目标域,在标签稀缺或注释繁琐的场景中推进模型转移。然而,由于数据隐私和安全问题,在域适应阶段可能无法访问源域数据,并且同时使用目标域数据与大规模源数据训练目标域上模型,在计算上也很棘手。因此,无源域数据无监督域自适应学习应运而生,如一种无源域数据的无监督域自适应学习模型(SHOT模型)。尽管现有研究取得了显著的进展,但大多数现有的无监督域自适应或无源域数据无监督域自适应方法都忽略了深度学习模型的鲁棒性,这些模型对输入图片中难以察觉的扰动很敏感并且在对抗样本面前表现十分脆弱;特别地,由于它们是在没有对目标域进行精确监督的情况下进行的乐观训练,因此(无源域数据的)无监督域自适应模型可能对这种扰动更加敏感,加剧了模型的脆弱性并对安全敏感的应用程序构成巨大威胁。
[0003]现有研究中,一方面通过从鲁棒的源域模型或鲁棒的预训练模型转移鲁棒性来训练鲁棒的无监督域自适应模型,以此提高鲁棒性;但在许多实际应用中,假设鲁棒源域模型或鲁棒预训练的可用性是不切实际的,因此很难直接应用到分类任务中。另一方面提高模型鲁棒性的方法是进行对抗训练,但是对抗训练会导致严重的过拟合现象,极大地影响了分类结果有效性。

技术实现思路

[0004]本专利技术的主要目的在于克服现有技术的缺点与不足,提供一种基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置,该方法将知识蒸馏和对抗训练结合起来,在源域数据完全缺失的情况下,只使用非鲁棒源域模型获得目标域上的鲁棒模型,在保持对目标域自然样本分类性能的同时,有效地提升了对目标域对抗样本的分类性能和模型的鲁棒性。
[0005]为了达到上述目的,本专利技术采用以下技术方案:
[0006]一方面,本专利技术提供了一种基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,包括下述步骤:
[0007]获取无监督目标域自然样本集;
[0008]构建鲁棒无监督域自适应图像分类框架;所述鲁棒无监督域自适应图像分类框架包括非鲁棒目标域教师模型和鲁棒目标域学生模型;
[0009]使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型;
[0010]基于训练好的非鲁棒目标域教师模型构造鲁棒目标域学生模型,在无监督目标域
自然样本集上进行对抗蒸馏训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
[0011]作为优选的技术方案,所述鲁棒无监督域自适应图像分类框架的目标函数是基于间隔差异散度在无源域数据条件下进行推导而得,具体为:
[0012]根据间隔学习理论可得,对于任意一个得分函数f,都满足:
[0013][0014]其中是一个理想间隔损失,是得分函数f在目标对抗域上基于0

1损失的分类误差,是得分函数f在源域上以常数ρ为间隔的分类误差,是以常数ρ为间隔的源域和目标域的间隔差异散度,是以常数ρ为间隔的目标域和目标对抗域的间隔差异散度;
[0015]令在目标对抗域上基于0

1损失的分类误差达到最小的最优得分函数f,故根据(1)式的右端项得:
[0016][0017]在源域数据完全缺失的条件下可知,得分函数f在源域上以常数ρ为间隔的分类误差是常数,故根据(2)式推导出鲁棒无监督域自适应图像分类框架的目标函数为:
[0018][0019]其中,为非鲁棒目标域教师模型的目标函数,为鲁棒目标域学生模型的目标函数。
[0020]作为优选的技术方案,所述得到训练好的非鲁棒目标域教师模型,具体为:
[0021]采用不使用源域数据的无监督域自适应学习模型进行标准的无监督域自适应学习,获得非鲁棒目标域教师模型;
[0022]使用预训练的非鲁棒源域模型的参数对非鲁棒目标域教师模型的参数进行初始化;
[0023]将无监督目标域自然样本集输入非鲁棒目标域教师模型中进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型。
[0024]作为优选的技术方案,所述得到训练好的鲁棒目标域学生模型,具体为:
[0025]采用和非鲁棒目标域教师模型相同的结构构造鲁棒目标域学生模型;
[0026]根据鲁棒目标域学生模型的参数信息,对无监督目标域自然样本集中每个自然样本生成对应的对抗样本;
[0027]进行对抗蒸馏训练,每次迭代训练过程中,将非鲁棒目标域教师模型的参数固定,对鲁棒目标域学生模型进行端到端的训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
[0028]作为优选的技术方案,基于面向鲁棒性与准确性权衡的对抗训练方法TRADES进行
对抗蒸馏训练;
[0029]所述对抗样本的生成公式为:
[0030][0031]所述对抗蒸馏损失函数根据非鲁棒目标域教师模型的输出和鲁棒目标域学生模型的输出进行建立,表示为:
[0032][0033]其中φ表示鲁棒目标域学生模型,φ
T
表示非鲁棒目标域教师模型,x是无监督目标域自然样本集中的某一自然样本,x'是x对应生成的对抗样本,是KL散度损失函数,β是常系数,p是p

范数,∈是常数范围。
[0034]作为优选的技术方案,使用投影梯度下降的对抗训练方法PGD进行对抗蒸馏训练;
[0035]所述对抗样本的生成公式为:
[0036][0037]所述对抗蒸馏损失函数根据非鲁棒目标域教师模型的输出和鲁棒目标域学生模型的输出进行建立,表示为:
[0038][0039]其中φ表示鲁棒目标域学生模型,φ
T
表示非鲁棒目标域教师模型,x是无监督目标域自然样本集中的某一自然样本,x'是x对应生成的对抗样本,是KL散度损失函数,β是常系数,p是p

范数,∈是常数范围。
[0040]另一方面,本专利技术还提供了一种基于对抗蒸馏的鲁棒无监督域自适应图像分类系统,其特征在于,应用于上述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,包括数据获取模块、分类框架构建模块、教师模型训练模块及学生模型训练模块;
[0041]所述数据获取模块用于获取无监督目标域自然样本集;
[0042]所述分类框架构建模块用于构建鲁棒无监督域自适应图像分类框架;所述鲁棒无监督域自适应图像分类框架包括非鲁棒目标域教师模型和鲁棒目标域学生模型;
[0043]所述教师模型训练模块用于使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型;
[0044]所述学生模型训练模块用于基于训练好的非鲁棒目标本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,包括下述步骤:获取无监督目标域自然样本集;构建鲁棒无监督域自适应图像分类框架;所述鲁棒无监督域自适应图像分类框架包括非鲁棒目标域教师模型和鲁棒目标域学生模型;使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型;基于训练好的非鲁棒目标域教师模型构造鲁棒目标域学生模型,在无监督目标域自然样本集上进行对抗蒸馏训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。2.根据权利要求1所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,所述鲁棒无监督域自适应图像分类框架的目标函数是基于间隔差异散度在无源域数据条件下进行推导而得,具体为:根据间隔学习理论可得,对于任意一个得分函数f,都满足:其中是一个理想间隔损失,是得分函数f在目标对抗域上基于0

1损失的分类误差,是得分函数f在源域上以常数ρ为间隔的分类误差,是以常数ρ为间隔的源域和目标域的间隔差异散度,是以常数ρ为间隔的目标域和目标对抗域的间隔差异散度;令在目标对抗域上基于0

1损失的分类误差达到最小的最优得分函数f,故根据(1)式的右端项得:在源域数据完全缺失的条件下可知,得分函数f在源域上以常数ρ为间隔的分类误差是常数,故根据(2)式推导出鲁棒无监督域自适应图像分类框架的目标函数为:其中,为非鲁棒目标域教师模型的目标函数,为鲁棒目标域学生模型的目标函数。3.根据权利要求2所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,所述得到训练好的非鲁棒目标域教师模型,具体为:采用不使用源域数据的无监督域自适应学习模型进行标准的无监督域自适应学习,获得非鲁棒目标域教师模型;使用预训练的非鲁棒源域模型的参数对非鲁棒目标域教师模型的参数进行初始化;将无监督目标域自然样本集输入非鲁棒目标域教师模型中进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型。4.根据权利要求3所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,所述得到训练好的鲁棒目标域学生模型,具体为:
采用和非鲁棒目标域教师模型相同的结构构造鲁棒目标域学生模型;根据鲁棒目标域学生模型的参数信息,对无监督目标域自然样本集中每个自然样本生成对应的对抗样本;进行对抗蒸馏训练,每次迭代训练过程中,将非鲁棒目标域教师模型的参数固定,对鲁棒目标域学生模型进行端到端的训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。5.根据权利要求4所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,基于面向鲁棒性与准确性权衡的对抗训练方法TRADES进行对抗蒸馏训练;所述对抗样本的生成公式为:所述对抗蒸馏损失函数根据非鲁棒目标域教师模型的输出和鲁棒目标域学生模型的输出进行建立,表示为:其中φ表示鲁棒目标域学生模型,φ

【专利技术属性】
技术研发人员:肖遥罗彬陈宇恒魏朋旭林倞
申请(专利权)人:中山大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1