【技术实现步骤摘要】
本专利技术属于分布式机器学习领域,具体涉及一种模型联合训练方法、装置、电子设备及存储介质。
技术介绍
1、传统的大规模数据分析依赖于分布式服务器架构,尽管分布式服务器架构提供了强大的计算能力,但它也存在一些限制。例如,数据在服务器之间的传输可能会造成延迟,服务器之间的负载可能不平衡等,这些因素限制了计算资源在分散环境中的有效利用。
2、为了应对这些挑战,协作学习方法(例如联邦学习)应运而生,旨在通过联合训练多个数据源来增强模型性能;在联合训练过程中,现有的孪生网络虽然可以在不直接共享原始数据的情况下工作,但其在训练过程中仍然需要访问成对的输入数据,这些数据可能包含敏感信息。在数据隐私和法规日益严格的今天,这种对数据的直接访问可能引发隐私泄露的风险,无法确保在训练过程中数据的机密性。
技术实现思路
1、为了解决现有技术在模型联合训练过程中数据的安全性较差的问题,本专利技术提供了一种模型联合训练方法、装置、电子设备及存储介质。
2、为了实现上述目的,本专利技术提供如下技术方案:
3、一种模型联合训练方法,包括:
4、构建包含两个子网络的伪孪生网络架构,将伪孪生网络架构部署于分布式网络的任意两个计算节点;每个计算节点包含对应的私有数据;
5、通过计算节点的私有数据分别对两个子网络进行预训练,对预训练的两个子网络的性能进行比较,确定两个子网络中的教师模型和学生模型;
6、以教师模型的预测结果作为软标签,结合真实标签训练
7、交换教师模型和学生模型,迭代训练学生模型的过程,直至两个子网络的性能达到预设性能标准。
8、可选地,以教师模型的预测结果作为软标签,结合真实标签训练学生模型,包括:
9、将目标数据集分别输入教师模型和学生模型,分别得到教师模型和学生模型的预测概率分布;
10、将学生模型的预测概率分布分别与教师模型的预测概率分布和真实标签进行比较,计算损失值;基于损失值更新学生模型的参数。
11、可选地,将学生模型的预测概率分布分别与教师模型的预测概率分布和真实标签进行比较,计算损失值,包括:
12、根据教师模型的预测概率分布和学生模型的预测概率分布计算kl散度;将kl散度作为距离度量,通过对比损失函数计算对比损失;
13、将学生模型的预测概率分布与真实标签进行比较,通过交叉熵损失函数计算交叉熵损失;
14、将交叉熵损失和对比损失进行加权求和得到总损失值。
15、可选地,通过计算节点的私有数据分别对两个子网络进行预训练,包括:
16、将两个计算节点的私有数据分别输入两个子网络,通过前向传播得到预测值;
17、通过交叉熵损失函数分别计算两个子网络预测值与私有数据真实标签之间的损失值;
18、通过两个子网络预测值与私有数据真实标签之间的损失值更新两个子网络的参数;
19、迭代前向传播、计算损失值和更新参数的过程,直至达到预设的停止条件,得到预训练的两个子网络。
20、可选地,在通过计算节点的私有数据分别对两个子网络进行预训练的过程中,包括:
21、对私有数据进行主成分分析提取主成分数据,计算主成分数据的正交投影数据;
22、交替使用主成分数据和正交投影数据进行参数更新。
23、可选地,通过奇异值分解计算主成分数据的正交投影数据,包括:
24、对私有数据进行中心化预处理得到数据矩阵;
25、通过奇异值分解将数据矩阵分解为左奇异向量组成的矩阵、奇异值组成的对角矩阵和右奇异向量组成的矩阵的转置;
26、选择奇异值中前k个主成分,通过构造投影矩阵将数据投影到主成分空间中,得到正交投影数据。
27、本专利技术还提供一种模型联合训练装置,其特征在于,包括:
28、构建模块,用于构建包含两个子网络的伪孪生网络架构,将伪孪生网络架构部署于分布式网络的任意两个计算节点;每个计算节点包含对应的私有数据;
29、第一训练模块,用于通过计算节点的私有数据分别对两个子网络进行预训练,对预训练的两个子网络的性能进行比较,确定两个子网络中的教师模型和学生模型;
30、第二训练模块,用于以教师模型的预测结果作为软标签,结合真实标签训练学生模型;
31、迭代模块,用于交换教师模型和学生模型,迭代训练学生模型的过程,直至两个子网络的性能达到预设性能标准。
32、本专利技术提供的模型联合训练方法具有以下有益效果:
33、本专利技术构建包含两个子网络的伪孪生网络架构,将伪孪生网络架构部署于分布式网络的任意两个计算节点;通过计算节点中的私有数据分别对两个子网络进行单独训练;在两个子网络中确定教师模型和学生模型,通过教师模型指导学生模型的训练,实现两个子网络的相互学习;由于伪孪生网络允许两个子网络具有相同的或相似的架构,但不共享权重,使得伪孪生网络能够在不直接共享数据或参数的情况下,对两个输入进行独立的处理,并通过某种方式(如比较输出)来协同工作;其次,教师模型和学生模型相互学习,每个模型都充当教师模型和学生模型的双重角色。一方面,它将自身的知识蒸馏给其他模型;另一方面,它也从其他模型的蒸馏中学习。这种相互学习使得两个模型能够在不直接共享数据或参数的情况下,从彼此的预测中学习到有用的信息。因此,本专利技术采用伪孪生网络架构,通过私有数据对两个子网络进行预训练,通过相互知识蒸馏在预训练后使两个模型相互学习,能够在模型联合训练过程中,不共享训练数据和网络参数,只需要在教师模型指导学生模型的输出过程中共享真实的标签和预测的标签,降低了联合训练过程中数据泄露的风险。
本文档来自技高网...【技术保护点】
1.一种模型联合训练方法,其特征在于,包括:
2.根据权利要求1所述的模型联合训练方法,其特征在于,以教师模型的预测结果作为软标签,结合真实标签训练学生模型,包括:
3.根据权利要求2所述的模型联合训练方法,其特征在于,将学生模型的预测概率分布分别与教师模型的预测概率分布和真实标签进行比较,计算损失值,包括:
4.根据权利要求1-3任一项所述的模型联合训练方法,其特征在于,通过计算节点的私有数据分别对两个子网络进行预训练,包括:
5.根据权利要求1-3任一项所述的模型联合训练方法,其特征在于,在通过计算节点的私有数据分别对两个子网络进行预训练的过程中,包括:
6.根据权利要求5所述的模型联合训练方法,其特征在于,通过奇异值分解计算主成分数据的正交投影数据,包括:
7.一种模型联合训练装置,其特征在于,包括:
8.一种计算机设备,其特征在于,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现权利要求1-6任一所述的模型联合训练方法。
9.一种计
...【技术特征摘要】
1.一种模型联合训练方法,其特征在于,包括:
2.根据权利要求1所述的模型联合训练方法,其特征在于,以教师模型的预测结果作为软标签,结合真实标签训练学生模型,包括:
3.根据权利要求2所述的模型联合训练方法,其特征在于,将学生模型的预测概率分布分别与教师模型的预测概率分布和真实标签进行比较,计算损失值,包括:
4.根据权利要求1-3任一项所述的模型联合训练方法,其特征在于,通过计算节点的私有数据分别对两个子网络进行预训练,包括:
5.根据权利要求1-3任一项所述的模型联合训练方法,其特征在于,在通过计算节点的私有...
【专利技术属性】
技术研发人员:单东晶,黄志伟,邹贝加,钟丽莎,罗亚梅,
申请(专利权)人:西南医科大学,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。