针对分布于未知域的图像的分类预测问题,源数据集和目标数据集的不匹配分布将导致源模型在目标域的性能显著下降,而目前已提出的针对跨域视觉表示的分布对齐方法没有考虑到跨域内部数据结构的差异。本发明专利技术通过样本结构特征的相似性,利用样本的CNN特征构建密集连接实例图。每个节点对应样本的CNN特征,该特征由标准卷积网络提取。然后,将图卷积网络应用于实例图,并将图结构信息沿着设计的网络学习加权图的边缘进行传播以更新节点。本发明专利技术利用类均值构造类原型进行分类,还考虑了实例节点的比较监督学习,以学习实例节点上类语义信息。本发明专利技术为了更好地学习和减少领域之间类别语义信息的差异,采用软标签进行领域之间知识蒸馏。蒸馏。蒸馏。
【技术实现步骤摘要】
一种基于图原型网络和实例对比的领域泛化方法
[0001]本专利技术属于机器学习域泛化领域,具体涉及一种基于图原型网络和实例对比的领域泛化方法。
技术介绍
[0002]通常,大多数机器学习模型先在源域数据集上进行训练,然后将训练结果在目标域数据集上进行预测,其中往往隐含地假设源域数据集和目标域数据集都遵循相同的分布。然而,此类假设在现实世界中往往不能成立。例如,对于基于不同角度、设备、环境等条件收集的多领域图像,在一个领域上通过训练获得的分类器在其他领域的运用效果不佳。这里,将某一领域的知识迁移到其他不可见领域的过程被称为领域泛化。在迁移学习中,域泛化问题的困难主要来自于两方面,其一是不同源数据集的分布差异,其二是目标域的不可知性。域泛化旨将在源域数据集上通过训练获得的模型直接推广到具有不同分布的不可见目标域,而无需在目标域数据集上进行再训练或微调。领域泛化解决针对分布于未知域的图像的分类预测问题。源数据集和目标数据集的不匹配分布将导致源模型在目标域的性能显著下降。目前已提出的针对跨域视觉表示的分布对齐方法没有考虑到跨域内部数据结构的差异,并受制于不充分的对齐跨域表示。例如,深度对抗性自适应方法仅迫使全域分布的对齐,但可能会丢失每个类别的关键语义类标签信息,同时必须在训练中使用域标签进行监督学习。即便使用完美的混淆对齐,也不能保证特征空间中具有相同类标签的非同域样本的相邻映射。然而,对于传统的与数据结构分布对齐相关的方法,虽然可以减少域之间分布差异,并保留原始空间属性,但很难有效地模型化数据结构信息并集成到现有的深度网络中。
技术实现思路
[0003]基于图像原型网络和案例比较网络的图像分类方法的总体框架如图1所示。为了对深度网络下的数据结构进行建模,通过样本结构特征的相似性,利用样本的CNN特征构建密集连接实例图。每个节点对应样本的CNN特征,该特征由标准卷积网络(例如ResNet)提取。然后,将图卷积网络(GCN)应用于实例图,并将图结构信息沿着设计的网络学习加权图的边缘进行传播以更新节点。一方面,利用类均值构造类原型进行分类;另一方面,还考虑了实例节点的比较监督学习,以学习实例节点上类语义信息。同时,为了更好地学习和减少领域之间类别语义信息的差异,采用软标签(logit)进行领域之间知识蒸馏,即缩小Kullback
‑
Leibler(KL)散度。知识蒸馏将具有相同类别标签但不同域的数据的预测分布集合与每个预测分布相匹配,通过使用多个域累积的有意义误差的集合惩罚样本的预测来增加模型预测的熵,鼓励模型收敛到宽局部最小值。本专利技术提出的基于图原型网络和实例对比的领域泛化方法的具体步骤如下:
[0004]步骤1:获取图像样本及其标签,并构建图像特征提取模型;
[0005]获取图像样本构建初始图像数据集,将所述初始图像数据集划分为源域数据集M
={M1,...M
i
,...,M
m
}和目标域数据集T,其中M
m
表示第m域数据集;所述目标域数据集在所述图像特征提取模型的训练过程中是不可访问的;
[0006]源域数据集M划分为训练集和验证集,将所述源域数据集M中的图像进行数据增强;
[0007]获取预训练模型,基于预训练模型构建所述图像特征提取模型;
‑
通过所述图像特征提取模型,提取源域数据集M中的特征,作为图输入特征X;
[0008]步骤2:建立图卷积网络并获取类原型表示;
[0009]将提取源域数据集M的特征的图结构信息定义为G=<V,E,Z>,其中V={v1,...,v
n
}是n个节点的集合,是通过两层GCN层提取获得的节点特征,E={e
11
,...,e
ij
,...,e
nn
}表示节点之间距离;其中,采用余弦相似度表示节点i和节点j之间距离;
[0010]通过节点间距离E构造包含n个节点的无向图邻接矩阵A,将所述无向图邻接矩阵A转换其中,为度矩阵,j为节点i的邻接节点编号;
[0011]根据节点之间相似度,构建归一化后的邻接矩阵其中,I是单位矩阵;
[0012]对于一个给定的包含n个节点的无向图邻接矩阵A∈R
n
×
n
,图卷积的线性变换取决于图输入特征X∈R
k
×
n
与滤波器W∈R
k
×
d
;
[0013]其中,图输入特征X中的列向量X
i
∈R
k
是节点的集合V中第i个节点的特征表示,d表示输出的特征维度;
[0014]按如下式所示的方法,进行两层的GCN处理得到嵌入特征
[0015][0016]其中,σ为激活函数,表示为第i个节点在第l层的输出,是图卷积输入;
[0017]之后利用图卷积网络生成的嵌入特征计算类原型P∈R
c
×
d
的表示,表示图卷积网络的第m源域中第i个节点输出;
[0018]所述类原型的定义为被同一类的节点紧密包围,这样同一类的节点就可以表示自己的类;第m源域的第c类的原型通过以下方式计算:
[0019][0020]其中PROTP是计算类原型P的表示的方式,是第m源域中第i个节点的表示,m
c
为第m源域的第c类,v
i
为第m域的第c类的第i个节点,具体公式如下:
[0021][0022]一般计算类原型时假设每个类只使用一个原型来表示,但原型分布不是单峰时,这种类表示是不充分的。此时,每个类可以使用多原型来表示,并用置换不变函数代替均值(如K
‑
means聚类)。为了简便起见,按平均值进行计算。
[0023]将所述节点从原始嵌入空间投影到另一个距离空间来学习一个距离度量表示;
[0024]步骤3.通过比较节点的学习距离度量表示与类原型的距离度量表示进行分类;
[0025]计算距离度量损失:
[0026]由图卷积学习到的嵌入节点计算每个节点到每个类原型的距离度量表示:
[0027][0028]其中,为第m源域中每个节点和每个类原型之间嵌入差异;
[0029]将节点嵌入差异联系到所有类原型,并按如下式(5)所示的方法,应用线性变换f对嵌入差异的不同维度给予不同程度的关注,同时自适应地提取嵌入差异信息,如下式所示:;
[0030][0031]距离度量表示g表示节点v到所有类原型的距离信息,用于定义了第m源域中节点与所有类原型的相对位置,c∈C表示第c个类;按如下式所示的方法,将距离度量表示通过连接层concat连接起来,以计算在所有源域M中的类原型和节点的距离度量表示:
[0032]G=concat(g1,
…
,g
m
)(6)
[0033]然后计算第i个节点v
i本文档来自技高网...
【技术保护点】
【技术特征摘要】
1.一种基于图原型网络和实例对比的领域泛化方法,其特征在于,包括以下步骤:步骤1:获取图像样本及其标签,并构建图像特征提取模型;获取图像样本构建初始图像数据集,将所述初始图像数据集划分为源域数据集M={M1,...M
i
,...,M
m
}和目标域数据集T,其中M
m
表示第m域数据集;所述目标域数据集在所述图像特征提取模型的训练过程中是不可访问的;源域数据集M划分为训练集和验证集,将所述源域数据集M中的图像进行数据增强;获取预训练模型,基于预训练模型构建所述图像特征提取模型;通过所述图像特征提取模型,提取源域数据集M中的特征,作为图输入特征X;步骤2:建立图卷积网络并获取类原型表示;将提取源域数据集M的特征的图结构信息定义为G=<V,E,Z>,其中V={v1,...,v
n
}是n个节点的集合,是通过两层GCN层提取获得的节点特征,E={e
11
,...,e
ij
,...,e
nn
}表示节点之间距离;其中,采用余弦相似度表示节点i和节点j之间距离;通过节点间距离E构造包含n个节点的无向图邻接矩阵A,将所述无向图邻接矩阵A转换其中,为度矩阵,j为节点i的邻接节点编号;根据节点之间相似度,构建归一化后的邻接矩阵其中,I是单位矩阵;对于一个给定的包含n个节点的无向图邻接矩阵A∈R
n
×
n
,图卷积的线性变换取决于图输入特征X∈R
k
×
n
与滤波器W∈R
k
×
d
;其中,图输入特征X中的列向量X
i
∈R
k
是节点的集合V中第i个节点的特征表示,d表示输出的特征维度;按如下式所示的方法,进行两层的GCN处理得到嵌入特征进行两层的GCN处理得到嵌入特征其中,σ为激活函数,表示为第i个节点在第l层的输出,是图卷积输入;之后利用图卷积网络生成的嵌入特征计算类原型P∈R
c
×
d
的表示,表示图卷积网络的第m源域中第i个节点输出;所述类原型的定义为被同一类的节点紧密包围,这样同一类的节点就可以表示自己的类;第m源域的第c类的原型通过以下方式计算:其中PROTP是计算类原型P的表示的方式,是第m源域中第i个节点的表示,m
c
为第m源域的第c类,v
i
为第m域的第c类的第i个节点,具体公式如下:将所述节点从原始嵌入空间投影到另一个距离空间来学习一个距离度量表示;步骤3:通过比较节点的学习距离度量表示与类原型的距离度量表示进行分类;计算距离度量损失:由图卷积学习到的嵌入节点计算每个节点到每个类原型的距离度量
【专利技术属性】
技术研发人员:彭伟民,郭浩栋,
申请(专利权)人:杭州电子科技大学,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。