基于图表示学习的身份保持对抗训练方法、装置、介质制造方法及图纸

技术编号:33284117 阅读:16 留言:0更新日期:2022-04-30 23:47
本发明专利技术提供了一种基于图表示学习的身份保持对抗训练方法、装置、介质,所述方法包括:获取训练场景的图数据,定义图数据的每一个节点为用于表征训练场景的一个原样本,定义原样本的样本身份信息;生成每一个原样本对应的对抗样本;通过为对抗样本添加身份保持约束,将对抗样本保持原样本的样本身份信息;将对抗样本作为第一输入变量,输入至初始图表示学习模型,执行身份保持对抗训练;更新初始图表示学习模型,得到目标图表示学习模型,利用目标图表示学习模型预测训练场景中所述原样本在不同图挖掘任务下的输出。该方法将对抗样本与原样本保持相同的样本身份信息,提升了图表示学习在图结构数据分析中的精度,具有一定的普适性。性。性。

【技术实现步骤摘要】
基于图表示学习的身份保持对抗训练方法、装置、介质


[0001]本专利技术涉及图数据挖掘
,尤其涉及一种基于图表示学习的身份保持对抗训练方法、装置、介质。

技术介绍

[0002]图表示学习成为分析图结构数据的热门研究领域。在软件层面,图表示学习旨在学习一种编码函数,该函数充分利用图数据的优势,将具有复杂结构的图数据转换为保留多样化图属性和结构特征的低维空间中的密集表示。目前,图表示学习方法广泛应用于节点分类、异常检测、连边预测、标签推荐等各种图挖掘任务中。同时为实际生活中的大量应用问题带来了突破性的进展,例如在商品推荐场景中,用户对于商品的喜好程度的预测可以形式化为连边预测问题,其中节点代表用户或者商品,连边表示其喜好程度。药物靶点预测是连边预测任务对应的另一个重要的实际应用,其中节点表示药物或者蛋白质,连边表示两者之间是否能起作用。在金融风险控制场景中,已知借款人之间的交易关系,需要判断每个借款人的风险程度即其还款能力,防止将贷款发放给高风险人群。该场景是典型的节点分类问题,不同的类别表示不同的风险程度,每个借款人是一个节点,用连边表示借款人间的交易信息。
[0003]然而这些方法都忽略了现实世界中图数据的噪声,例如在推荐场景中用户可能存在误点击,对不喜欢的商品点了赞。同时在算法训练的过程中可能面临着过拟合的问题。基于对抗训练的图表示学习方法,通过在图表示学习方法之上引入对抗训练,将对抗样本及其对应的原始样本一同用于训练图表示学习模型,以提升训练数据的多样性从而解决上述问题。但该方法由于未保证对抗样本的质量,容易在对抗样本中引入错误信息并损害图表示学习模型性能,影响了在商品推荐场景、药物靶点预测、金融风险控制场景等实际训练场景中的节点分类、异常检测、连边预测、标签推荐等各种图挖掘任务中的图结构数据分析精度。

技术实现思路

[0004]本专利技术实施例的目的是提供一种基于图表示学习的身份保持对抗训练方法、装置、存储介质、电子设备,该方法生成的对抗样本与原样本保持相同的样本身份信息,提升了图表示学习在图挖掘任务中的图结构数据分析精度。
[0005]为了实现上述目的,本专利技术一方面提供一种基于图表示学习的身份保持对抗训练方法,包括:
[0006]获取训练场景的图数据,定义所述图数据的每一个节点为用于表征所述训练场景的一个原样本,定义所述原样本的样本身份信息;
[0007]生成每一个所述原样本对应的对抗样本;
[0008]通过为所述对抗样本添加身份保持约束,将所述对抗样本保持所述原样本的样本身份信息;
[0009]将所述对抗样本作为第一输入变量,输入至初始图表示学习模型,执行身份保持对抗训练;
[0010]更新所述初始图表示学习模型,得到目标图表示学习模型,利用所述目标图表示学习模型预测所述训练场景中所述原样本在不同图挖掘任务下的输出。
[0011]可选的,所述生成每一个所述原样本对应的对抗样本,包括:
[0012]将每一个所述原样本作为第二输入变量,输入至所述初始图表示学习模型进行训练,得到每一个所述原样本的表示;
[0013]为每一个所述原样本的表示添加扰动因子,生成每一个所述原样本对应的对抗样本。
[0014]可选的,所述为每一个所述原样本的表示添加扰动因子,生成每一个所述原样本对应的对抗样本之前,还包括:
[0015]生成所述扰动因子,
[0016]所述扰动因子包括扰动强度与扰动方向,所述扰动强度用于确定所述对抗样本与所述原样本之间的距离,所述扰动方向用于确定所述对抗样本的生成方向;
[0017]对于每一个所述原样本,生成的对抗样本的表示为:
[0018]h

i
=h
i
+r
i
*d
i
[0019]其中,r
i
、d
i
、h
i
分别为第i个所述原样本对应的扰动强度、扰动方向、表示。
[0020]可选的,所述生成所述扰动因子,包括:
[0021]确定所述扰动强度,包括:
[0022]将每一个节点的邻居节点的表示作为输入,输入至感知机模型,通过自适应学习得到每一个所述节点对应的扰动强度;
[0023]所述扰动强度表示为:
[0024]r
i
=ReLU(W2*ReLU(W1h
i
))
[0025]其中,W1、W2是可学习的参数,ReLU是激活函数;
[0026]构造所述扰动强度满足的约束条件,得到第一损失函数;
[0027]所述第一损失函数表示为:
[0028][0029]其中,N表示所述节点的数量。
[0030]可选的,所述生成所述扰动因子,还包括:
[0031]确定所述扰动方向,包括:
[0032]通过最大化所述初始图表示学习模型的原损失函数方向,确定所述扰动方向,所述扰动方向表示为:
[0033][0034]其中,L
basic
表示所述初始图表示学习模型的原损失函数。
[0035]可选的,所述通过为所述对抗样本添加身份保持约束,将所述对抗样本保持所述原样本的样本身份信息,包括:
[0036]构造所述身份保持约束,包括:
[0037]负采样每一个所述节点,构造所述身份保持约束,得到第二损失函数;
[0038]所述第二损失函数表示为:
[0039][0040]其中,h
k
为第k个所述节点的表示,i≠k;p(n)表示每一个所述节点被采样到的概率,按均匀分布负采样;K是负采样的个数;σ(h

i
,h
i
)表示原样本与其相对应的对抗样本之间的关系;σ(h

i
,h
k
)表示原样本与其不对应的对抗样本之间的关系;
[0041]采用二分类判别器,根据所述第二损失函数,将所述原样本每一个节点的表示对应一个类别。
[0042]可选的,所述更新所述初始图表示学习模型,得到目标图表示学习模型,利用所述目标图表示学习模型预测所述训练场景中所述原样本在不同图挖掘任务下的输出,包括:
[0043]确定整体损失函数,所述整体损失函数表示为:
[0044]L=L
basic
+λ1L
basic
+λ2L
id
+λ3L
norm
[0045]其中,λ1、λ2、λ3表示收敛因子;L
basic
表示所述初始图表示学习模型的原损失函数,L
norm
表示第一损失函数,L
id
表示第二损失函数,L

basic
表示将所述对抗样本作为第一输入变量,输入至所述初始图表示学习模型,执行身份保持对抗训练,得到的第三损失函数;
[0046]根据所述整体本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于图表示学习的身份保持对抗训练方法,其特征在于,包括:获取训练场景的图数据,定义所述图数据的每一个节点为用于表征所述训练场景的一个原样本,定义所述原样本的样本身份信息;生成每一个所述原样本对应的对抗样本;通过为所述对抗样本添加身份保持约束,将所述对抗样本保持所述原样本的样本身份信息;将所述对抗样本作为第一输入变量,输入至初始图表示学习模型,执行身份保持对抗训练;更新所述初始图表示学习模型,得到目标图表示学习模型,利用所述目标图表示学习模型预测所述训练场景中所述原样本在不同图挖掘任务下的输出。2.根据权利要求1所述的方法,其特征在于,所述生成每一个所述原样本对应的对抗样本,包括:将每一个所述原样本作为第二输入变量,输入至所述初始图表示学习模型进行训练,得到每一个所述原样本的表示;为每一个所述原样本的表示添加扰动因子,生成每一个所述原样本对应的对抗样本。3.根据权利要求2所述的方法,所述为每一个所述原样本的表示添加扰动因子,生成每一个所述原样本对应的对抗样本之前,还包括:生成所述扰动因子,所述扰动因子包括扰动强度与扰动方向,所述扰动强度用于确定所述对抗样本与所述原样本之间的距离,所述扰动方向用于确定所述对抗样本的生成方向;对于每一个所述原样本,生成的对抗样本的表示为:h

i
=h
i
+r
i
*d
i
其中,r
i
、d
i
、h
i
分别为第i个所述原样本对应的扰动强度、扰动方向、表示。4.根据权利要求3所述的方法,其特征在于,所述生成所述扰动因子,包括:确定所述扰动强度,包括:将每一个节点的邻居节点的表示作为输入,输入至感知机模型,通过自适应学习得到每一个所述节点对应的扰动强度;所述扰动强度表示为:r
i
=ReLU(W2*ReLU(W1h
i
))其中,W1、W2是可学习的参数,ReLU是激活函数;构造所述扰动强度满足的约束条件,得到第一损失函数;所述第一损失函数表示为:其中,N表示所述节点的数量。5.根据权利要求4所述的方法,其特征在于,所述生成所述扰动因子,还包括:确定所述扰动方向,包括:通过最大化所述初始图表示学习模型的原损失函数方向,确定所述扰动方向,所述扰动方向表示为:
其中,L
basic
表示所述初始图表示学习模型的原损失函数。6.根据权利要求5所述的方法,其特征在于,所述通过为所述对抗样本添加身份保持约束,将所述对抗样本保持所述原样本的样本身份信息,包括:构造所述身份保持约束,包括:负采样每一个所述节点,构造所述身份保持约束,得到第二损失函数;所述第二损失函数表示为:其中,h
k
为第k个所述节点的表示,i≠k;p(n)表示每一个所述节点被采样到的概率,按均匀分布负采样;K是负采样的个数;σ(h

i
,h
i
)表示原样本与其相对应的对抗样本之间的关系;σ(h

i
...

【专利技术属性】
技术研发人员:沈华伟岑科廷曹婍徐冰冰程学旗
申请(专利权)人:中国科学院计算技术研究所
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1