本发明专利技术公开了一种基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法,在知识蒸馏的实现中,除了引入蒸馏损失用于学生模型拟合教师模型的logits输出,还使用输出概率分布与真实标签的交叉熵损失,以确保学生模型的输出与样本的真实标签相互匹配。这两部分损失共同构成目标函数,帮助学生模型从教师模型的“暗知识”中进行学习,优化模型的输出概率分布,从而提高剪枝模型的准确率。另外,本发明专利技术将知识蒸馏应用于LSTM模型的剪枝过程中,通过合理传递知识,使得剪枝后的模型具备更强的表征能力。使得剪枝后的模型具备更强的表征能力。使得剪枝后的模型具备更强的表征能力。
【技术实现步骤摘要】
基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法
[0001]本专利技术涉及人工智能
,尤其涉及一种基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法。
技术介绍
[0002]深度神经网络压缩和加速是针对资源受限环境下的深度神经网络模型进行优化的一系列技术和方法。这些技术的目标是减小模型大小、加快推理速度,并降低计算资源的消耗,从而使得模型能够高效地部署在移动设备、嵌入式系统和边缘设备等场景中。其中,权重剪枝是一种深度神经网络压缩的方法,它通过去除模型中冗余的、贡献较小的连接或参数来减小模型的大小。剪枝可以基于权重敏感性或梯度敏感性等标准进行,在训练前、训练中或训练后进行。
[0003]粗粒度剪枝和细粒度剪枝是深度神经网络压缩中常用的两种剪枝方法。粗粒度剪枝是一种相对较粗的剪枝方法,它通常在网络的层级或模块级别进行参数修剪。具体来说,该方法会选择整个层或模块中的一部分参数进行修剪,而不是单独剪枝每个参数。这样的剪枝方式使得整个层或模块中的一部分连接被去除,从而减小了网络的规模。由于粗粒度剪枝选择的是整个层或模块的一部分参数,因此它可能会导致一些信息损失,同时也可能降低模型的精度。但它的优点在于其相对简单和快速的操作,适用于硬件加速和高效推理;细粒度剪枝是一种更细致的剪枝方法,它会针对每个参数进行选择性修剪。具体来说,该方法会根据参数的敏感度或重要性,选择性地修剪网络中的一些参数,而保留其他参数。
[0004]细粒度剪枝相对于粗粒度剪枝来说,更加精细和精确,因为它可以更好地保留重要的连接和参数,减少信息损失,同时还有可能更好地保持模型的精度。然而,使用粗粒度剪枝容易导致模型精度下降,特别是在需要高压缩率的情况下。即使对于压缩效果较好的细粒度剪枝,当修剪参数的比例过大时,模型的精度仍然可能会降低到不足的水平。
[0005]知识蒸馏是另一种用于深度神经网络压缩的方法,通常通过将教师网络的输出(logits)作为模型内部隐藏的“暗知识”传递给较小的学生网络,从而使学生网络能够逼近教师网络的性能。相比于直接使用one
‑
hot标签进行训练,知识蒸馏策略在提高小型网络准确率方面表现更优。通过知识蒸馏,学生网络可以获得比直接使用one
‑
hot标签训练更多的信息,从而在相对较小的网络结构下,实现更好的准确率。这使得知识蒸馏成为一种有力的神经网络压缩方法,特别适用于资源受限的环境,如移动设备和嵌入式系统等。
[0006]通常情况下,过参数化的原始模型具有较强的学习和表达能力。然而,在进行剪枝操作后,网络规模变小或受到一定的约束,自身的学习能力可能无法获得复杂的表征,因此即使通过微调也难以完全恢复到原始模型的精度。传统的迭代剪枝方法中,修剪后的模型会经过微调来恢复模型的精度。然而,对于过度细粒度剪枝模型或粗粒度剪枝模型,微调可能会面临一定的困难,难以完全恢复其精度。
技术实现思路
[0007]为解决现有技术存在的局限和缺陷,本专利技术提供一种基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法,包括:
[0008]步骤S1、根据得到的数据集训练长短期记忆模型,获得具有预设的泛化能力的原始模型,保存所述原始模型;
[0009]步骤S2、设置剪枝参数,所述剪枝参数包括权重剪枝方法、稀疏度的初始值、稀疏度的期望值;
[0010]步骤S3、根据所述权重剪枝方法评估连接或权重块的重要性,排序后根据所述稀疏度确定修剪比例,根据所述修剪比例将对应的参数置零,同时禁止已经置零的参数进行更新,得到剪枝模型;
[0011]步骤S4、使用知识蒸馏方法对所述剪枝模型进行训练,将所述原始模型作为教师,将所述剪枝模型作为学生,通过在损失函数中加入蒸馏损失,使得学生模型拟合教师模型的logits输出,迭代训练预设的次数之后,得到精度恢复的模型;
[0012]步骤S5、评估所述精度恢复的模型的精度,调整所述稀疏度,根据预设的精度损失范围增减所述稀疏度,返回步骤S3继续剪枝,直至达到所述稀疏度的期望值或满足预设的终止条件。
[0013]可选的,还包括:
[0014]获取在预设的任务上进行预设的微调的BERT模型;
[0015]将所述BERT模型作为教师,将所述剪枝模型作为学生,使用知识蒸馏方法对所述剪枝模型进行训练。
[0016]可选的,还包括:
[0017]使用均方误差损失直接比较logits输出的结果差异,以计算所述蒸馏损失,所述蒸馏损失的表达式如下:
[0018][0019]其中,z
T
为所述教师模型的logits输出,z
S
为所述学生模型的logits输出,n是预测类别的个数。
[0020]可选的,还包括:
[0021]使用输出概率分布与真实标签之间的交叉熵损失作为目标函数的一部分,最终的损失函数的表达式如下:
[0022][0023]其中,y
S
是所述学生模型预测的输出概率分布,由logits经过softmax函数得到;当样本来自原始的标记数据集时,t为标记的真值标签;当样本来自数据增强生成的数据集时,将BERT模型的预测结果作为真值标签;α为权重超参数。
[0024]本专利技术具有下述有益效果:
[0025]本专利技术在知识蒸馏的实现中,除了引入蒸馏损失用于学生模型拟合教师模型的logits输出,还使用输出概率分布与真实标签的交叉熵损失,以确保学生模型的输出与样本的真实标签相互匹配。这两部分损失共同构成目标函数,帮助学生模型从教师模型的“暗知识”中进行学习,优化模型的输出概率分布,从而提高剪枝模型的准确率。另外,本专利技术将知识蒸馏应用于LSTM模型的剪枝过程中,通过合理传递知识,使得剪枝后的模型具备更强的表征能力。
附图说明
[0026]图1为本专利技术实施例一提供的基于原始模型的压缩方法的流程图。
[0027]图2为本专利技术实施例一提供的基于BERT模型的压缩方法的流程图。
[0028]图3为本专利技术实施例一提供的知识蒸馏架构示意图。
[0029]图4为本专利技术实施例一提供的针对单句文本分类任务的BiLSTM网络示意图。
[0030]图5为本专利技术实施例一提供的针对句子对匹配任务的BiLSTM网络示意图。
[0031]图6为本专利技术实施例一提供的在句子对匹配任务微调BERT
‑
base的流程图。
[0032]图7为本专利技术实施例一提供的压缩模型与原始模型在单句分类任务上功耗和推理时间上的比较示意图。
具体实施方式
[0033]为使本领域的技术人员更好地理解本专利技术的技术方案,下面结合附图对本专利技术提供的基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法进行详细描述。
[0034]实施例一
[0035]本实施例旨在提升长短期记忆(Long Short
‑
Term Memory,LSTM)剪枝后模型的精度,为此引入了知识蒸馏的方法。通常情况下,过参数化的本文档来自技高网...
【技术保护点】
【技术特征摘要】
1.一种基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法,其特征在于,包括:步骤S1、根据得到的数据集训练长短期记忆模型,获得具有预设的泛化能力的原始模型,保存所述原始模型;步骤S2、设置剪枝参数,所述剪枝参数包括权重剪枝方法、稀疏度的初始值、稀疏度的期望值;步骤S3、根据所述权重剪枝方法评估连接或权重块的重要性,排序后根据所述稀疏度确定修剪比例,根据所述修剪比例将对应的参数置零,同时禁止已经置零的参数进行更新,得到剪枝模型;步骤S4、使用知识蒸馏方法对所述剪枝模型进行训练,将所述原始模型作为教师,将所述剪枝模型作为学生,通过在损失函数中加入蒸馏损失,使得学生模型拟合教师模型的logits输出,迭代训练预设的次数之后,得到精度恢复的模型;步骤S5、评估所述精度恢复的模型的精度,调整所述稀疏度,根据预设的精度损失范围增减所述稀疏度,返回步骤S3继续剪枝,直至达到所述稀疏度的期望值或满足预设的终止条件。2.根据权利要求1所述的基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法,其特征在于,还包括:获取在预设的任务上...
【专利技术属性】
技术研发人员:王思野,李元东,赵中原,梁步顺,徐文波,赖锦林,麦吉,
申请(专利权)人:北京邮电大学,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。