基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备技术

技术编号:37963724 阅读:4 留言:0更新日期:2023-06-30 09:39
本发明专利技术提供一种基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备:获取客户端本地数据集;获取初始模型,包括语义聚类模型和预训练得到的编码器网络;将本地数据集随机增强两次生成两个视图,输入初始编码器网络,提取特征向量并构建对比损失,训练得到编码器网络;将本地数据集输入编码器网络提取特征向量,并提取样本的Top

【技术实现步骤摘要】
基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备


[0001]本专利技术涉及人工智能
,尤其涉及一种基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备。

技术介绍

[0002]随着智能设备的普及,联邦学习已经成为最常用的一种隐私保护模型共享方法,并在用户习惯预测、个性化推荐和无线网络优化等许多场景中得到了广泛的应用。现有的联邦学习方法通常只考虑有监督的训练设置,其中客户端数据被完全标记。然而,包含复杂注释的本地数据对于物联网应用来说是不现实的,因为用户总是有不同的习惯和使用频率。示例性的,假设有一个照片分类器应用程序,可以实现自动对相册中的图片进行分类。在这种情况下,应用程序的用户若不愿意自己注释这些隐私和敏感的图片,则会导致服务提供商只能在中央服务器上使用有限的公共图片。因此,在许多现实的物联网场景中,客户端数据可能没有完全标记,只有少量标记数据在服务器上可用。
[0003]现有的联邦学习方法在缺乏标签数据场景下主要是采用联邦半监督学习。联邦半监督学习的目标是学习多个客户端之间的一致性。部分工作通过客户端间一致性损失,用于对标记数据和未标记数据进行分布式训练,或是考虑参数更新多样性的半监督训练多样性缩放聚合算法,在移动设备之间交换局部模型的输出,而不是典型框架中使用的模型参数交换。但是,当下游任务没有可用的标签时,这些方法的性能不佳。同时,与理想化的分布条件不同,由于用户的使用习惯和使用频率不同,物联网设备之间的数据通常是非独立的同分布,也会导致共享模型性能下降。
专利技术内容
[0004]鉴于此,本专利技术实施例提供了一种基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备,以消除或改善现有技术中存在的一个或更多个缺陷,解决现有联邦无监督学习性能较差且不适用非独立同分布场景的问题。
[0005]一方面,本专利技术提供一种基于语义聚类的联邦无监督图像分类模型训练方法,其特征在于,所述方法在各客户端执行,包括以下步骤:
[0006]获取本地数据集,所述本地数据集包含多个样本,每个样本包含一张图像;
[0007]获取初始模型,所述初始模型包括语义聚类模型和预训练得到的编码器网络;其中,将所述本地数据集的样本进行两次随机增强,生成第一视图和第二视图;将所述第一视图和所述第二视图一同输入初始编码器网络,提取第一特征向量和第二特征向量;采用所述本地数据集对所述初始编码器网络进行训练,并构建第一特征向量和第二特征向量之间的对比损失,以得到训练好的编码器网络;将所述本地数据集按批输入所述语义聚类模型,利用所述编码器网络提取对应样本的特征向量;基于预设神经网络从所述特征向量中提取对应样本的Top

K近邻样本,通过预设Softmax函数计算对应样本所属于不同集群的向量
值,以得到对应样本的类别;
[0008]采用所述本地数据集对所述初始模型进行训练,并构建聚类损失,利用所述聚类损失对所述初始模型的参数进行迭代,以得到初始图像分类模型;
[0009]将所述初始图像分类模型的模型参数发送至全局服务器,以生成共享模型;所述共享模型由所述全局服务器根据各客户端初始图像分类模型参数加权聚合得到;接收所述共享模型的参数,并采用指数移动平均更新所述初始图像分类模型,以得到最终的图像分类模型;其中,所述共享模型还包括自标记模块,所述自标记模块为基于所述共享模型得到的类别设置伪标签,并构建基于所述共享模型得到的类别与相应伪标签之间的交叉熵损失,利用所述交叉熵损失更新所述共享模型参数。
[0010]在本专利技术的一些实施例中,将所述本地数据集的样本进行两次随机增强,所述随机增强至少包括空间变换裁剪、旋转、调节饱和度、调节对比度、调节色调、调节颜色、调节亮度和调节灰度中的一种或多种组合操作。
[0011]在本专利技术的一些实施例中,构建第一特征向量和第二特征向量之间的对比损失,所述对比损失采用归一化温度交叉熵损失。
[0012]在本专利技术的一些实施例中,所述对比损失的计算式为:
[0013][0014]其中,表示所述对比损失;i,j分别表示所述第一视图和所述第二视图;z
i
,z
j
分别表示所述第一特征向量和所述第二特征向量;sim(z
i
,z
j
)表示所述第一视图和所述第二视图的相似度度量;τ是温度因子;M表示所述本地数据集中样本的数量;m表示所述本地数据集中样本的序号。
[0015]在本专利技术的一些实施例中,采用所述本地数据集对所述初始模型进行训练,并构建聚类损失;所述聚类损失的计算式为:
[0016][0017]其中,表示所述聚类损失;x表示所述本地数据集x
c
中的单个样本;表示x的相邻样本集N
x
中的单个近邻图像样本;q(
·
)表示预设函数;<
·
>表示点积运算符号;λ表示权重;k表示集群;p
k
表示被分配到集群k的概率。
[0018]在本专利技术的一些实施例中,接收所述共享模型的参数,并采用指数移动平均更新所述初始图像分类模型,计算式为:
[0019][0020]其中,q
c
表示客户端c的初始图像分类模型参数;q
g
表示所述共享模型参数;t表示第t轮所述共享模型参数聚合;μ表示预设阈值;ξ表示所述初始图像分类模型参数与所述共享模型参数在更新中分别占的权重。
[0021]在本专利技术的一些实施例中,还包括:
[0022]计算所述初始图像分类模型在训练时的模型散度,当所述模型散度大于所述预设阈值时,客户端使用所述共享模型的权重进行更新;当所述模型散度小于或等于所述预设
阈值时,客户端使用其初始图像分类模型和所述共享模型的权重组合进行更新。
[0023]在本专利技术的一些实施例中,基于预设置信阈值选择置信度大于所述预设置信阈值的样本,并为相应样本基于所述共享模型得到的类别设置伪标签,构建基于所述共享模型得到的类别与相应伪标签之间的交叉熵损失,所述交叉熵损失计算式为:
[0024][0025]其中,L
self
表示所述交叉熵损失;x表示所述全局服务器的数据集x
g
中的单个样本;σ表示所述预设置信阈值;p(x)表示所述共享模型的输出;表示样本x的伪标签;H(
·
)表示所述伪标签上的标准交叉熵损失。
[0026]另一方面,本专利技术提供一种基于语义聚类的联邦无监督图像分类方法,其特征在于,该方法在客户端执行,包括以下步骤:
[0027]获取待分类的图像;
[0028]将所述图像输入如上文中任一项所述基于语义聚类的联邦无监督图像分类模型训练方法得到的图像分类模型,以得到所述图像的类别。
[0029]另一方面,本专利技术还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现如上文中本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于语义聚类的联邦无监督图像分类模型训练方法,其特征在于,所述方法在各客户端执行,包括以下步骤:获取本地数据集,所述本地数据集包含多个样本,每个样本包含一张图像;获取初始模型,所述初始模型包括语义聚类模型和预训练得到的编码器网络;其中,将所述本地数据集的样本进行两次随机增强,生成第一视图和第二视图;将所述第一视图和所述第二视图一同输入初始编码器网络,提取第一特征向量和第二特征向量;采用所述本地数据集对所述初始编码器网络进行训练,并构建第一特征向量和第二特征向量之间的对比损失,以得到训练好的编码器网络;将所述本地数据集按批输入所述语义聚类模型,利用所述编码器网络提取对应样本的特征向量;基于预设神经网络从所述特征向量中提取对应样本的Top

K近邻样本,通过预设Softmax函数计算对应样本所属于不同集群的向量值,以得到对应样本的类别;采用所述本地数据集对所述初始模型进行训练,并构建聚类损失,利用所述聚类损失对所述初始模型的参数进行迭代,以得到初始图像分类模型;将所述初始图像分类模型的模型参数发送至全局服务器,以生成共享模型;所述共享模型由所述全局服务器根据各客户端初始图像分类模型参数加权聚合得到;接收所述共享模型的参数,并采用指数移动平均更新所述初始图像分类模型,以得到最终的图像分类模型;其中,所述共享模型还包括自标记模块,所述自标记模块为基于所述共享模型得到的类别设置伪标签,并构建基于所述共享模型得到的类别与相应伪标签之间的交叉熵损失,利用所述交叉熵损失更新所述共享模型参数。2.根据权利要求1所述的基于语义聚类的联邦无监督图像分类模型训练方法,其特征在于,将所述本地数据集的样本进行两次随机增强,所述随机增强至少包括空间变换裁剪、旋转、调节饱和度、调节对比度、调节色调、调节颜色、调节亮度和调节灰度中的一种或多种组合操作。3.根据权利要求1所述的基于语义聚类的联邦无监督图像分类模型训练方法,其特征在于,构建第一特征向量和第二特征向量之间的对比损失,所述对比损失采用归一化温度交叉熵损失。4.根据权利要求3所述的基于语义聚类的联邦无监督图像分类模型训练方法,其特征在于,所述对比损失的计算式为:其中,表示所述对比损失;i,j分别表示所述第一视图和所述第二视图;z
i
,z
j
分别表示所述第一特征向量和所述第二特征向量;sim(z
i
,z
j
)表示所述第一视图和所述第二视图的相似度度量;τ是温度因子;M表示所述本地数据集中样本的数量;m表示所述本地数据集...

【专利技术属性】
技术研发人员:高志鹏赵晨杨杨芮兰兰莫梓嘉俞新蕾熊子健
申请(专利权)人:北京邮电大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1