当前位置: 首页 > 专利查询>四川大学专利>正文

一种基于自蒸馏的图像分类方法技术

技术编号:38770283 阅读:23 留言:0更新日期:2023-09-10 10:43
本发明专利技术公开了一种基于自蒸馏的图像分类方法,涉及计算机视觉和人工智能领域。方法包括:(1)将残差网络分为四部分,在每部分后结合瓶颈层和全连接层设置一个分类器,共四个分类器,浅层分类器通过蒸馏法训练,同时使用标签和最深层分类器监督,并通过最深层分类器的特征监督浅层分类器的特征,将下采样提前,其中,残差连接处的下采样使用平均池化代替;(2)在(1)所述框架中的第一个残差块和第二个残差块之间添加位置注意力模块;(3)在(1)所述框架中引入由SPC模块和SE Weight模块构成的注意力模块。本发明专利技术能够有效地将提取图像特征,提高了图像分类的准确率,在图像分析等领域具有开阔的应用前景。阔的应用前景。

【技术实现步骤摘要】
一种基于自蒸馏的图像分类方法


[0001]本专利技术涉及一种基于自蒸馏的图像分类方法,属于计算机视觉领域。

技术介绍

[0002]图像分类是计算机视觉领域方向的研究热点之一,也是实现目标检测、姿态估计、人脸识别等应用的重要基础,因此图像分类具有很高的学术研究价值。图像分类,即输入一幅图片,通过分类算法来判断该图片的类别。近年来,图像分类领域已取得较好的成就,但是许多应用场景下数据收集困难,无法满足网络对数据量的要求。
[0003]对于的图像分类,由于数据量较少,当前大多数方法采用基于元学习的方法和基于迁移学习的方法,而基于元学习的方法具有较好的性能和较强的泛化能力,但是其需要提取足够准确的元知识,由于人们对元知识的理解不够充分,目前提取的元知识存在欠缺。基于迁移学习的方法中,较为典型的是知识蒸馏,但是传统的知识蒸馏需要耗费大量时间和数据预训练一个教师模型,需求大量的时间成本和数据量,而且存在两个问题,一个问题是知识转移效率低,学生模型很难优于教师模型;另一个问题是设计一个合适的教师网络需要大量的努力和实验。

技术实现思路

[0004]为了解决现有技术的不足,本专利技术提出了一种基于自蒸馏的图像分类方法,目的在于设计类似多分类器体系的包含位置注意力模块和金字塔拆分注意力模块的自蒸馏框架,提高图像分类的准确率。
[0005]本专利技术采用以下技术方案:一种基于自蒸馏的图像分类方法,该方法包括以下步骤:
[0006](1)将残差网络(以下以ResNet50为例)分为四部分,在每部分后结合瓶颈层和全连接层设置一个分类器,共四个分类器,浅层分类器通过蒸馏法训练,同时使用标签和最深层分类器监督,并通过最深层分类器的特征监督浅层分类器的特征,并将ResNet50中的7*7卷积改为3个3*3卷积,将所有下采样提前,其中,残差连接处的下采样使用平均池化代替;
[0007](2)在(1)所述框架中的第一个残差块和第二个残差块之间添加位置注意力模块;
[0008](3)在(1)所述框架中将瓶颈层中的3*3卷积改为金字塔拆分注意力模块。
[0009]与现有技术相比,本专利技术的有益效果在:
[0010]1、本专利技术采用自监督框架,浅层分类器通过蒸馏法训练,而不是仅仅通过标签进行训练,这使得自蒸馏法的准确率更高,同时,因为网络不需要预训练,节约了大量时间成本和数据成本;
[0011]2、本专利技术加入位置注意力模块,通过把位置信息嵌入到通道注意力,使网络获取了更大区域的信息同时避免了较大的计算开销;
[0012]3、本专利技术加入金字塔注意力拆分模块,将校正后的注意力向量作用于多尺度特征图并将结果作为输出,使得输出具有丰富的多尺度信息;
[0013]4、本专利技术采用多个小卷积并将下采样提前,采用多个小卷积在感受野大小不变的前提下减少参数并且有更多的非线性,使得判决函数更加具有判决性,且多个小卷积可以表达出数据集中更多的强力特征,下采样提前可以减少信息丢失。
附图说明
[0014]图1为本专利技术图像分类方法整体网络框图;
[0015]图2为本专利技术位置注意力模块结构图;
[0016]图3为本专利技术金字塔注意力拆分模块结构图;
[0017]图4为本专利技术SPC模块结构图;
[0018]图5为本专利技术SE Weight模块结构图。
具体实施方式
[0019]为了使本专利技术的目的、技术方案及优点更加清楚明白,以下结合附图,对本专利技术进一步详细说明。应当理解,此处所描述具体实施方式仅仅用以解释本专利技术,并不用于限定本专利技术。
[0020]如图1所示,一种基于自蒸馏的图像分类方法,包括以下部分:
[0021](1)通过对ResNet50的4个残差块的前3个残差块添加瓶颈层及全连接层,每个残差块和残差块后的瓶颈层及全连接层各自构成一个分类器,共4个分类器,在第一个残差块和第二个残差块之间添加位置注意力模块;
[0022](2)将ResNet50中的7*7卷积改为3个3*3卷积,将所有下采样提前,其中,残差连接处的下采样使用平均池化代替,将瓶颈层中的3*3卷积改为金字塔拆分注意力模块;
[0023](3)该网络包含三个损失,分别为标签和分类器之间的交叉熵损失、学生和教师网络之间的KL散度损失和用于计算最深层分类器和浅层分类器之间特征图差异的二次损失。
[0024]详细阐述如下:
[0025]1.关于整个网络框架及损失函数,具体说明如下:
[0026]自蒸馏框架将ResNet50分为四部分,在每部分后结合瓶颈层和全连接层设置分类器,采用类似于深度监督网络的多分类器体系结构,不同之处在于浅层分类器通过蒸馏法训练,而不是仅仅通过标签进行训练,这使得自蒸馏法的准确率更高。自蒸馏法不需要提前预训练模型,节约了大量时间成本和数据成本。
[0027]该框架中,浅段结合的瓶颈层和全连接层只用于训练,瓶颈层主要是为了减轻每个浅分类器之间的影响。训练时,将所有具有对应非分类器的浅层部分从最深层的部分提炼出来作为学生模型。
[0028]从M个类中给定N个例子对应的标签为y
i
∈{1,2,...,M},神经网络中的分类器记为C为神经网络中分类器的数量,在每个分类器后面设置一个softmax层:其中,z是全连接层后的输出,T表示蒸馏温度,表示分类器θ
c/C
的分为i类的概率。
[0029]为了提高学生模型的性能,在训练过程中引入如下三种损失:
[0030](1)交叉熵损失,用于计算标签和各个分类器之间的损失,即:
[0031](1

α)
·
Cross Entropy(q
i
,y)
[0032](2)KL散度(Kullback

Leibler Divergence,KLD),计算学生和教师网络之间的损失,即:
[0033]α
·
KL(q
i
,q
c
)
[0034](3)二次损失(L2

norm loss function,L2),用于计算最深分类器和每个浅层分类器的特征图之间的损失,即:
[0035][0036]其中,二次损失使用最深层分类器即教师模型隐含层的输出指导学生模型的学习,减少特征映射在浅分类器和最深分类器之间的距离。最终,神经网络的总损失为:
[0037][0038]2.关于位置注意力模块,具体说明如下:
[0039]如图2所示,首先利用两个全局池化将垂直方向和水平方向的输入特征分别聚合成两个单独的位置感知,使用连接和卷积压缩通道,然后通过BN和Non

linear将具有嵌入的特定方向信息的两个特征图分别编码,得到两个注意力图,每个注意力图都沿同一个空间方向获取输入特征图的远距离依存关系。位置信息被保存在生成的注意力图中,然后通过乘法将两个注意力图都本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于自蒸馏的图像分类方法,其特征在于,包括以下步骤:(1)将残差网络(以下以ResNet50为例)分为四部分,在每部分后结合瓶颈层和全连接层设置一个分类器,共四个分类器,浅层分类器通过蒸馏法训练,同时使用标签和最深层分类器监督,并通过最深层分类器的特征监督浅层分类器的特征,并将下采样提前,其中,残差连接处的下采样使用平均池化代替;(2)在(1)所述框架中的第一个残差块和第二个残差块之间添加位置注意力模块;(3)在(1)所述框架中引入由SPC模块和SE Weight模块构成的注意力模块。2.根据权利要求1所述的一种基于自蒸馏的图像分类方法,其特征在于,所述步骤(1)中网络为:首先,通过对ResNet50的4个残差块的前3个残差块添加瓶颈层及全连接层,每个残差块和残差块后的瓶颈层及全连接层各自构成一个分类器,共4个分类器;其次,将ResNet50中的7*7卷积改为3个3*3卷积,将所有下采样提前,其中,残差连接处的下采样使用平均池化代替;最后,该网络包含三个损失,分别为标签和分类器之间的交叉熵损失、学生和教师网络之间的KL散度损失和用于计算最深层分类器和浅层分类器之间特征图差异的二次损失。3.根据权利要求1所述的一种基于自蒸馏的图像分类方法,其特征在于,所述步骤(2)中在第一个残差块和第二个残差块之间引入了位置注意力模块,该模块网络结构为:首先利用两个全局池化将垂直方向和水平方向的输入特征分别聚合成两个单独的位置感知,使用连接和卷积压缩通道,然后通过BN和Non

linear将具有嵌入的特定方向信息的两个特征图分别编码,得到两个注意力图,每个注意力图都沿同一个空间方向获取输入特征图的远距离依存关系;位置信息被保存在生成的注意力图中,然后通过乘法将两个注意力图都应用于输入特征图,以此强调注意区域。4.根据权利要求1所述的一种基于自蒸馏的图像分类方法,其特征在于,所述步骤(3)中将原始ResNet瓶颈层中的3*3卷积改为由SPC模块和SE Weight模块构成的注意力模块;首先,利用SPC模块构建多尺度特征;然后,通过SE方式得到通道级注意力向量已提取不同尺度特征;其次,采用Softmax对上述所得通道注意力向量进行重校正;最后,将校正后的注意力向量作用于多尺度特征图并将结果作为输出;所述步骤(3)中的SPC模块如下:假设输入为X,现将其拆分为S部分{X0,X1,...,X
S
‑1},每个部分通道数为第i个特征映射尺寸为X
i
∈R
C
′×
H
×
W
,如公式1所示:Split(X)=[X0,X1,...,X
S
‑1]
ꢀꢀꢀꢀ
(1)为了在不增加计算量的前提下处理不同核尺度的输入张量,采用了一种群卷积方法,其中,多尺度核大小与组大小的关系如公式2所示:其中,K是卷积核大小,G为组大...

【专利技术属性】
技术研发人员:何小海李雨婷刘强曾王明卿粼波陈洪刚吴晓红
申请(专利权)人:四川大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1