一种基于知识蒸馏的轻量化花卉识别方法技术

技术编号:34183146 阅读:9 留言:0更新日期:2022-07-17 13:35
本发明专利技术公开了一种基于知识蒸馏的轻量化花卉识别方法,包括以下步骤:S1.构建花卉数据集,并将花卉数据集划分为训练集和测试集;S2.选定教师网络和学生网络;S3.对教师网络初始化和训练,得到成熟的教师网络;S4.对学生网络进行初始化;S5.在教师网络的辅助下,使用花卉数据集训练初始化后的学生网络,得到成熟的学生神经网络;S6.将成熟的学生神经网络设置为eval模式,不进行反向传播;将待识别花卉图片输入成熟的学生神经网络,通过前向传播计算并输出识别结果,至此花卉识别结束。本发明专利技术使得轻量级花卉识别模型在模型大幅压缩的同时还能保持较高的准确率。能保持较高的准确率。能保持较高的准确率。

【技术实现步骤摘要】
一种基于知识蒸馏的轻量化花卉识别方法


[0001]本专利技术涉及花卉识别,特别是涉及一种基于知识蒸馏的轻量化花卉识别方法。

技术介绍

[0002]在农、林业发展中,花卉种类的快速准确鉴别具有重要的意义。传统的花卉识别方法易受到花卉形态多样性、背景环境复杂性及光照条件多变性的影响,其准确率与泛化性能有待提升。而深层卷积神经网络(Deep convolutional neural network,DCNN)在高速计算设备的辅助下可以自动学习视觉目标语义特征的特点,解决了复杂环境下的视觉目标的鲁棒性识别问题,在花卉识别应用中具有较大潜力。但在实际应用中,人们更希望能够利用便携式设备及时获得花卉的种类信息,从而在数据产生地点实时进行分析,以便于最有效地对花卉资源进行开发利用。因此在算力弱、存储成本高但是便于携带的AI边缘计算设备上高效运行DCNN花卉分类模型对于户外实时花卉识别具有重大的研究价值与意义。目前,相关研究人员已构建出多种CNN模型来进行花卉的识别;
[0003]为了追求更好的分类效果,大多数的网络模型结构变得愈发庞杂。虽然相关任务准确率得到了提升,但通过加深网络来提高准确率会增加较大的参数量,导致网络的运算量增加,需要花费极大的运算资源,使得其难以应用到AI边缘计算设备上。轻量级DCNN模型的优势主要在于构建出更加高效的卷积网络计算方式,在模型大幅压缩的同时兼顾良好的网络性能。
[0004]相较于重量级网络而言,轻量级网络的预测时间、运算力需求以及模型储存占用量都得到了极大减少,使得该类网络更加适合于移动平台的应用。但是经过实验对比发现,轻量级网络在识别的准确率上和重量级网络还有明显的差距。

技术实现思路

[0005]本专利技术的目的在于克服现有技术的不足,提供一种基于知识蒸馏的轻量化花卉识别方法,使用知识蒸馏的算法,利用重量级网络辅助训练轻量级网络,在模型大幅压缩的同时尽量减低准确率方面的损失,以此得到一个模型大幅压缩而且保持较高准确率的轻量级花卉识别模型。
[0006]本专利技术的目的是通过以下技术方案来实现的:一种基于知识蒸馏的轻量化花卉识别方法,包括以下步骤:
[0007]S1.构建花卉数据集,并将花卉数据集划分为训练集和测试集;
[0008]所述花卉数据集中包含m张花卉图片,根据每一张花卉图片的花类别,构建该图片的真实标签;所述真实标签由N个数字构成数组:若花卉图片属于第n个花类别,则真实标签的第n个数字为1,其余数字为0;花卉数据集中共有N个花类别,即花卉数据集中共有N个不同的真实标签;并且在所述花卉数据集中,每个花类别具有至少两张花卉图片;
[0009]在本申请的实施例中,所使用的花卉数据集为牛津大学制作并提供公开下载的Oxford

Flower102数据集或Oxford

Flower17数据集。其中Oxford

Flower102数据集包含
102个花类别,每个类包含40到258个图片,共8189张图片;Oxford

Flower17数据集,包含17个花类别,每个类别80张图片,共1360张图片。
[0010]将花卉数据集划分为训练集和测试集,并使得训练集和测试集均包含N个花类别的花卉图片;
[0011]S2.选定教师网络和学生网络;
[0012]S3.对教师网络初始化和训练,得到成熟的教师网络;
[0013]S4.对学生网络进行初始化;
[0014]S5.在教师网络的辅助下,使用花卉数据集训练初始化后的学生网络,得到成熟的学生神经网络;
[0015]S6.将成熟的学生神经网络设置为eval模式,不进行反向传播;将待识别花卉图片输入成熟的学生神经网络,通过前向传播计算并输出识别结果,至此花卉识别结束。
[0016]其中,所述步骤S2中,选定一个模型较大准确率较高的神经网络作为教师网络,模型较小准确率较低的神经网络作为学生网络;
[0017]所述模型较大准确率较高的神经网络包括SeNet152网络或MobilNetV3

Large网络;
[0018]所述模型较小准确率较低的神经网络包括MobilNetV3

Small网络。
[0019]其中,所述步骤S3包括:
[0020]S301.教师网络加载预先设定的ImageNet预训练权重(ImageNet预训练权重为由Pytorch官方提供的ImageNet预训练权重),并根据花卉总类别的数目N,构建一个新的全连接层:新的全连接层的输出类别与花卉训练数据集的总类别数目相同且一一对应;
[0021]利用新建的全连接层替换教师网络原有的最后一个连接层,完成教师网络的初始化;当图片输入教师网络时,教师网络的全连接层输出的是:该图片为各个花卉类别的概率;
[0022]S302.对于训练集中的任一张图片,将该图片输入教师网络做前向运算得到教师网络的输出y:
[0023]设教师网络共有K层,其中第i层的输入输出表示为
[0024]y
i
=σ
i
(x
i
*w
i
+b
i
)
[0025]其中i=1,2,

K;y
i
表示教师网络的第i层输出,x
i
表示教师网络的第i层的输入,σ
i
表示教师网络第i层所用的激活函数;设教师网络最后一层的输出为y,教师网络最后一层的输出也叫作教师网络的输出,其中包含了输入图片为各个花类别的概率;
[0026]通过CrossEntropyLoss函数计算y和真实标签label之间的硬损失L
hard
_t,
[0027]L
hard
_t=CrossEntroyLoss(y,lable)
[0028]其中,label表示当前输入图片的真实标签,
[0029]使用L
hard
_t对教师网络进行反向传播并结合Adam优化器,更新教师网络的参数:
[0030]W
i
,B
i
=Adam(L
hard_t
,w
i
,b
i
,lr)
[0031]其中,Adam优化器表示为Adam函数,w
i
,b
i
表示教师网络第i层更新前的参数,W
i
,B
i
表示教师网络第i层更新后的参数,lr为学习率;
[0032]S303.对于训练集的每一张图片,重复执行步骤S302,对教师网络参数进行更新,所有图像下的更新完成时,得到训练后的教师网络;
[0033]S304.对于测试集每一张图片,将该图片输入S303训练后的教师网络做前向运算得到教师网络的预测输出y,将y和真实标签对本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于知识蒸馏的轻量化花卉识别方法,其特征在于:包括以下步骤:S1.构建花卉数据集,并将花卉数据集划分为训练集和测试集;所述花卉数据集中包含m张花卉图片,根据每一张花卉图片的花类别,构建该图片的真实标签;所述真实标签由N个数字构成数组:若花卉图片属于第n个花类别,则真实标签的第n个数字为1,其余数字为0;花卉数据集中共有N个花类别,即花卉数据集中共有N个不同的真实标签;并且在所述花卉数据集中,每个花类别具有至少两张花卉图片;将花卉数据集划分为训练集和测试集,并使得训练集和测试集均包含N个花类别的花卉图片;S2.选定教师网络和学生网络;S3.对教师网络初始化和训练,得到成熟的教师网络;S4.对学生网络进行初始化;S5.在教师网络的辅助下,使用花卉数据集训练初始化后的学生网络,得到成熟的学生神经网络;S6.将成熟的学生神经网络设置为eval模式,不进行反向传播;将待识别花卉图片输入成熟的学生神经网络,通过前向传播计算并输出识别结果,至此花卉识别结束。2.根据权利要求1所述的一种基于知识蒸馏的轻量化花卉识别方法,其特征在于:所述步骤S2中,选定一个模型较大准确率较高的神经网络作为教师网络,模型较小准确率较低的神经网络作为学生网络;所述模型较大准确率较高的神经网络包括SeNet152网络或MobilNetV3

Large网络;所述模型较小准确率较低的神经网络包括MobilNetV3

Small网络。3.根据权利要求1所述的一种基于知识蒸馏的轻量化花卉识别方法,其特征在于:所述步骤S3包括:S301.教师网络加载预先设定的ImageNet预训练权重,并根据花卉总类别的数目N,构建一个新的全连接层:新的全连接层的输出类别与花卉训练数据集的总类别数目相同且一一对应;利用新建的全连接层替换教师网络原有的最后一个连接层,完成教师网络的初始化;当图片输入教师网络时,教师网络的全连接层输出的是:该图片为各个花卉类别的概率;S302.对于训练集中的任一张图片,将该图片输入教师网络做前向运算得到教师网络的输出y:设教师网络共有K层,其中第i层的输入输出表示为y
i
=σ
i
(x
i
*w
i
+b
i
)其中i=1,2,

K;y
i
表示教师网络的第i层输出,x
i
表示教师网络的第i层的输入,σ
i
表示教师网络第i层所用的激活函数;设教师网络最后一层的输出为y,教师网络最后一层的输出也叫作教师网络的输出,其中包含了输入图片为各个花类别的概率;通过CrossEntropyLoss函数计算y和真实标签label之间的硬损失L
hard
_t,L
hard
_t=CrossEntroyLoss(y,lable)其中,label表示当前输入图片的真实标签,使用L
hard
_t对教师网络进行反向传播并结合Adam优化器,更新教师网络的参数:W
i
,B
i
=Adam(L
hard_t
,w
i
,b
i
,lr)
其中,Adam优化器表示为Adam函数,w
i
,b
i<...

【专利技术属性】
技术研发人员:韦旭东张红雨李博史长凯韩欢钟山王曦
申请(专利权)人:电子科技大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1