本公开公开了一种模型获取方法、装置、电子设备、存储介质和程序产品,涉及计算机视觉和深度学习技术领域。具体实现方案为:获取参考模型输出的M个第一软标签,其中,所述M个第一软标签分别与所述参考模型的M个目标连接层一一对应;依据所述M个第一软标签对超网络中的子网络的中间节点的连接参数进行更新,得到目标模型;其中,所述子网络的第i个中间节点是基于第i个目标连接层对应的第一软件标签进行更新,所述第i个中间节点位于所述超网络包括的M个连接层中的第i个连接层。本公开可以提高更新得到的目标模型与超网络之间性能的一致性。
【技术实现步骤摘要】
模型获取方法、装置、电子设备、存储介质和程序产品
本公开涉及计算机
,尤其涉及计算机视觉和深度学习技术等人工智能领域。
技术介绍
随着深度学习的不断发展,其在众多领域都取得了巨大的成功,且逐渐向全自动机器学习发展。例如,神经网络结构搜索技术(NeuralArchitectureSearch,NAS)作为全自动机器学习的研究热点之一,通过设计高效的搜索方法,自动获取泛化能力强,硬件要求友好的神经网络,大量的解放了相关研究人员的创造力。传统的NAS方法需要独立采样并评估模型结构的性能,这种方式会造成很大的性能开销。为降低性能开销,基于梯度的超网络训练方法得以研究。其中,超网络可以适用于多种不同的网络结构应用。基于梯度的超网络训练方法,在超网络训练过程中,逐步删除权重最低的连接,随着连接的逐步删除,搜索空间会逐步减少,最终收敛到最优的结构中。
技术实现思路
本公开提供了一种模型获取方法、装置、电子设备、存储介质和程序产品。根据本公开的一方面,提供了一种模型获取方法,包括:获取参考模型输出的M个第一软标签,其中,所述M个第一软标签分别与所述参考模型的M个目标连接层一一对应;依据所述M个第一软标签对超网络中的子网络的中间节点的连接参数进行更新,得到目标模型;其中,所述子网络的第i个中间节点是基于第i个目标连接层对应的第一软件标签进行更新,所述第i个中间节点位于所述超网络包括的M个连接层中的第i个连接层。根据本公开的另一方面,提供了一种模型获取装置,包括:获取模块,用于获取参考模型输出的M个第一软标签,其中,所述M个第一软标签分别与所述参考模型的M个目标连接层一一对应;更新模块,用于依据所述M个第一软标签对超网络中的子网络的中间节点的连接参数进行更新,得到目标模型;其中,所述子网络的第i个中间节点是基于第i个目标连接层对应的第一软件标签进行更新,所述第i个中间节点位于所述超网络包括的M个连接层中的第i个连接层。根据本公开的另一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开提供的模型获取方法。根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使所述计算机执行本公开提供的模型获取方法。根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现本公开提供的模型获取方法。根据本公开的技术方案,通过基于参考模型中的第一软标签,对超网络进行更新,这样,可以提高更新得到的目标模型与超网络之间性能的一致性。应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。附图说明附图用于更好地理解本方案,不构成对本公开的限定。其中:图1是本公开提供的一种模型获取方法的流程图;图2是本公开提供的超网络中的部分网络结构示意图;图3是本公开提供的目标模型中的部分网络结构示意图;图4是本公开提供的一种模型获取装置的结构图之一;图5是本公开提供的一种模型获取装置的结构图之二;图6是本公开提供的一种电子设备的示意性框图。具体实施方式以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。近年来,深度学习技术在很多方向上都取得了巨大的成功,深度学习技术中,神经网络结构的好坏对目标模型的效果有非常重要的影响。人工设计神经网络结构需要非常丰富的经验和众多尝试,并且众多参数会产生爆炸性的组合,常规的随机搜索几乎不可行,因此NAS成为研究热点。传统的NAS方法需要独立采样并评估模型结构的性能,这种方式会造成很大的性能开销。为降低性能开销,基于超网络的模型训练方法通过参数共享的方式,大大加速的模型结构的搜索过程。然而,一致性问题是所有基于超网络的模型训练方案最大的问题,如果不解决一致性问题会导致搜索结果与预期结果存在非常大的性能差异。其中,一致性问题具体为:当将基于超网络的训练方法得到目标模型应用于具体的场景时,经常使得目标模型无法达到该场景对应的独立网络结构的性能,即训练得到的目标模型与超网络之间存在性能差异,也就是说,目前基于超网络的训练方法得到目标模型的性能较差。基于超网络的模型训练方案包括基于梯度的超网络训练方案和基于一步法(oneshot)的超网络训练方案。本实施例旨在解决基于梯度的超网络训练方案的一致性问题。目前,基于梯度的超网络训练方案,在超网络训练过程中,逐步删除权重最低的连接方式,随着连接方式的逐步删除,搜索空间会逐步缩小,最终收敛到最优的结构中。但上述方案无法保证删除权重最低的连接方式对超网络整体性能的影响,从而导致超网络性能与独立训练得到的网络结构之间的差异;此外,由于删除掉连接方式存在的差异,会导致超网络的性能也无法达到最优。请参见图1,图1是本公开提供的一种模型获取方法,包括:步骤S101、获取参考模型输出的M个第一软标签,其中,所述M个第一软标签分别与所述参考模型的M个目标连接层一一对应。具体地,可以将所述参考模型作为教师模型,将超网络作为学生模型。所述教师模型可以是用于对图像/视频数据进行编解码处理的大模型,可以根据实际需求的不同,具体表现为不同类型的模型,例如,卷积神经网络、深度神经网络、长短期记忆网络、生成对抗网络等,本实施例中,所述教师模型可以采用resnet101_vd等结构,教师模型在imagenettop1acc超过80。应当说明的是,上述教师模型应当是预先训练好且具有良好性能的网络模型结构,以便于基于教师模型对所述学生模型进行训练,例如,可以基于所述教师模型对所述学生模型进行蒸馏训练,以提高训练到的目标模型结构与超网络之间的一致性。上述目标连接层可以是所述参考模型的中间连接层,即除输出层和输出层之外的连接层,例如,所述目标连接层可以是所述参考模型中的卷积层。上述第一软标签可以是参考模型之所以能够表现出高精度的原因,即教师模型所拥有的“知识”的表现,以人为例,知识可以具有抽象为文字表述的经验,但对于计算机模型而言,其知识通常表现为关键特征数据。由于参考模型中的每个中间连接层都能基于自身所拥有的“知识”对上层输出的数据进行处理,因此,参考模型的每个目标连接层均具有特定的“知识”,本实施例中,所述M个第一软标签分布表示不同参考模型中的M个不同目标连接层所拥有的“知识”。具体本文档来自技高网...
【技术保护点】
1.一种模型获取方法,包括:/n获取参考模型输出的M个第一软标签,其中,所述M个第一软标签分别与所述参考模型的M个目标连接层一一对应;/n依据所述M个第一软标签对超网络中的子网络的中间节点的连接参数进行更新,得到目标模型;/n其中,所述子网络的第i个中间节点是基于第i个目标连接层对应的第一软件标签进行更新,所述第i个中间节点位于所述超网络包括的M个连接层中的第i个连接层。/n
【技术特征摘要】
1.一种模型获取方法,包括:
获取参考模型输出的M个第一软标签,其中,所述M个第一软标签分别与所述参考模型的M个目标连接层一一对应;
依据所述M个第一软标签对超网络中的子网络的中间节点的连接参数进行更新,得到目标模型;
其中,所述子网络的第i个中间节点是基于第i个目标连接层对应的第一软件标签进行更新,所述第i个中间节点位于所述超网络包括的M个连接层中的第i个连接层。
2.根据权利要求1所述的方法,其中,所述子网络中的每个中间节点均存在K+1个连接,所述K为大于1的整数,所述依据所述M个第一软标签对超网络中的子网络的中间节点的连接参数进行更新,包括:
对所述子网络进行K轮迭代更新,其中,每轮迭代更新删除每个所述中间节点的一个连接。
3.根据权利要求2所述的方法,其中,所述超网络包括输出节点,所述K轮迭代更新中的第j轮更新,包括:
按照预设顺序分别对所述子网络中的中间节点的连接进行删除,其中,所述预设顺序为按照所述子网络中的中间节点与所述输出节点之间的距离,由小至大进行排序得到的顺序。
4.根据权利要求3所述的方法,其中,所述按照预设顺序分别对所述子网络中的中间节点的连接方式进行删除,包括:
对目标中间节点进行K+1次删除操作,得到K+1个中间超网络,其中,每次删除操作删除所述目标中间节点的一个不同连接,所述目标中间节点为所述子网络中的任意中间节点,且所述目标中间节点位于所述M个连接层中的第y个连接层,y为1至M中任一整数;
确定所述K+1个中间超网络中的第y个连接层输出的K+1个第二软标签;
基于所述K+1个第二软标签与第y个目标连接层所对应的目标第一软标签之间的距离,将所述K+1个中间超网络中的目标超网络确定为更新后的超网络。
5.根据权利要求4所述的方法,其中,所述目标超网络为所述K+1个中间超网络中输出的第二软标签与所述目标第一软标签之间的距离最小的超网络。
6.根据权利要求1所述的方法,其中,所述参考模型包括的连接层的数量为所述M的整数倍,所述M个目标连接层中任意两个目标连接层之间间隔的连接层数相同。
7.一种模型获取装置,包括:
获取模块,用于获取参考模型输出的M个第一软标签,其中,所述M个第一软标签分别与所述参考模型的M个目标连接层一一对应;
更新模块,用于依据所述M个第一软标签对超网络中的子网络的中间节点的连接参...
【专利技术属性】
技术研发人员:希滕,张刚,温圣召,
申请(专利权)人:北京百度网讯科技有限公司,
类型:发明
国别省市:北京;11
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。