强化学习模型的训练方法及装置制造方法及图纸

技术编号:39158716 阅读:13 留言:0更新日期:2023-10-23 15:01
本公开涉及计算机技术领域,提供了一种强化学习模型的训练方法及装置。该方法包括:获取使用设定强化学习算法对第一强化学习模型进行训练得到的第二强化学习模型;将相同的训练数据分别输入到第一强化学习模型和第二强化学习模型,对应得到第一输出数据组和第二输出数据组;根据训练数据、第一输出数据组、第二输出数据组和设定的总损失函数获取总损失函数值,其中,总损失函数的自蒸馏损失函数部分根据第一输出数据组和第二输出数据组的距离得到;根据总损失函数值调整第一强化学习模型,直到第一强化学习模型收敛,得到训练好的目标强化学习模型。本公开的技术方案可以提高强化学习模型在实际应用中的泛化能力和可复现性。现性。现性。

【技术实现步骤摘要】
强化学习模型的训练方法及装置


[0001]本公开涉及计算机
,尤其涉及一种强化学习模型的训练方法及装置。

技术介绍

[0002]相关技术中,强化学习的应用场景逐渐增多。常用的深度强化学习,是一种使用深度神经网络的强化学习方法,对复杂问题的解决较为友好,比如运行游戏,控制机器人,控制自动驾驶等。它的主要思想是,通过不断的学习和实践,通过优化智能体在环境中的行为,来最大化未来的奖励。深度强化学习使用深度神经网络来实现,它可以处理高维度和非线性的环境,并能够更好地学习和表现。
[0003]但是,深度强化学习等强化学习算法存在一定的问题。强化学习算法的强化学习模型在训练中容易在模型参数量较大的情形下过拟合,对训练数据产生记忆行为,而无法很好地对有变化的环境很好适应,学习效果受影响,泛化能力较差。此外,强化学习模型在参数量较大的情形下对环境的采样并不稳定,从而对环境奖励机制建模出现较大的方差,从而导致模型的Q值估计不准,造成模型训练不稳定,使得强化学习得到的强化学习模型效果可复现性较差。
[0004]强化学习模型在实际应用中的泛化能力和可复现性较差是当前亟需解决的技术问题。

技术实现思路

[0005]有鉴于此,本公开实施例提供了一种强化学习模型的训练方法、装置、电子设备及计算机可读存储介质,以解决现有技术中强化学习模型在实际应用中的泛化能力和可复现性较差的技术问题。
[0006]本公开实施例的第一方面,提供了一种强化学习模型的训练方法,该方法包括:获取使用设定强化学习算法对第一强化学习模型进行训练得到的第二强化学习模型;将相同的训练数据分别输入到第一强化学习模型和第二强化学习模型,对应得到第一输出数据组和第二输出数据组;根据训练数据、第一输出数据组、第二输出数据组和设定的总损失函数获取总损失函数值,其中,总损失函数包括强化学习损失函数部分和自蒸馏损失函数部分,自蒸馏损失函数部分根据第一输出数据组和第二输出数据组的距离得到;根据总损失函数值调整第一强化学习模型,直到第一强化学习模型收敛,得到训练好的目标强化学习模型。
[0007]本公开实施例的第二方面,提供了一种强化学习模型的训练装置,该装置包括:模型获取模块,用于获取使用设定强化学习算法对第一强化学习模型进行训练得到的第二强化学习模型;输入模块,用于将相同的训练数据分别输入到第一强化学习模型和第二强化学习模型,对应得到第一输出数据组和第二输出数据组;损失函数获取模块,用于根据训练数据、第一输出数据组、第二输出数据组和设定的总损失函数获取总损失函数值,其中,总损失函数包括强化学习损失函数部分和自蒸馏损失函数部分,自蒸馏损失函数部分根据第一输出数据组和第二输出数据组的距离得到;调整模块,用于根据总损失函数值调整第一
强化学习模型,直到第一强化学习模型收敛,得到训练好的目标强化学习模型。
[0008]本公开实施例的第三方面,提供了一种电子设备,包括存储器、处理器以及存储在存储器中并且可在处理器上运行的计算机程序,该处理器执行计算机程序时实现上述方法的步骤。
[0009]本公开实施例的第四方面,提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机程序,该计算机程序被处理器执行时实现上述方法的步骤。
[0010]本公开实施例与现有技术相比存在的有益效果是:本公开实施例的技术方案通过对第一强化学习模型进行预训练得到第二强化学习模型,并将训练数据分别输入到第一强化学习模型和第二强化学习模型中,根据第一强化学习模型和第二强化学习模型输出的动作分布的距离获取损失函数,用以调整第一强化学习模型,最终得到目标强化学习模型,可以提高强化学习模型在实际应用中的泛化能力,以及强化学习模型训练过程中的稳定性和效率。
附图说明
[0011]为了更清楚地说明本公开实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本公开的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
[0012]图1是现有技术中一种强化学习组件关系的示意图;图2是现有技术中一种强化学习过程中的轨迹的示意图;图3是本公开实施例提供的一种强化学习模型的训练方法的流程示意图;图4是本公开实施例提供的另一种强化学习模型的训练过程的示意图;图5是本公开实施例提供的一种强化学习模型的训练装置的结构示意图;图6是本公开实施例提供的一种电子设备的结构示意图。
具体实施方式
[0013]以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本公开实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本公开。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本公开的描述。
[0014]DQN(Deep Q Network,深度Q网络)是一种基于神经网络的深度强化学习算法,它可以用来解决回合制强化学习问题。DQN可以通过学习一个价值函数来指导智能体选择动作,从而达到最优化智能体行为的目的。DQN算法可以用来解决多种类型的强化学习问题,例如游戏控制、机器人控制、自动驾驶等。其主要的应用场景如下:机器人控制:强化学习可以用于控制机器人,以实现机器人自主学习和行动。
[0015]自动驾驶:强化学习可以用于自动驾驶,使车辆能够在复杂的环境中安全驾驶。
[0016]游戏:强化学习可以用于游戏,让游戏角色能够从自身的行为中学习,从而更好地控制游戏。
[0017]无人机:强化学习可以用于控制无人机,让无人机能够自主学习和行动,以实现更
好的空中飞行姿态控制。
[0018]金融:强化学习可以用于金融交易,以提高金融交易的准确性和效率。
[0019]DQN算法的核心思想是将环境的状态和动作映射到一个行为值即Q值,从而指导智能体选择动作。DQN算法使用神经网络来学习这个Q值,通过反向传播算法来更新神经网络的参数,从而使得智能体能够学习最优的行为策略。
[0020]DQN算法的主要过程包括:进行环境描述、进行状态描述、构建模型、训练和测试。其中,在进行环境描述时,定义好环境,包括环境状态、可能的动作、奖励等;在进行状态表示时,将环境中的状态进行表示,以便模型可以接受和处理;在构建模型时,构建DQN模型,包括网络结构、优化器、loss函数等;在训练时,使用经验回放(Experience Replay)和目标网络(Target Network)来收集经验,并训练模型以更新Q值;在测试时,使用训练好的模型,在测试环境中运行,看看算法的表现如何。
[0021]深度强化学习等现有强化学习技术存在模型泛化能力差,容易收敛在局部最优点,算法训练稳定性差等问题。
[0022]为解决以上问题,本公开实施例提供一种强化学习模型的训练方案,通过使用自蒸馏的方法提高强化学习模型的训练效率和稳定性。
[0本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种强化学习模型的训练方法,其特征在于,所述方法包括:获取使用设定强化学习算法对第一强化学习模型进行训练得到的第二强化学习模型;将相同的训练数据分别输入到所述第一强化学习模型和所述第二强化学习模型,对应得到第一输出数据组和第二输出数据组;根据所述训练数据、所述第一输出数据组、所述第二输出数据组和设定的总损失函数获取总损失函数值,其中,所述总损失函数包括强化学习损失函数部分和自蒸馏损失函数部分,所述自蒸馏损失函数部分根据所述第一输出数据组和所述第二输出数据组的距离得到;根据所述总损失函数值调整所述第一强化学习模型,直到所述第一强化学习模型收敛,得到训练好的目标强化学习模型。2.根据权利要求1所述的方法,其特征在于,所述总损失函数根据以下公式得到:;其中,为强化学习损失函数部分,为自蒸馏损失函数部分,M为互动的局数,N为每一局的步数,为第一强化学习模型在第j局第t步的输出,为第二强化学习模型在第j局第t步的输出,α为调节系数。3.根据权利要求1所述的方法,其特征在于,所述自蒸馏损失函数部分的函数值小于等于所述强化学习损失函数部分的函数值的0.2倍。4.根据权利要求1所述的方法,其特征在于,所述设定强化学习算法包括以下任一种:策略梯度算法、异步优势演员

评论员算法、邻近策略优化算法和信任域策略优化算法。5.根据权利要求1所述的方法,其特征在于,所述第一强化学习模型包括特征提取器,所述特征提取器的网络结构包括以下任一种:卷积神经网络、长短期记忆网络和变换器。6.根据权利要求1所述的...

【专利技术属性】
技术研发人员:杜梦雪暴宇健
申请(专利权)人:深圳须弥云图空间科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1