对象分类模型训练方法、对象分类预测方法及装置制造方法及图纸

技术编号:39437399 阅读:11 留言:0更新日期:2023-11-19 16:20
本说明书的实施例提供对象分类模型训练方法、对象分类预测方法及装置。在进行模型训练时,将对象图数据样本的对象节点特征数据和图拓扑结构信息提供给教师图神经网络模型来确定出对象图数据样本的用于指导对象分类模型训练的第一分类标签,对象分类模型包括感知机模型。随后,使用对象图数据的对象节点特征数据,在调整后的第一分类标签的指导下训练对象分类模型,对象图数据样本的第一分类标签根据该对象图数据样本的真实标签和该对象图数据样本经过对象分类模型后的第二分类标签进行调整。行调整。行调整。

【技术实现步骤摘要】
对象分类模型训练方法、对象分类预测方法及装置


[0001]本说明书实施例通常涉及人工智能领域,尤其涉及基于定制化知识蒸馏的对象分类模型训练方法、对象分类预测方法及装置。

技术介绍

[0002]近年来,图神经网络(Graph Neural Networks,GNN)模型被越来越广泛地应用于图数据的各类任务处理。GNN模型的实现依赖于消息传递范式,该范式通过迭代传播图上的图节点特征来提取图信息,进而学习图节点特征表示。然而,随着传播次数的增加,节点邻居数量将呈指数增长。邻居数量爆炸式增长导致高额时间和算力开销,从而阻碍了GNN模型在对推理时延具有严格要求的应用场景中的部署和应用。

技术实现思路

[0003]本说明书实施例提供基于定制化知识蒸馏的对象分类模型训练方法、对象分类预测方法及装置。利用该对象分类模型训练方法及对象分类预测方法,可以利用定制化的知识蒸馏策略来修正教师图神经网络所提供的指导知识,使得所训练出的基于感知机结构的对象分类预测模型具有与图神经网络模型相当的模型性能,并且保持高效推理速度。
[0004]根据本说明书实施例的一个方面,提供一种基于定制化知识蒸馏的对象分类模型训练方法,包括:将对象图数据样本的对象节点特征数据和图拓扑结构信息提供给教师图神经网络模型来确定出所述对象图数据样本的用于指导对象分类模型训练的第一分类标签,所述对象分类模型包括感知机模型;以及使用对象图数据样本的对象节点特征数据,在调整后的第一分类标签的指导下训练所述对象分类模型,所述对象图数据样本的第一分类标签根据该对象图数据样本的真实标签和该对象图数据样本经过所述对象分类模型后的第二分类标签进行调整。
[0005]可选地,在上述方面的一个示例中,所述使用对象图数据样本的对象节点特征数据,在调整后的第一分类标签的指导下训练所述对象分类模型可以包括:将对象图数据样本的对象节点特征数据提供给所述对象分类模型,得到对象图数据样本的第二分类标签;根据对象图数据样本的真实标签和第二分类标签调整对象图数据样本的第一分类标签;以及根据对象图数据样本的真实标签、经过调整后的第一分类标签以及第二分类标签确定模型损失函数,并根据所述模型损失函数调整所述对象分类模型的模型参数。
[0006]可选地,在上述方面的一个示例中,所述对象分类模型训练时所使用的模型损失函数包括第一损失项和第二损失项,所述第一损失项包括对象图数据样本的经过调整的第一分类标签和第二分类标签之间的知识蒸馏损失项,所述第二损失项包括基于对象图数据样本的真实标签和第二分类标签确定的模型损失项。
[0007]可选地,在上述方面的一个示例中,所述知识蒸馏损失项包括KL散度损失项,以及所述模型损失项包括交叉熵损失项。
[0008]可选地,在上述方面的一个示例中,所述对象图数据样本的真实标签包括经过标
签平滑正则化后的真实标签。
[0009]可选地,在上述方面的一个示例中,所述第一分类标签的调整因子基于所述教师图神经网络和所述对象分类模型的交叉熵确定。
[0010]可选地,在上述方面的一个示例中,所述对象分类模型训练方法还可以包括:对所述对象图数据样本的对象节点特征数据进行图增强处理。相应地,使用对象图数据样本的对象节点特征数据,在所述调整后的第一分类标签的指导下训练所述对象分类模型可以包括:使用对象图数据样本的经过图增强处理后的对象节点特征数据,在所述调整后的第一分类标签的指导下训练所述对象分类模型。
[0011]可选地,在上述方面的一个示例中,所述图增强处理包括基于广义PageRank算法的图增强处理。
[0012]可选地,在上述方面的一个示例中,所述基于广义PageRank算法的图增强处理使用基于卷积矩阵维度和对象特征维度加权的多项式图滤波器。
[0013]根据本说明书的实施例的另一方面,提供一种对象分类预测方法,包括:从对象图数据中获取待分类对象的对象节点特征数据;以及将所述对象节点特征数据提供给对象分类模型来进行对象分类预测,所述对象分类模型按照如上所述的对象分类模型训练方法训练出。
[0014]可选地,在上述方面的一个示例中,所述对象分类预测方法还可以包括:对所获取的对象节点特征数据进行图增强处理。相应地,将所述对象节点特征数据提供给对象分类模型来进行对象分类预测可以包括:将经过图增强处理的对象节点特征数据提供给对象分类模型来进行对象分类预测。
[0015]可选地,在上述方面的一个示例中,所述对象分类预测方法还可以包括:对所获取的对象节点特征数据进行基于一阶邻居的聚合特征的近似邻居特征聚合处理。相应地,将所述对象节点特征数据提供给对象分类模型来进行对象分类预测可以包括:将经过近似邻居特征聚合处理的对象节点特征数据提供给对象分类模型来进行对象分类预测。
[0016]根据本说明书的实施例的另一方面,提供一种基于定制化知识蒸馏的对象分类模型训练装置,包括:指导知识确定单元,将对象图数据样本的对象节点特征数据和图拓扑结构信息提供给教师图神经网络模型来确定出所述对象k图数据样本的用于指导对象分类模型训练的第一分类标签,所述对象分类模型包括感知机模型;以及对象分类模型训练单元,使用对象图数据样本的对象节点特征数据,在调整后的第一分类标签的指导下训练所述对象分类模型,所述对象图数据样本的第一分类标签根据该对象图数据样本的真实标签和该对象图数据样本经过所述对象分类模型后的第二分类标签进行调整。
[0017]可选地,在上述方面的一个示例中,所述对象分类模型训练单元包括:模型预测模块,将对象图数据样本的对象节点特征数据提供给所述对象分类模型,得到对象图数据样本的第二分类标签;指导知识调整模块,根据对象图数据样本的真实标签和第二分类标签调整对象图数据样本的第一分类标签;以及模型调整模块,根据对象图数据样本的真实标签、经过调整后的第一分类标签以及第二分类标签确定模型损失函数,并根据所述模型损失函数调整所述对象分类模型的模型参数。
[0018]可选地,在上述方面的一个示例中,所述模型损失函数包括第一损失项和第二损失项,所述第一损失项包括所述调整后的第一分类标签和所述对象分类模型的第二分类标
签之间的知识蒸馏损失项,所述第二损失项包括基于所述对象图数据样本的真实标签和所述第二分类标签的模型损失项。
[0019]可选地,在上述方面的一个示例中,所述对象分类模型训练装置还可以包括:图增强处理单元,对所述对象图数据样本的对象节点特征数据进行图增强处理。相应地,所述对象分类模型训练单元使用所述对象图数据样本的经过图增强处理后的对象节点特征数据,在所述调整后的第一分类标签的指导下训练所述对象分类模型。
[0020]可选地,在上述方面的一个示例中,所述图增强处理使用基于卷积矩阵维度和对象特征维度加权的多项式图滤波器,所述多项式图滤波器的参数与所述对象分类模型的参数一起根据所述模型损失函数调整。
[0021]根据本说明书的实施例的另一方面,提供一种对象分类预测装置,包括:节点特征数据获取单元本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于定制化知识蒸馏的对象分类模型训练方法,包括:将对象图数据样本的对象节点特征数据和图拓扑结构信息提供给教师图神经网络模型来确定出所述对象图数据样本的用于指导对象分类模型训练的第一分类标签,所述对象分类模型包括感知机模型;以及使用对象图数据样本的对象节点特征数据,在调整后的第一分类标签的指导下训练所述对象分类模型,所述对象图数据样本的第一分类标签根据该对象图数据样本的真实标签和该对象图数据样本经过所述对象分类模型后的第二分类标签进行调整。2.如权利要求1所述的对象分类模型预测方法,其中,使用对象图数据样本的对象节点特征数据,在调整后的第一分类标签的指导下训练所述对象分类模型包括:将对象图数据样本的对象节点特征数据提供给所述对象分类模型,得到对象图数据样本的第二分类标签;根据对象图数据样本的真实标签和第二分类标签调整对象图数据样本的第一分类标签;以及根据对象图数据样本的真实标签、经过调整后的第一分类标签以及第二分类标签确定模型损失函数,并根据所述模型损失函数调整所述对象分类模型的模型参数。3.如权利要求1所述的对象分类模型训练方法,其中,所述对象分类模型训练时所使用的模型损失函数包括第一损失项和第二损失项,所述第一损失项包括对象图数据样本的经过调整的第一分类标签和第二分类标签之间的知识蒸馏损失项,所述第二损失项包括基于对象图数据样本的真实标签和第二分类标签确定的模型损失项。4.如权利要求3所述的对象分类模型训练方法,其中,所述知识蒸馏损失项包括KL散度损失项,以及所述模型损失项包括交叉熵损失项。5.如权利要求1所述的对象分类模型训练方法,其中,所述对象图数据样本的真实标签包括经过标签平滑正则化后的真实标签。6.如权利要求1所述的对象分类模型训练方法,其中,所述第一分类标签的调整因子基于所述教师图神经网络和所述对象分类模型的交叉熵确定。7.如权利要求1所述的对象分类模型训练方法,还包括:对所述对象图数据样本的对象节点特征数据进行图增强处理,使用对象图数据样本的对象节点特征数据,在所述调整后的第一分类标签的指导下训练所述对象分类模型包括:使用对象图数据样本的经过图增强处理后的对象节点特征数据,在所述调整后的第一分类标签的指导下训练所述对象分类模型。8.如权利要求7所述的对象分类模型训练方法,其中,所述图增强处理包括基于广义PageRank算法的图增强处理。9.如权利要求8所述的对象分类模型训练方法,其中,所述基于广义PageRank算法的图增强处理使用基于卷积矩阵维度和对象特征维度加权的多项式图滤波器。10.一种对象分类预测方法,包括:从对象图数据中获取待分类对象的对象节点特征数据;以及将所述对象节点特征数据提供给对象分类模型来进行对象分类预测,所述对象分类模型按照如权利要求1到9中任一所述的对象分类模型训练方法训练出。
11.如权利要求10所述的对象分类预测方法,还包括:对所获取的对象节点特征数据进行图增强处理,将所述对象节点特征数据提供给对象分类模型来进行对象分类预测包括:将经过图增强处理的对象节点特征数据提供给对象分类模型来进行对象分类预测。12.如权利要求10所述的对象分类预测方法,还包括:对所获取的对象节点特征数据进行基于一阶邻居的聚合特征的近似邻居特征聚合处理,将所述对象节点特征数据提供给对象分类模型来进行对象分类预测包括:将经过近似邻居特征聚合处理的对象节点特征数据提供给对象分类模型来进行对象分类预测。1...

【专利技术属性】
技术研发人员:韦绍玮吴郑伟张志强周俊
申请(专利权)人:支付宝杭州信息技术有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1