System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 一种模型剪枝方法、装置、设备及存储介质制造方法及图纸_技高网

一种模型剪枝方法、装置、设备及存储介质制造方法及图纸

技术编号:44133984 阅读:1 留言:0更新日期:2025-01-29 10:13
本发明专利技术公开了一种模型剪枝方法、装置、设备及存储介质,涉及人工智能技术领域,该方法包括:将测试数据集输入到训练完成的待剪枝网络模型中;针对所述待剪枝网络模型中的每个网络层,获取所述网络层基于各所述测试数据分别输出的网络向量集,并基于各所述测试标签和各所述网络向量集,对未训练完成的初始分类器进行训练得到与所述网络层对应的目标分类器;基于至少两个目标分类器分别对应的分类效果数据,对所述待剪枝网络模型执行剪枝操作得到剪枝网络模型,本发明专利技术实施例解决了传统的剪枝评估算法准确度不高的问题,在保证剪枝后的网络模型的模型性能的同时,尽可能提高了剪枝后的网络模型的推理速度。

【技术实现步骤摘要】

本专利技术涉及人工智能,尤其涉及一种模型剪枝方法、装置、设备及存储介质


技术介绍

1、深度学习模型越来越广泛地被应用在各种业务场景,如人脸识别、图像分类、图像分割等等。通过对业务数据进行预处理、对业务特征进行选择和加工、选择训练模型、调整训练模型的超参数、执行迭代训练等过程,得到能够解决实际业务需求的深度学习模型。

2、深度学习模型中可能存在冗余的网络层,即便将该网络层从深度学习模型中去除,模型性能几乎不会下降,这种现象被称为过参数化,而消除过参数化的技术被称为模型剪枝。

3、在实现本专利技术的过程中,发现现有技术中至少存在以下技术问题:

4、传统的模型剪枝方法对网络层采用的评估算法的准确度不高,导致剪枝后的网络模型的模型性能下降。


技术实现思路

1、本专利技术实施例提供了一种模型剪枝方法、装置、设备及存储介质,以解决传统的剪枝评估算法准确度不高的问题,在保证剪枝后的网络模型的模型性能的同时,尽可能提高剪枝后的网络模型的推理速度。

2、根据本专利技术一个实施例提供了一种模型剪枝方法,该方法包括:

3、将测试数据集输入到训练完成的待剪枝网络模型中;其中,所述测试数据集中包含至少两个测试标签分别对应的测试数据,所述待剪枝网络模型包括至少两个网络层;

4、针对每个网络层,获取所述网络层基于各所述测试数据分别输出的网络向量集,并基于各所述测试标签和各所述网络向量集,对未训练完成的初始分类器进行训练得到与所述网络层对应的目标分类器;

5、基于至少两个目标分类器分别对应的分类效果数据,对所述待剪枝网络模型执行剪枝操作得到剪枝网络模型。

6、根据本专利技术另一个实施例提供了一种模型剪枝装置,该装置包括:

7、测试数据集输入模块,用于将测试数据集输入到训练完成的待剪枝网络模型中;其中,所述测试数据集中包含至少两个测试标签分别对应的测试数据,所述待剪枝网络模型包括至少两个网络层;

8、初始分类器训练模块,用于针对每个网络层,获取所述网络层基于各所述测试数据分别输出的网络向量集,并基于各所述测试标签和各所述网络向量集,对未训练完成的初始分类器进行训练得到与所述网络层对应的目标分类器;

9、剪枝网络模型确定模块,用于基于至少两个目标分类器分别对应的分类效果数据,对所述待剪枝网络模型执行剪枝操作得到剪枝网络模型。

10、根据本专利技术另一个实施例提供了一种电子设备,该电子设备包括:

11、至少一个处理器;以及

12、与所述至少一个处理器通信连接的存储器;其中,

13、所述存储器存储有可被所述至少一个处理器执行的计算机程序,所述计算机程序被所述至少一个处理器执行,以使所述至少一个处理器能够执行本专利技术任一实施例所述的模型剪枝方法。

14、根据本专利技术另一个实施例提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使处理器执行时实现本专利技术任一实施例所述的模型剪枝方法。

15、本专利技术实施例的技术方案,通过针对待剪枝网络模型中的每个网络层,获取网络层基于输入的至少两种测试数据分别输出的网络向量集,并基于各测试数据分别对应的测试标签和网络向量集,对未训练完成的初始分类器进行训练得到与网络层对应的目标分类器,基于至少两个目标分类器分别对应的分类效果数据,对待剪枝网络模型执行剪枝操作得到剪枝网络模型,解决了传统的剪枝评估算法准确度不高的问题,在保证剪枝后的网络模型的模型性能的同时,尽可能提高了剪枝后的网络模型的推理速度。

16、应当理解,本部分所描述的内容并非旨在标识本专利技术的实施例的关键或重要特征,也不用于限制本专利技术的范围。本专利技术的其它特征将通过以下的说明书而变得容易理解。

本文档来自技高网...

【技术保护点】

1.一种模型剪枝方法,其特征在于,包括:

2.根据权利要求1所述的方法,其特征在于,所述测试数据的数据类型为订单数据,所述测试标签为订单风险标签,所述待剪枝网络模型为风险预测模型。

3.根据权利要求1所述的方法,其特征在于,所述基于至少两个目标分类器分别对应的分类效果数据,对所述待剪枝网络模型执行剪枝操作得到剪枝网络模型,包括:

4.根据权利要求3所述的方法,其特征在于,所述基于各所述差值效果数据,对所述待剪枝网络模型执行剪枝操作得到剪枝网络模型,包括:

5.根据权利要求4所述的方法,其特征在于,所述方法还包括:

6.根据权利要求3所述的方法,其特征在于,所述基于各所述差值效果数据,对所述待剪枝网络模型执行剪枝操作得到剪枝网络模型,包括:

7.根据权利要求1所述的方法,其特征在于,所述基于各所述测试标签和各所述网络向量集,对未训练完成的初始分类器进行训练得到与所述网络层对应的目标分类器,包括:

8.根据权利要求1-7任一项所述的方法,其特征在于,所述分类效果数据包括至少一个测试标签分别对应的标签分类精度,或者,所述分类效果数据包括所述目标分类器的总分类精度。

9.一种模型剪枝装置,其特征在于,包括:

10.一种电子设备,其特征在于,所述电子设备包括:

11.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使处理器执行时实现权利要求1-8中任一项所述的模型剪枝方法。

...

【技术特征摘要】

1.一种模型剪枝方法,其特征在于,包括:

2.根据权利要求1所述的方法,其特征在于,所述测试数据的数据类型为订单数据,所述测试标签为订单风险标签,所述待剪枝网络模型为风险预测模型。

3.根据权利要求1所述的方法,其特征在于,所述基于至少两个目标分类器分别对应的分类效果数据,对所述待剪枝网络模型执行剪枝操作得到剪枝网络模型,包括:

4.根据权利要求3所述的方法,其特征在于,所述基于各所述差值效果数据,对所述待剪枝网络模型执行剪枝操作得到剪枝网络模型,包括:

5.根据权利要求4所述的方法,其特征在于,所述方法还包括:

6.根据权利要求3所述的方法,其特征在于,所述基于各所述差值效果数据,对所述待剪枝网络模型执行...

【专利技术属性】
技术研发人员:夏晓华郝应涛黄志翔
申请(专利权)人:京东科技控股股份有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1