System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind()
【技术实现步骤摘要】
本专利技术属于图像分类,具体涉及一种基于自适应对比学习的开放世界图像分类方法。
技术介绍
1、图像分类是计算机视觉的核心任务,其目标是根据图像内容将其归类为一个或者多个类别标签。简单来说,图像分类任务是让计算机“看”一张图像,并判断它所属的类别,该任务被广泛应用于自动驾驶、医疗诊断、智能安防、社交媒体内容推荐等多个领域。
2、基于深度学习的图像分类通常假设训练数据和测试数据具有相同分布,即独立同分布假设。然而,在真实场景中这一假设往往不成立,例如,自动驾驶系统的摄像头采集的数据会因天气、光线变化或突发状况而与训练数据有显著差异;在医学图像分析中,因不同设备的分辨率、病人个体差异等,导致测试数据可能与训练数据分布不一致,这样的分布不一致问题被称为域转移问题。由于域偏移的存在,模型在训练数据上表现良好,但是在真实环境下的测试数据中分类性能会显著下降,尤其是出现新的或者从未见过类别(outofdistribution,简称ood数据)的开放世界情境下更加显著。对于开放世界的图像分类,不仅要求模型能准确分类已知类别的图像,还要能分类未见过类别的图像,以提高分类的鲁棒性和准确性。
3、为了解决域转移问题,测试时训练(test time training,ttt)方法应运而生。ttt在测试阶段利用无标签的目标域数据对模型进行自适应调整,以缩小训练数据和测试数据分布之间的差异,从而提高模型的分类准确性。然而,传统的ttt方法通常假设训练和测试阶段的类别集合保持一致,当测试阶段出现新的类别时,传统的ttt方法则无法很好地
技术实现思路
1、针对现有技术的不足,本专利技术拟解决的技术问题是,提供一种基于自适应对比学习的开放世界图像分类方法。
2、本专利技术解决所述技术问题采用如下的技术方案:
3、一种基于自适应对比学习的开放世界图像分类方法,其特征在于,该方法包括以下步骤:
4、第一步:利用源域数据对预训练分类模型进行训练,得到源域训练分类模型;对预训练分类模型从源域样本中提取的特征向量进行聚类,得到源域中各类别的聚类中心;
5、第二步:利用目标域数据对源域训练分类模型进行测试时训练,根据最终的损失函数计算训练损失;
6、2.1)对目标域样本进行数据增强,生成增强样本;目标域样本和与之对应的增强样本组成正样本对,目标域样本和与之不对应的增强样本组成负样本对,基于样本对的对比学习对源域训练分类模型进行优化;
7、2.2)正、负类别对的对齐;
8、根据式(6)计算目标域样本的强ood评分,将目标域样本的强ood评分与强ood判别阈值进行比较,若强ood评分大于强ood判别阈值,则该目标域样本为强ood样本,否则为弱ood样本;
9、
10、式中,表示目标域样本xi的强ood评分,f′(xi)表示源域训练分类模型f′(·)从目标域样本xi中提取的特征向量,表示源域中类别k的聚类中心,<·>表示余弦相似度,表示从源域类别的聚类中心中寻找与目标域样本所属类别最相似的聚类中心,ds表示源域类别的聚类中心集合;
11、对于弱ood样本,通过生成伪标签的方式将其嵌入到所属类别的聚类中心附近,根据式(9)计算弱ood类别的负对数似然损失lwea_pc;
12、
13、式中,f′(xi_wea)表示源域训练分类模型从弱ood样本xi_wea中提取的特征向量,dwea_a、dwea_b分别表示弱ood类别a、b的聚类中心,δ是超参数,表示类别的伪标签,ys表示源域类别集合;
14、当前批次和源域中相同的类别组成正类别对,不同的类别组成负类别对,根据式(12)计算类别对的nt-xent损失lnt;
15、
16、式中,表示当前批次和源域中类别k组成的正类别对的相似度,分别表示当前批次和源域中类别k的归一化聚类中心向量,表示当前批次中类别k和源域中类别r组成的负类别对的相似度,表示源域中类别r的归一化聚类中心向量,表示当前批次中类别e与源域中类别k组成的负类别对的相似度,表示当前批次中类别e的归一化聚类中心向量,α2表示超参数;
17、2.3)强ood类别聚类中心的对齐;
18、对于强ood样本,根据式(13)计算强ood样本的强ood评分;创建多个强ood类别的聚类中心,每个强ood类别对应一个强ood类别判别阈值;将强ood样本的强ood评分与各个强ood类别判别阈值进行比较,若强ood样本的强ood评分大于某个强ood类别判别阈值,则强ood样本属于该强ood类别;若强ood样本的强ood评分均小于等于现有的所有强ood类别判别阈值,则该强ood样本不属于现有的强ood类别,则创建新的强ood类别的聚类中心;遍历完所有强ood样本,得到所有强ood类别的聚类中心;
19、
20、式中,是强ood样本xi-str的强ood评分,di_str表示强ood样本xi-str的聚类中心,f′(xi_str)表示源域训练分类模型从强ood样本xi-str中提取的特征向量,dstr表示强ood类别的聚类中心集合,表示从源域类别和强ood类别的聚类中心中寻找与强ood样本xi-str所属类别最相似的聚类中心;
21、根据强ood类别的聚类中心,通过式(14)计算强ood类别的负对数似然损失lstr_pc;
22、
23、式中,dstr_c、dstr_d分别是强ood类别c、d的聚类中心,f′(xi_str)表示源域训练分类模型从强ood样本xi-str中提取的特征向量,ystr表示强ood类别集合;
24、2.4)采用kl散度损失度量源域和目标域特征分布之间的差异,根据式(15)计算kl散度损失lkld;
25、
26、式中,dkl表示kl散度分布,表示源域特征分布,表示目标域特征分布;
27、综上,最终的损失函数表示为:
28、lcs=lnt+lwea_pc+lstr_pc+lkld(16)
29、第三步:计算源域训练分类模型从目标域样本中提取的特征向量与各个目标域类别聚类中心的余弦相似度,最大余弦相似度对应的目标域类别即为目标域样本的预测类别标签。
30、与现有技术相比,本专利技术的有益效果是:本文档来自技高网...
【技术保护点】
1.一种基于自适应对比学习的开放世界图像分类方法,其特征在于,该方法包括以下步骤:
2.根据权利要求1所述的基于自适应对比学习的开放世界图像分类方法,其特征在于,在步骤2.1)中,将当前批次的目标域样本和增强样本分别输入到预训练分类模型和源域训练分类模型中提取特征向量,通过式(2)对特征向量进行归一化;
3.根据权利要求1或2所述的基于自适应对比学习的开放世界图像分类方法,其特征在于,在每个批次中,根据下式对强OOD判别阈值进行优化,得到最佳强OOD判别阈值;
4.根据权利要求3所述的基于自适应对比学习的开放世界图像分类方法,其特征在于,源域中类别k的聚类中心表示为:
5.根据权利要求1所述的基于自适应对比学习的开放世界图像分类方法,其特征在于,当前批次和源域中类别k的归一化聚类中心向量和表示为:
【技术特征摘要】
1.一种基于自适应对比学习的开放世界图像分类方法,其特征在于,该方法包括以下步骤:
2.根据权利要求1所述的基于自适应对比学习的开放世界图像分类方法,其特征在于,在步骤2.1)中,将当前批次的目标域样本和增强样本分别输入到预训练分类模型和源域训练分类模型中提取特征向量,通过式(2)对特征向量进行归一化;
3.根据权利要求1或2所述的基于自适应对比学习的...
【专利技术属性】
技术研发人员:黄义雄,汪梦竹,汪思嘉,陈政翰,苏厚成,任文浩,张宇,万齐,樊彦龙,尹楠,苏浩,
申请(专利权)人:河北工业大学,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。