一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法技术

技术编号:33705696 阅读:20 留言:0更新日期:2022-06-06 08:27
本发明专利技术公开了一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法,包括:选择ETH行人轨迹数据集作为数据源,选择Social GAN的方式,借助长短期记忆网络LSTM对行人历史轨迹和当前位置建模,实现行人轨迹预测;其中,生成器使用长短期记忆网络对数据中的每一个行人的历史轨迹进行特征分析;鉴别器部则通过数个全连接层进行输入特征的提取,同样通过LSTM网络进行历史轨迹的特征记忆;考虑到官方的ETH数据集中并不包含人物ID标签信息,通过对ETH数据集和其补充数据集的使用,成功对模型进行了训练,并选择常用轨迹预测指标ADE和FDE作为性能评价指标。FDE作为性能评价指标。FDE作为性能评价指标。

【技术实现步骤摘要】
一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法


[0001]本专利技术涉及行人轨迹预测
,特别是涉及一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法。

技术介绍

[0002]与汽车和骑自行车者等交通场景中的其它移动目标相比,行人轨迹预测具有更大的挑战性,因为行人既不像在机动车那样在规定车道内行驶,也不像非机动车那样保持在车道边界内前进,行人在运动时通常是没有规律地流动。由于有已经定义好的一系列“道路规则”,汽车与自行车的运动状态都会因车主遵守规则而受到约束。但行人却不一样,没有相应的法律法规来限制行人应当按照什么样的轨迹运动,当交通场景中出现没有交通信号灯的路口,或行人数量较多时,行人运动就变得更加复杂。因此,需要有效的行人运动轨迹预测算法来解决这些挑战。
[0003]预测行人的运动轨迹需要考虑多种可能影响行人运动的因素。近年来,多数研究从行人行为的角度入手,例如,考虑行人在无交通信号灯的路口对迎面而来的汽车所做出的反应,以此探寻车辆与行人进行交互的机制;预测行人何时过街;此外,对行人行为进行在线预测,需要从传感器中获取数据并提取多种线索,例如,使用机器视觉技术获取多种类型的上下文线索。
[0004]1、基于静态环境线索的预测:
[0005]有的学者提出了行为CNN,采用神经网络对拥挤场景中的行人行为进行建模并证明其有效性。也有的学者从固定场景内的历史轨迹信息中学习常微分方程的加权和,提出了一种新的行人位置预测方法并验证了其良好的效果。
[0006](2)基于动态环境线索的预测:
[0007]行人行为也会受到场景中其它动态目标的影响。有的学者研究了一种微观模拟模型,在无交通信号灯的路口对骑车人的行为进行分析。除了要考虑其他道路使用者的存在以外,交通参与者还应就通行顺序进行协商以协调交通行为,也有的学者提供了一个用于研究交通参与者行为的新数据集,并重点研究了驾驶人与行人的交流方式以避免与对方碰撞。
[0008](3)基于目标线索的预测:
[0009]人们不可能时刻都对周围的环境非常熟悉,而行人注意力不集中往往是交通事故发生的重要原因。可以依靠行人头部方向来判断行人是否注意到正在驶来的车辆,具体做法是使用多个离散定向分类器的结果,添加物理约束和时间滤波来提高鲁棒性,获得连续的头部方向估计。此外,还可以使用神经网络来实现人体全身骨架的实时2D估计,进而达到检测图像中多个人姿态的效果。行人的全身外观也可以用于轨迹预测,例如,对目标进行分类并预测特定类别目标的轨迹或姿态7叫1;还可在行人边界框周围采用密集光流特征来估计过街行人是否会停在路边。

技术实现思路

[0010]有鉴于此,本专利技术的目的在于提供一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法,用以训练一个新的神经网络模型,提升了预测的性能。
[0011]为了实现上述目的,本专利技术采用如下技术方案:
[0012]一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法,所述行人轨迹预测方法包括如下步骤:
[0013]步骤S1、获取ETH数据集以及ETH补充数据集,并将二者进行匹配处理,获得ETH数据集中每一个人物对象的明确轨迹,以该每一个人物对象的明确轨迹作为训练与测试用的数据集;
[0014]步骤S2、将步骤S1中得到的训练用的数据集按照一定比例划分为训练集和测试集,并且进行预处理;
[0015]步骤S3、构建行人轨迹预测网络,其包括:该行人轨迹预测网络包括生成器网络和鉴别器网络,其中,生成器网络与鉴别器网络构成生成对抗网络;
[0016]步骤S4、将步骤S2中得到的训练集输入至步骤S3中构建的行人轨迹预测网络中,进行模型训练,经过多轮迭代训练,直到损失函数收敛,固定网络参数;
[0017]步骤S5、进行行人轨迹预测,其包括:将步骤S2中得到测试集输入至步骤S4中得到的行人轨迹预测模型中进行预测,得到预测结果。
[0018]进一步的,所述步骤S1具体包括:
[0019]步骤S101、首先获取ETH数据集,该ETH数据集是以时间标签,人ID标签,人位置点坐标(x,y)构成的参数,然后再通过字符串形式读取该ETH数据集中idl标签文件,接着通过Python中的通用代码re模块来读取idl标签文件中的有效信息,最后将有效信息以csv表格的形式导出;
[0020]步骤S102、首先获取ETH补充数据集,其中,该数据集中不存在额外的标签数据文件,只有所有人物图片的分割;然后参照ETH数据集中的标签信息,对该ETH补充数据集中的各个图片进行命名,同时将所有的图片按照人的分类划分到不同的文件夹中;最后通过Python中的通用代码os模块获取各个文件夹的文件目录;
[0021]步骤S103、将步骤S101中得到的csv表格,以及步骤S102中得到的各个文件目录,在Python中的数据列表中进行逐个匹配,得到ETH数据集中每一个人物对象的明确轨迹;
[0022]步骤S104、将轨迹数据划分为观测轨迹数据,预测轨迹数据,时间列表,ID列表,以此构建模型训练与测试用的数据集。
[0023]进一步的,所述步骤S2包括:
[0024]步骤S201、将步骤S104中构建的数据集按划分比例4:1把数据划分成训练集和测试集;
[0025]步骤S202、将所有数据进行标准化,把所有数据标准化到0与1之间。
[0026]进一步的,生成器网络还包括:编码器、解码器、社会特征嵌入层以及社会池化层,其中,编码器和解码器均包括全连接层和LSTM层;
[0027]在生成器网络中,解码器和社会特征嵌入层设置在前,其输入均为轨迹数据,解码器和社会特征嵌入层的输出均传递至社会特征嵌入层中,社会特征嵌入层的输出与噪声进行合并输入至解码器中,解码器的输出传递至鉴别器网络中,其中,在编码器中,全连接层
设置在前,其后连接有多个LSTM网络,相应的,在解码器中,多个LSTM网络设置在前,其后均连接一个全连接层。
[0028]进一步的,在鉴别器网络中,一个全连接层设置在前,另外一个全连接层设置在后,在该两个全连接层之间设置有多个LSTM网络。
[0029]进一步的,所述鉴别器网络的计算公式如下所示:
[0030][0031][0032][0033][0034]在公式(1)

公式(4)中,t=1,...,t
obs
,...,t
obs
+t
prep
,T
i
表示真假轨迹的并集,to
bs
为过去轨迹时间长度,t
prep
为未来轨迹时间长度,X
i
,Y
i
为位置坐标,δ为一个全连接层,用于将二维坐标转化为特征向量,W
δ
为该全连接层参数;LSTM层将每个时刻的特征向量进行编码,直到t=t
obs
+t
prep...

【技术保护点】

【技术特征摘要】
1.一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法,其特征在于,所述行人轨迹预测方法包括如下步骤:步骤S1、获取ETH数据集以及ETH补充数据集,并将二者进行匹配处理,获得ETH数据集中每一个人物对象的明确轨迹,以该每一个人物对象的明确轨迹作为训练与测试用的数据集;步骤S2、将步骤S1中得到的训练用的数据集按照一定比例划分为训练集和测试集,并且进行预处理;步骤S3、构建行人轨迹预测网络,其包括:该行人轨迹预测网络包括生成器网络和鉴别器网络,其中,生成器网络与鉴别器网络构成生成对抗网络;步骤S4、将步骤S2中得到的训练集输入至步骤S3中构建的行人轨迹预测网络中,进行模型训练,经过多轮迭代训练,直到损失函数收敛,固定网络参数;步骤S5、进行行人轨迹预测,其包括:将步骤S2中得到测试集输入至步骤S4中得到的行人轨迹预测模型中进行预测,得到预测结果。2.根据权利要求1所述的一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法,其特征在于,所述步骤S1具体包括:步骤S101、首先获取ETH数据集,该ETH数据集是以时间标签,人ID标签,人位置点坐标(x,y)构成的参数,然后再通过字符串形式读取该ETH数据集中idl标签文件,接着通过Python中的通用代码re模块来读取idl标签文件中的有效信息,最后将有效信息以csv表格的形式导出;步骤S102、首先获取ETH补充数据集,其中,该数据集中不存在额外的标签数据文件,只有所有人物图片的分割;然后参照ETH数据集中的标签信息,对该ETH补充数据集中的各个图片进行命名,同时将所有的图片按照人的分类划分到不同的文件夹中;最后通过Python中的通用代码os模块获取各个文件夹的文件目录;步骤S103、将步骤S101中得到的csv表格,以及步骤S102中得到的各个文件目录,在Python中的数据列表中进行逐个匹配,得到ETH数据集中每一个人物对象的明确轨迹;步骤S104、将轨迹数据划分为观测轨迹数据,预测轨迹数据,时间列表,ID列表,以此构建模型训练与测试用的数据集。3.根据权利要求2所述的一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法,其特征在于,所述步骤S2包括:步骤S201、将步骤S104中构建的数据集按划分比例4:1把数据划分成训练集和测试集;步骤S202、将所有数据进行标准化,把所有数据标准化到0与1之间。4.根据权利要求3所述的一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法,其特征在于,生成器网络还包括:编码器、解码器、社会特征嵌入层以及社会池化层,其中,编码器和解码器均包括全连接层和LSTM层;在生成器网络中,解码器和社会特征嵌入层设置在前,其输入均为轨迹数据,解码器和社会特征嵌入层的输出均传递至社会特征嵌入层中,社会特征嵌入层的输出与噪声进行合并输入至解码器中,解码器的输出传递至鉴别器网络中,其中,在编码器中,全连接层设置在前,其后连接有多个LSTM网络,相应的,在解码器中,多个LSTM网络设置在前,其后均连接一个全连接层。
5.根据权利要求4所述的一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法,其特征在于,在鉴别器网络中,一个全连接层设置在前,另外一个全连接层设置在后,在该两个全连接层之间设置有多个LSTM网络。6.根据权利要求5所述的一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法,其特征在于,所述鉴别器网络的计算公式如下所示:法,其特征在于,所述鉴别器网络的计算公式如下所示:法,其特征在于,所述鉴别器网络的计算公式如下所示:法,其特征在于,所述鉴别器网络的计算公式如下所示:在公式(1)
...

【专利技术属性】
技术研发人员:王翔辰杨欣樊江锋
申请(专利权)人:南京航空航天大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1