图神经网络的训练方法及装置制造方法及图纸

技术编号:34049257 阅读:16 留言:0更新日期:2022-07-06 15:24
本说明书实施例提供一种图神经网络的训练方法,涉及基于用户关系图谱对图神经网络进行多轮次迭代更新,其中任一轮次包括:利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量;基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签;针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。更新所述当前图神经网络中的模型参数。更新所述当前图神经网络中的模型参数。

Figure training method and device of neural network

【技术实现步骤摘要】
图神经网络的训练方法及装置


[0001]本说明书一个或多个实施例涉及机器学习
,尤其涉及一种图神经网络的训练方法及装置。

技术介绍

[0002]关系网络图是对现实世界中实体之间的关系的描述,目前被广泛应用于各种业务处理中,如社交网络分析、化学键预测等。图神经网络(Graph Neural Networks,简称GNN)适用于处理关系网络图上的各种任务,然而,GNN的性能在很大程度上依赖标注数据的数量,通常,GNN的性能会随着标注数据的减少而迅速下降。
[0003]因此,需要一种方案,能够突破GNN训练时标注数据不足的限制,得到性能优异的GNN模型,从而有效提升业务处理结果的准确度。

技术实现思路

[0004]本说明书一个或多个实施例描述了一种图神经网络的训练方法及装置,利用未标注数据扩充标注数据,并引入信息增益缩小原始标注数据分布与扩充后标注数据分布所对应训练损失之间的差异,从而有效提升GNN模型的训练效果。
[0005]根据第一方面,提供一种图神经网络的训练方法,涉及基于用户关系图谱对图神经网络进行多轮次迭代更新,其中任一轮次包括:利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量;基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签;针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
[0006]在一个实施例中,所述多个用户节点中包括第二数量的未标注节点,各个分类预测向量中包括与多个类别对应的多个预测概率;其中,基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签,包括:针对所述第二数量的未标注节点中的各个节点,若其所对应分类预测向量中包含的最大预测概率达到预设标准,则将该节点归入所述第一数量的未标注节点,并将该最大预测概率所对应的类别确定为该节点的伪分类标签。
[0007]在一个实施例中,针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益,包括:针对任意的第一未标注节点,利用其对应的第一分类预测向量和伪分类标签,训练所述当前图神经网络,并基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量;根据所述第一分类预测向量,确定第一信息熵;根据所述第二分类预测向量,确定第二信息熵;基于所述第二信息熵与所述第一信息熵的差值,得到所述信息增益。
[0008]在一个具体的实施例中,所述第一图神经网络包括多个聚合层和输出层;其中,基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量,包括:在所述多个聚合层中的某个聚合层,对上一聚合层输出的针对所述多个用户节点的多个聚合向量中的向量元素进行随机置零处理,并且,基于所述随机置零处理后的多个聚合向量,确定本聚合层针对所述多个用户节点输出的多个聚合向量;在所述输出层,对最后一个聚合层针对所述第一未标注用户节点输出的聚合向量进行处理,得到所述第二分类预测向量。
[0009]在另一个具体的实施例中,所述第一图神经网络包括多个聚合层和输出层;其中,基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量,包括:在所述多个聚合层中的某个聚合层,对所述用户关系图谱所对应邻接矩阵中的矩阵元素进行随机置零处理,并且,基于所述随机置零处理后的邻接矩阵,以及由上一聚合层输出的针对所述多个用户节点的多个聚合向量,确定本聚合层针对所述多个用户节点的多个聚合向量;在所述输出层,对最后一个聚合层针对所述第一未标注用户节点输出的聚合向量进行处理,得到所述第二分类预测向量。
[0010]进一步,在一个更具体的实施例中,基于训练出的第一图神经网络确定该未标注节点的第二分类预测向量,包括:多次执行确定所述第二分类预测向量的操作,对应得到多个第二分类预测向量;其中,根据所述第二分类预测向量,确定第二信息熵,包括:将所述多个第二分类预测向量所对应多个信息熵的均值,确定为所述第二信息熵。
[0011]在一个实施例中,根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数,包括:根据所述各个标注节点对应的分类预测向量和真实分类标签,确定第一损失项;针对所述各个未标注节点,根据其对应的分类预测向量和伪分类标签,确定第二损失项,并利用其对应的信息增益对所述第二损失项进行加权处理;根据所述第一损失项和加权处理后的第二损失项,更新所述模型参数。
[0012]在一个具体的实施例中,利用其对应的信息增益对所述第二损失项进行加权处理,包括:利用所述第一数量的未标注节点所对应第一数量的信息增益,对所述各个未标注节点的信息增益进行归一化处理,得到对应的加权系数;利用所述加权系数进行所述加权处理。
[0013]根据第二方面,提供一种图神经网络的训练方法,涉及基于预先构建的关系图谱对图神经网络进行多轮次迭代更新,其中任一轮次包括:利用当前图神经网络对所述关系图谱进行处理,得到与该关系图谱中多个业务对象节点对应的多个分类预测向量;基于所述多个分类预测向量,为所述多个业务对象节点中第一数量的未标注节点分配对应的伪分类标签;针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;根据与所述多个业务对象节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
[0014]根据第三方面,提供一种图神经网络的训练装置,所述装置通过以下单元,根据用户关系图谱对图神经网络进行多轮次迭代更新中的任一轮次:分类预测单元,配置为利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量;伪标签分配单元,配置为基于所述多个分类预测向量,为所述多个
用户节点中第一数量的未标注节点分配对应的伪分类标签;信息增益确定单元,配置为针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;参数更新单元,配置为根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。
[0015]根据第四方面,提供一种图神经网络的训练装置,所述装置通过以下单元,根据预先构建的关系图谱对图神经网络进行多轮次迭代更新中的任一轮次:分类预测单元,配置为利用当前图神经网络对所述关系图谱进行处理,得到与该关系图谱中多个业务对象节点对应的多个分类预测向量;伪标签分本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种图神经网络的训练方法,涉及基于用户关系图谱对图神经网络进行多轮次迭代更新,其中任一轮次包括:利用当前图神经网络对所述用户关系图谱进行处理,得到与该用户关系图谱中多个用户节点对应的多个分类预测向量;基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签;针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益;根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的分类预测向量、伪分类标签和信息增益,更新所述当前图神经网络中的模型参数。2.根据权利要求1所述的方法,其中,所述多个用户节点中包括第二数量的未标注节点,各个分类预测向量中包括与多个类别对应的多个预测概率;其中,基于所述多个分类预测向量,为所述多个用户节点中第一数量的未标注节点分配对应的伪分类标签,包括:针对所述第二数量的未标注节点中的各个节点,若其所对应分类预测向量中包含的最大预测概率达到预设标准,则将该节点归入所述第一数量的未标注节点,并将该最大预测概率所对应的类别确定为该节点的伪分类标签。3.根据权利要求1所述的方法,其中,针对所述第一数量的未标注节点中的各个未标注节点,确定利用其训练所述当前图神经网络而产生的信息增益,包括:针对任意的第一未标注节点,利用其对应的第一分类预测向量和伪分类标签,训练所述当前图神经网络,并基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量;根据所述第一分类预测向量,确定第一信息熵;根据所述第二分类预测向量,确定第二信息熵;基于所述第二信息熵与所述第一信息熵的差值,得到所述信息增益。4.根据权利要求3所述的方法,其中,所述第一图神经网络包括多个聚合层和输出层;其中,基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量,包括:在所述多个聚合层中的某个聚合层,对上一聚合层输出的针对所述多个用户节点的多个聚合向量中的向量元素进行随机置零处理,并且,基于所述随机置零处理后的多个聚合向量,确定本聚合层针对所述多个用户节点输出的多个聚合向量;在所述输出层,对最后一个聚合层针对所述第一未标注用户节点输出的聚合向量进行处理,得到所述第二分类预测向量。5.根据权利要求3所述的方法,其中,所述第一图神经网络包括多个聚合层和输出层;其中,基于训练出的第一图神经网络确定该第一未标注节点的第二分类预测向量,包括:在所述多个聚合层中的某个聚合层,对所述用户关系图谱所对应邻接矩阵中的矩阵元素进行随机置零处理,并且,基于所述随机置零处理后的邻接矩阵,以及由上一聚合层输出的针对所述多个用户节点的多个聚合向量,确定本聚合层针对所述多个用户节点的多个聚合向量;
在所述输出层,对最后一个聚合层针对所述第一未标注用户节点输出的聚合向量进行处理,得到所述第二分类预测向量。6.根据权利要求4或5所述的方法,其中,基于训练出的第一图神经网络确定该未标注节点的第二分类预测向量,包括:多次执行确定所述第二分类预测向量的操作,对应得到多个第二分类预测向量;其中,根据所述第二分类预测向量,确定第二信息熵,包括:将所述多个第二分类预测向量所对应多个信息熵的均值,确定为所述第二信息熵。7.根据权利要求1所述的方法,其中,根据与所述多个用户节点中各个标注节点对应的分类预测向量和真实分类标签,以及与所述各个未标注节点对应的...

【专利技术属性】
技术研发人员:胡斌斌刘洪瑞张志强石川王啸周俊
申请(专利权)人:北京邮电大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1