System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 一种基于异构分支融合的知识蒸馏方法技术_技高网

一种基于异构分支融合的知识蒸馏方法技术

技术编号:44158126 阅读:2 留言:0更新日期:2025-01-29 10:29
本发明专利技术公开了一种基于异构分支融合的知识蒸馏方法,涉及计算机技术领域。本发明专利技术包括步骤1:获取原始数据集;步骤2:将原始数据集分为训练集和测试集,并进行预处理;步骤3:使用训练集对教师模型进行预训练,保存训练好的教师模型;步骤4:加载预训练权重,搭建对应的学生模型,并对学生模型进行训练;步骤5:在推理阶段,只保留目标分支作为基准模型。本发明专利技术能够克服教师知识单一和在师生能力差距较大时难以学习的局限性,并缓解多分支架构带来的均质化问题,且在没有额外成本增加的同时,有效提高模型的分类精度。

【技术实现步骤摘要】

本专利技术涉及计算机,具体为一种基于异构分支融合的知识蒸馏方法


技术介绍

1、深度神经网络在计算机视觉领域的众多任务中,如图像分类、目标检测及语义分割方面,均展现出了强大的性能。然而,这些高性能模型往往伴随着较大的参数量和计算成本,这在追求模型轻量化和高效部署的许多应用场景中构成了不小的挑战。为了平衡模型的参数量和性能,知识蒸馏技术应运而生。作为一种高效的知识迁移策略,知识蒸馏能够将大型复杂模型(教师模型)中的隐含知识提炼并传递给更为紧凑的轻量级模型(学生模型),有效提升这些轻量级模型在实际任务中的表现能力。

2、现有的知识蒸馏方法大致分为离线知识蒸馏和在线知识蒸馏。离线知识蒸馏采用两阶段训练方式:首先,预训练一个大型教师模型,然后将提取的知识转移到较小的学生模型中,以帮助学生学习教师模型中的复杂知识。在线知识蒸馏则采用单阶段训练方式,通过在训练过程中不断更新知识,直接优化目标模型,使学生模型充分利用来自多个输出的丰富信息。

3、在传统知识蒸馏方法中,学生很难完全学习教师提供的知识,原因是收敛后的教师模型与从头开始训练的学生模型之间存在较大的能力差距。此外,教师的固定知识无法充分提升学生的泛化能力,可以利用训练过程的实时信息作为知识来源。在线知识蒸馏策略的提出为解决这些问题提供了新的思路。


技术实现思路

1、本专利技术的目的在于提供一种基于异构分支融合的知识蒸馏方法,以克服传统知识蒸馏中,教师知识单一和在师生能力差距较大时难以学习的局限性,并缓解多分支架构带来的均质化问题。

2、为实现上述目的,本专利技术提供如下技术方案:一种基于异构分支融合的知识蒸馏方法,至少包括以下步骤:

3、步骤1:获取原始数据集;

4、步骤2:将原始数据集分为训练集和测试集,并进行预处理;

5、步骤3:使用训练集对教师模型进行预训练,保存训练好的教师模型;

6、步骤4:加载预训练权重,搭建对应的学生模型,并对学生模型进行训练;

7、步骤5:在推理阶段,只保留目标分支作为基准模型。

8、进一步的,所述步骤4至少包括以下步骤:

9、步骤4-1:冻结教师网络模型的预训练权重;

10、步骤4-2:构建基于异构分支的学生模型,采用递进特征融合模块逐步融合分支特征;

11、步骤4-3:使用协同学习目标和教师指导作为蒸馏损失来训练学生模型;

12、进一步的,所述步骤4-2至少包括以下步骤:

13、学生模型为多分支架构,其中每个分支在低级层共享参数,在高级层则具有独立的参数,作为多个并行训练的学生;

14、将第一个分支命名为branch-1,并设为目标分支,与低级层共同组成基准模型;

15、使用b来表示分支的总数量,b表示分支的序号b∈(1,2,...,b),并将第b个分支生成的特征表示为fb,递进融合模块生成的融合特征表示为fbm;

16、在目标分支的基础上逐步添加通道数为(b-1)·c的辅助块,来构建学生模型中的异构分支,其中c为高级层的通道数;

17、设每个递进融合特征模块有两个输入入口和一个输出出口;

18、将第一个分支即目标分支的特征图作为首次输入的融合特征,即f1m=f1;

19、在递进特征融合模块中,首先将第b个分支的特征fb和前一层次的融合特征进行初始融合;

20、然后,初始融合后的特征送入两条路径处理:

21、一条路径包含两个卷积核为1x1的卷积层,用于捕获局部信息;

22、另一条路径则通过额外的全局池化层来捕获全局信息;

23、将两条路径捕获的特征信息,通过sigmoid函数进行激活得到权重分数,并分别与特征fb,相乘,使用参数w来调节分支的重要程度,最终得到这一层次的融合特征

24、进一步的,采用所述递进特征融合模块的表达式为:

25、

26、其中d(·)表示融合函数。

27、进一步的,所述步骤4-3具体为:

28、通过加入超参数温度τ的softmax函数来软化logits,构建知识蒸馏损失函数;

29、教师和学生软化后的输出分别为和通过kullback-leibler散度计算蒸馏损失lkl;

30、结合真实标签与学生预测的交叉熵损失函数,以及平衡参数,构建完整的知识蒸馏损失函数lkd;

31、根据知识蒸馏损失函数,构建协同学习目标和教师指导两种损失,并通过ramp-up函数来控制知识蒸馏在训练过程中的权重,最终得到总损失对学生模型进行训练;

32、所述的协同学习目标和教师指导两种损失具体为:

33、协同学习目标由两部分组成:

34、(1)将融合特征进行输出,得到fl;

35、(2)将分支的logits进行集成,得到el,然后对fl和el进行软化处理分别得到和并计算蒸馏损失;

36、教师指导是通过计算教师模型与每个分支的蒸馏损失,教师和学生软化后的输出分别为和

37、进一步的,所述的协同学习目标损失函数和教师指导损失函数表示为:

38、

39、

40、其中lcl表示协同学习目标损失,ltg表示教师指导损失。τ2表示权重参数,用于平衡交叉熵损失和蒸馏损失;

41、所述的完整损失函数表示为:

42、

43、其中ltotal表示完整的损失函数,lce表示交叉熵损失函数。w(i)表示ramp-up函数,用于调节在不同epoch的蒸馏损失权重。

44、进一步的,所述步骤5至少包括以下步骤:

45、在学生推理阶段,去除预训练好的教师模型和增加辅助块的分支,只保留目标分支作为原始模型架构,没有额外成本增加。

46、与现有技术相比,本专利技术的有益效果是:

47、本专利技术能够克服教师知识单一和在师生能力差距较大时难以学习的局限性,并缓解多分支架构带来的均质化问题,且在没有额外成本增加的同时,有效提高模型的分类精度。

本文档来自技高网...

【技术保护点】

1.一种基于异构分支融合的知识蒸馏方法,其特征在于,至少包括以下步骤:

2.根据权利要求1所述的一种基于异构分支融合的知识蒸馏方法,其特征在于:所述步骤4至少包括以下步骤:

3.根据权利要求2所述的一种基于异构分支融合的知识蒸馏方法,其特征在于:所述步骤4-2至少包括以下步骤:

4.根据权利要求3所述的一种基于异构分支融合的知识蒸馏方法,其特征在于:采用所述递进特征融合模块的表达式为:

5.根据权利要求2所述的一种基于异构分支融合的知识蒸馏方法,其特征在于:所述步骤4-3具体为:

6.根据权利要求5所述的一种基于异构分支融合的知识蒸馏方法,其特征在于:所述的协同学习目标损失函数和教师指导损失函数表示为:

7.根据权利要求6所述的一种基于异构分支融合的知识蒸馏方法,其特征在于:所述步骤5至少包括以下步骤:

【技术特征摘要】

1.一种基于异构分支融合的知识蒸馏方法,其特征在于,至少包括以下步骤:

2.根据权利要求1所述的一种基于异构分支融合的知识蒸馏方法,其特征在于:所述步骤4至少包括以下步骤:

3.根据权利要求2所述的一种基于异构分支融合的知识蒸馏方法,其特征在于:所述步骤4-2至少包括以下步骤:

4.根据权利要求3所述的一种基于异构分支融合的知识蒸馏方法,其特征在于:采用...

【专利技术属性】
技术研发人员:李刚吕鹏飞徐传运蒋建忠阮子涵樊昕宇汪儒周正谭委邓江林周春宇
申请(专利权)人:重庆理工大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1