模型训练方法和装置、电子设备、存储介质制造方法及图纸

技术编号:35274649 阅读:16 留言:0更新日期:2022-10-19 10:52
本申请实施例提供了模型训练方法和装置、电子设备、存储介质,属于人工智能技术领域。该模型训练方法包括:获取待训练的原始图像数据,并获取服务器端发送的原始模型的原始训练参数;通过原始模型对原始图像数据进行标签预测得到预测标签,通过原始模型对原始图像数据进行特征提取得到初步图像特征;对初步图像特征进行投影聚类处理,得到域标签和域组合标签;根据域标签和域组合标签计算距离损失函数;根据预设的交叉熵损失函数和距离损失函数更新第一、二原始网络参数得到第一、二目标参数;将第一、二目标参数发送给服务器端,使服务器端更新原始模型得到全局模型。本申请实施例基于联邦学习进行模型训练可以提高模型的训练效率和准确率。练效率和准确率。练效率和准确率。

【技术实现步骤摘要】
模型训练方法和装置、电子设备、存储介质


[0001]本申请涉及人工智能
,尤其涉及模型训练方法和装置、电子设备、存储介质。

技术介绍

[0002]人工智能中的联邦学习技术,可以在不交换本地数据的情况下,多中心联合训练一个机器学习模型。然而,不同客户端的图像数据之间可能存在统计异质性,例如客户端之间的图像数据分布不同的数据异质性,该数据异质性可能导致模型训练效果不好。

技术实现思路

[0003]本申请实施例的主要目的在于提出模型训练方法和装置、电子设备、存储介质,提高模型的训练效率和准确率。
[0004]为实现上述目的,本申请实施例的第一方面提出了一种模型训练方法,应用于客户端,所述模型训练方法包括:
[0005]获取待训练的原始图像数据,并获取服务器端发送的原始模型的原始训练参数;其中,所述原始图像数据是无标注数据,所述原始训练参数包括所述原始模型的第一原始网络参数、第二原始网络参数、聚类数量;
[0006]将所述原始图像数据输入所述原始模型,通过所述原始模型对所述原始图像数据进行标签预测得到预测标签,通过所述原始模型对所述原始图像数据进行特征提取得到初步图像特征;
[0007]对所述初步图像特征进行投影聚类处理,得到域标签和域组合标签;其中,域标签和域组合标签的数量均为第一数量,且所述第一数量等于所述聚类数量;
[0008]根据所述域标签和所述域组合标签计算距离损失函数;
[0009]根据预设的交叉熵损失函数和所述距离损失函数更新所述第一原始网络参数得到第一目标参数,根据所述交叉熵损失函数和所述距离损失函数更新所述第二原始网络参数得到第二目标参数;其中,所述交叉熵损失函数由所述预测标签进行预先构建得到;
[0010]将所述第一目标参数和所述第二目标参数发送给所述服务器端;其中,所述第一目标参数和所述第二目标参数用于所述服务器端更新所述原始模型得到全局模型。
[0011]在一些实施例,所述原始模型包括图卷积网络和卷积神经网络,所述通过所述原始模型对所述原始图像数据进行标签预测得到预测标签,包括:
[0012]通过所述图卷积网络对所述原始图像数据进行参数提取,得到初步参数;
[0013]通过所述卷积神经网络对所述初步参数进行特征图提取,得到所述初步特征图;
[0014]通过预设的激活函数对所述初步特征图进行标签预测,得到所述预测标签。
[0015]在一些实施例,所述对所述初步图像特征进行投影聚类处理,得到域标签和域组合标签,包括:
[0016]将所述初步图像特征映射到同一维度,得到投影特征;
[0017]通过预设的聚类算法对所述投影特征进行聚类处理,得到所述域标签和所述域组合标签。
[0018]在一些实施例,所述通过预设的聚类算法对所述投影特征进行聚类处理,得到所述域标签和所述域组合标签,包括:
[0019]通过k

means++算法对所述投影特征进行聚类处理,得到所述第一数量的聚类中心;
[0020]从所述第一数量的聚类中心中选择一个作为参考聚类中心;
[0021]计算每一所述投影特征与所述参考聚类中心的距离,得到聚类距离;
[0022]根据所述聚类距离和预设系数计算得到所述第二数量的域组合标签。
[0023]在一些实施例,所述域组合标签包括所述预设系数,所述根据所述域标签和所述域组合标签计算距离损失函数:
[0024]根据所述域组合标签获取所述预设系数的最大值,得到目标系数;
[0025]根据所述目标系数从所述第一数量的聚类中心筛选出目标中心,从所述第一数量的聚类中心过滤所述目标中心得到第二数量的当前聚类中心;其中,所述第二数量等于所述第一数量减1;
[0026]计算所述投影特征与所述目标中心之间的距离得到第一距离,并计算所述投影特征与每一所述当前聚类中心之间的距离得到所述第二数量的第二距离;
[0027]将所述第一距离与每一所述第二距离进行求差计算,得到所述第二数量的距离差;
[0028]根据所述第二数量的距离差计算所述距离损失函数。
[0029]在一些实施例,所述方法还包括:构建所述交叉熵损失函数,具体包括:
[0030]获取所述原始图像数据的原始标签;所述原始标签是所述原始图像数据的真实标签;
[0031]将所述预测标签与所述真实标签进行比对,得到标签比对结果;
[0032]根据所述预测标签和所述对比结果构建所述交叉熵损失函数。
[0033]为实现上述目的,本申请实施例的第二方面提出了一种模型训练方法,应用于服务器端,所述模型训练方法包括:
[0034]向客户端发送原始模型的训练参数;其中,所述原始训练参数包括所述原始模型的第一原始网络参数、第二原始网络参数;
[0035]获取所述客户端对所述第一原始网络参数更新得到的第一目标参数和对所述第二原始网络参数更新得到的第二目标参数;其中,所述第一目标参数和所述第二目标参数是根据如第一方面所述的模型训练方法训练得到;
[0036]将所述第一目标参数和所述第二目标参数进行整合处理,得到全局模型参数;
[0037]根据所述全局模型参数更新所述原始模型,得到全局模型。
[0038]为实现上述目的,本申请实施例的第三方面提出了一种模型训练装置,应用于客户端,所述模型训练装置包括:
[0039]原始图像获取模块,用于获取待训练的原始图像数据,并获取服务器端发送的原始模型的原始训练参数;其中,所述原始图像数据是无标注数据,所述原始训练参数包括所述原始模型的第一原始网络参数、第二原始网络参数、聚类数量;
[0040]模型处理模块,用于将所述原始图像数据输入所述原始模型,通过所述原始模型对所述原始图像数据进行标签预测得到预测标签,通过所述原始模型对所述原始图像数据进行特征提取得到初步图像特征;
[0041]聚类模块,用于对所述初步图像特征进行投影聚类处理,得到域标签和域组合标签;其中,域标签和域组合标签的数量均为第一数量,且所述第一数量等于所述聚类数量;
[0042]距离损失函数构建模块,用于根据所述域标签和所述域组合标签计算距离损失函数;
[0043]参数更新模块,用于根据预设的交叉熵损失函数和所述距离损失函数更新所述第一原始网络参数得到第一目标参数,根据所述交叉熵损失函数和所述距离损失函数更新所述第二原始网络参数得到第二目标参数;其中,所述交叉熵损失函数由所述预测标签进行预先构建得到;
[0044]参数发送模块,用于将所述第一目标参数和所述第二目标参数发送给所述服务器端;其中,所述第一目标参数和所述第二目标参数用于所述服务器端更新所述原始模型得到全局模型。
[0045]为实现上述目的,本申请实施例的第四方面提出了一种模型训练装置,应用于服务器端,所述模型训练装置包括:
[0046]训练参数发送模块,用于向客户端发送原始模型的训练参本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,应用于客户端,其特征在于,所述模型训练方法包括:获取待训练的原始图像数据,并获取服务器端发送的原始模型的原始训练参数;其中,所述原始图像数据是无标注数据,所述原始训练参数包括所述原始模型的第一原始网络参数、第二原始网络参数、聚类数量;将所述原始图像数据输入所述原始模型,通过所述原始模型对所述原始图像数据进行标签预测得到预测标签,通过所述原始模型对所述原始图像数据进行特征提取得到初步图像特征;对所述初步图像特征进行投影聚类处理,得到域标签和域组合标签;其中,域标签和域组合标签的数量均为第一数量,且所述第一数量等于所述聚类数量;根据所述域标签和所述域组合标签计算距离损失函数;根据预设的交叉熵损失函数和所述距离损失函数更新所述第一原始网络参数得到第一目标参数,根据所述交叉熵损失函数和所述距离损失函数更新所述第二原始网络参数得到第二目标参数;其中,所述交叉熵损失函数由所述预测标签进行预先构建得到;将所述第一目标参数和所述第二目标参数发送给所述服务器端;其中,所述第一目标参数和所述第二目标参数用于所述服务器端更新所述原始模型得到全局模型。2.根据权利要求1所述的模型训练方法,其特征在于,所述原始模型包括图卷积网络和卷积神经网络,所述通过所述原始模型对所述原始图像数据进行标签预测得到预测标签,包括:通过所述图卷积网络对所述原始图像数据进行参数提取,得到初步参数;通过所述卷积神经网络对所述初步参数进行特征图提取,得到所述初步特征图;通过预设的激活函数对所述初步特征图进行标签预测,得到所述预测标签。3.根据权利要求1所述的模型训练方法,其特征在于,所述对所述初步图像特征进行投影聚类处理,得到域标签和域组合标签,包括:将所述初步图像特征映射到同一维度,得到投影特征;通过预设的聚类算法对所述投影特征进行聚类处理,得到所述域标签和所述域组合标签。4.根据权利要求3所述的模型训练方法,其特征在于,所述通过预设的聚类算法对所述投影特征进行聚类处理,得到所述域标签和所述域组合标签,包括:通过k

means++算法对所述投影特征进行聚类处理,得到所述第一数量的聚类中心;从所述第一数量的聚类中心中选择一个作为参考聚类中心;计算每一所述投影特征与所述参考聚类中心的距离,得到聚类距离;根据所述聚类距离和预设系数计算得到所述第二数量的域组合标签。5.根据权利要求4所述的模型训练方法,其特征在于,所述域组合标签包括所述预设系数,所述根据所述域标签和所述域组合标签计算距离损失函数:根据所述域组合标签获取所述预设系数的最大值,得到目标系数;根据所述目标系数从所述第一数量的聚类中心筛选出目标中心,从所述第一数量的聚类中心过滤所述目标中心得到第二数量的当前聚类中心;其中,所述第二数量等于所述第一数量减1;计算所述投影特征与所述目标中心之间的距离得到第一距离,并计算所述投影特征与
每一所述当前聚类中心之间的距离得到所述第二数量的第二距离;将所述第一距离与每一所述第二距离进行求差计算...

【专利技术属性】
技术研发人员:李泽远王健宗曹康养
申请(专利权)人:平安科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1