System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 一种基于图神经网络架构挑选的知识蒸馏方法技术_技高网

一种基于图神经网络架构挑选的知识蒸馏方法技术

技术编号:42872705 阅读:3 留言:0更新日期:2024-09-27 17:32
本发明专利技术提供了一种基于图神经网络架构挑选的知识蒸馏方法,包括:获取预先定义的图神经网络架构的搜索空间,其定义架构搜索的范围;利用神经架构搜索技术和预设分类任务对应的验证集,采用强化学习机制从所述搜索空间中搜索执行该分类任务的最优图神经网络结构,所述预设分类任务与训练教师图模型时所对应的分类任务相同;基于预设的知识蒸馏方式,利用所述教师图模型指导采用所述最优图神经网络结构的学生图模型进行节点分类,得到经训练的学生图模型。

【技术实现步骤摘要】

本专利技术涉及图数据挖掘领域,具体来说涉及图神经网络知识蒸馏领域,更具体地说,涉及一种基于图神经网络架构挑选的知识蒸馏方法


技术介绍

1、图神经网络(graph neural networks,简称gnns)作为一种有效的图数据处理工具,在近年来受到了广泛关注。通过有效地捕捉图结构中的相互依赖关系,gnns能够对图中的节点和边进行高效地推断和预测,已经在社交网络分析、推荐系统、化学分子分析等领域展现出强大的潜力。然而,在实际应用中,如何构建一个性能优异的图神经网络模型依然是一个具有挑战性的问题。

2、近年来出现了图神经网络知识蒸馏(图蒸馏)方法,通过将预训练教师图模型的知识转移到学生图模型中,来增强学生表现。图蒸馏的核心思想是通过提取教师模型的表示学习能力和知识,将其传递给学生图模型,从而进一步提升学生图模型的预测准确性。这种方法在一定程度上为gnns在图数据分析和应用上,提供了一种有效的模型性能增强方案。

3、然而,目前的图蒸馏方法主要集中在知识的传递的优化上,所选择的学生图模型的结构可能并不适用教师图模型所应用的分类任务,即教师和学生的网络结构可能不匹配,以致影响学生图模型的性能和泛化能力。

4、需要说明的是:本
技术介绍
仅用于介绍本专利技术的相关信息,以便于帮助理解本专利技术的技术方案,但并不意味着相关信息必然是现有技术。相关信息与本专利技术方案一同提交和公开,在没有证据表明相关信息已在本专利技术的申请日以前公开的情况下,相关信息不应被视为现有技术。


技术实现思路

1、因此,本专利技术的目的在于克服上述现有技术的缺陷,提供一种基于图神经网络架构挑选的知识蒸馏方法。

2、本专利技术的目的是通过以下技术方案实现的:

3、根据本专利技术的第一方面,提供一种基于图神经网络架构挑选的知识蒸馏方法,包括:获取预先定义的图神经网络架构的搜索空间,其定义架构搜索的范围;利用神经架构搜索技术和预设分类任务对应的验证集,采用强化学习机制从所述搜索空间中搜索执行该分类任务的最优图神经网络结构,所述预设分类任务与训练教师图模型时所对应的分类任务相同;基于预设的知识蒸馏方式,利用所述教师图模型指导采用所述最优图神经网络结构的学生图模型进行节点分类,得到经训练的学生图模型。

4、可选的,强化学习机制包括:设置参数化的策略网络,利用策略网络基于所述搜索空间进行多次模拟,每次模拟得到一个仿真网络结构;利用仿真网络结构构建的模型在验证集上的性能指标作为奖励,指导策略网络优化网络参数;基于优化参数后的策略网络,从搜索空间中确定最优图神经网络结构。

5、可选的,所述预设的蒸馏方式包括:获取训练集,其包括多个样本和每个样本对应的标签,所述样本为包括节点和边的图数据,所述标签为图数据中至少部分节点所属的类别真值;将训练集中的样本输入教师图模型和学生图模型,得到教师图模型的分类预测值和学生图模型的分类预测值;利用预设的损失函数、教师图模型的分类预测值和学生图模型的分类预测值,指导更新学生图模型的参数,以降低预测损失。

6、可选的,预设的损失函数用于计算一次训练所采用的所有节点的节点损失的均值,单个节点损失采用以下kl散度损失函数计算:

7、lkl=kl(yt||ys)

8、其中,lkl表示kl散度损失函数,yt是教师图模型对输入样本中单个节点输出的分类预测值,ys是学生图模型对yt所对应的同一节点输出的分类预测值,kl(·)表示kl散度函数。

9、可选的,预设的损失函数用于计算一次训练所采用的所有节点的节点损失的均值,单个的节点损失采用以下加权损失函数计算:

10、ls=λlkl+(1-λ)lce

11、其中,ls表示加权损失函数,lkl表示用于计算教师图模型和学生图模型对同一节点输出的分类预测值间损失的kl散度损失函数,lce表示用于计算学生图模型对设有标签的节点输出的分类预测值和标签之间损失的交叉熵损失函数,λ表示lkl的权重。

12、可选的,kl散度损失函数为:

13、lkl=kl(yt||ys)

14、其中,lkl表示kl散度损失函数,yt是教师图模型对输入样本中单个节点输出的分类预测值,ys是学生图模型对yt所对应的同一节点输出的分类预测值,kl(·)表示kl散度函数。

15、可选的,图数据为论文引用关系图数据,节点代表论文,所述边代表论文之间的引用关系,所述标签为节点所属的主题类别;

16、所述学生图模型进行节点分类包括:从输入的论文引用关系图数据中提取各节点的节点特征,并根据节点特征预测节点的主题分类预测值。

17、根据本专利技术的第二方面,提供一种对图数据中节点进行分类的方法,包括:获取待分类的图数据以及按照第一方面所述的方法得到的经训练的学生图模型;利用所述经训练的学生图模型对待分类的图数据中的节点进行分类。

18、根据本专利技术的第三方面,提供一种电子设备,包括:一个或多个处理器;以及存储器,其中存储器用于存储可执行指令;所述一个或多个处理器被配置为经由执行所述可执行指令以实现第一方面和/或第二方面所述方法的步骤。

19、与现有技术相比,本专利技术的优点在于:

20、本专利技术实施例提供了一种基于图神经网络架构挑选的知识蒸馏方法,一方面,利用神经架构搜索技术和预设分类任务对应的验证集,采用强化学习机制从所述搜索空间中搜索执行该分类任务的最优图神经网络结构,所述预设分类任务与训练教师图模型时所对应的分类任务相同,不仅能够搜索出适应于固定任务的最优图神经网络结构,以解决师生网络结构不匹配的问题,又能够利用教师图模型的知识来指导学生图模型的训练,从而提高学生图模型的性能。

本文档来自技高网...

【技术保护点】

1.一种基于图神经网络架构挑选的知识蒸馏方法,其特征在于,包括:

2.根据权利要求1所述的方法,其特征在于,所述强化学习机制包括:

3.根据权利要求1或2所述的方法,其特征在于,所述预设的蒸馏方式包括:

4.根据权利要求3所述的方法,其特征在于,所述预设的损失函数用于计算一次训练所采用的所有节点的节点损失的均值,单个节点损失采用以下KL散度损失函数计算:

5.根据权利要求3所述的方法,其特征在于,所述预设的损失函数用于计算一次训练所采用的所有节点的节点损失的均值,单个的节点损失采用以下加权损失函数计算:

6.根据权利要求5所述的方法,其特征在于,KL散度损失函数为:

7.根据权利要求3所述的方法,其特征在于,所述图数据为论文引用关系图数据,节点代表论文,所述边代表论文之间的引用关系,所述标签为节点所属的主题类别;

8.一种对图数据中节点进行分类的方法,包括:

9.一种计算机可读存储介质,其特征在于,其上存储有计算机程序,所述计算机程序可被处理器执行以实现权利要求1至8之一所述方法的步骤

10.一种电子设备,其特征在于,包括:

...

【技术特征摘要】

1.一种基于图神经网络架构挑选的知识蒸馏方法,其特征在于,包括:

2.根据权利要求1所述的方法,其特征在于,所述强化学习机制包括:

3.根据权利要求1或2所述的方法,其特征在于,所述预设的蒸馏方式包括:

4.根据权利要求3所述的方法,其特征在于,所述预设的损失函数用于计算一次训练所采用的所有节点的节点损失的均值,单个节点损失采用以下kl散度损失函数计算:

5.根据权利要求3所述的方法,其特征在于,所述预设的损失函数用于计算一次训练所采用的所有节点的节点损失的均值,单个的节...

【专利技术属性】
技术研发人员:刘静郝沁汾
申请(专利权)人:中国科学院计算技术研究所
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1