来源: arXiv:2103.14829v1
这篇文章的目标是利用transformer实现真正的端到端多目标跟踪器的训练,这里的端到端是指给定一段图像序列,网络能够自动的处理轨迹的产生和终止以及生长。具体而言,提出的MO3TR模型使用temporal transformer实现每个轨迹历史特征的融合并预测当前时刻该轨迹的特征,另外使用spatial transformer刻画object之间的位置关系以及object与image之间的关系。 object之间的关系能够刻画目标之间可能存在的交互,而object与image的交互能够实现潜在的关联,从而完成轨迹生长。
注意: 其实该方法思路很简单,是在TrackerFormer上的扩展,其难点在于模型的训练,如何有效的提供训练数据。
MO3TR: Multi-Object TRacking using spatial TRansformers and temporal TRansformers.
MO3TR的框架图如下:
Temporal Transformer
首先来介绍下Temporal Transformers,这个模块并行作用在每个track上,目的是有效的利用每条轨迹的历史状态预测当前时刻用于检测和关联的隐状态。给定某条轨迹, 其中表示轨迹的起始时刻, 表示轨迹m在某一历史时刻的历史状态,那么使用transformer时需要提供每个embedding的pos编码,而query是当前帧,即状态下的位置编码, 于是轨迹m在当前帧中的预测状态为:
表示包含pos编码的每个时刻的特征向量。
这里每条轨迹的长度是可变的,同样通过迭代可以预测多步之后的状态。
作者认为temporal transformer生成轨迹预测的优势:1)利用attention的方式能够有效的筛选有用的历史信息,使特征鲁棒;2)利用位置编码能够学到有效的运动信息,而不单单是表观特征变化。
Spatial Transformers
这一部分和TrackFormer相似,以预测的轨迹状态和可学习的embedding作为queries,以当前图像的特征encoder作为key和value,本质就是self-attention + cross-attention. 同样的decoder包含多个decoder layer。
其中可学习的的embedding主要用来检测新产生的目标,而预测状态则表示已经存在的轨迹。
对于每个query都有一个输出特征, 对该特征进行分类和回归任务。这部分其实和TrackFormer类似,不同点在于分类后的query,通过阈值决定了那些轨迹已经终止和哪些检测是轨迹的新起点,从而把当前时刻的新状态更新到历史集合中。
注意:这里提到的right-aligning 嵌入集合,本质上是对当前时刻存在的轨迹在当前时刻对齐,它影响的其实是pos编码。
训练
训练部分值得注意的包含两点;
- 如何构建训练数据集,因为模型要建模轨迹的初始化、轨迹的终止和遮挡等情况,因此该类的数据应该足够,只能通过数据增广的方式实现。
- 因为预测的query既包括已有轨迹的持续生长,也包括新轨迹的初始化,所以应当确定如何给不同的query分配标签。
针对于问题1. 的做法:首先训练一个detr检测器,那么检测器就能提供第k帧图像的前K帧图像中的目标的embedding,可以将该embedding作为历史状态。这时候采用的损失函数和标签匹配采用的和detr相同。 为了刻画轨迹的初始化,轨迹的终止和轨迹的遮挡,使用了三种对应的数据增广:a.随机的丢掉部分embedding,使得模型更加鲁棒,而丢失的是最近时刻的embedding时相当于目标的重识别;b. 随机的插入一些positive examples, 增加对遮挡带来的模糊的处理能力;c. 随机选择轨迹的长度,使模型能够处理不同长度的预测。
在标签分配上,检测的标签分配和DETR相同,而对于tracking部分的query,其label是历史label相同。没有匹配目标label的query的标签赋值背景标签。
实验部分
- 细节。
图像特征抽取的backbone是ImageNet上预训练的,首先使用CrowdHuman, ETH和CUHJ-SYSU训练一个行人检测器,其次在MOT17上训练temporal transformer和整个模型。 模型遵循DETR的策略,训练300个epoch, 初始学习率1e-4, 每100个epoch以0.1衰减一次。历史状态的长度是1~30. 为了更好的训练temporal transformers, 连续预测未来10步的状态,然后计算均值损失。(这部分细节不是很清楚) - 实验结果
a. 在public 检测条件下,性能相对于TrackFormer提升了1.4个点的MOTA,在FP, FN和IDS上都有提升。
b. 分析实验验证了不同的数据增广的有效性。
c. temporal attention部分的可视化解释描述的不是很清楚。
总结
文章的思路挺清晰,主要是temporal transformer模块对多帧图像的处理,其难点在于模型训练部分。
文章对于temporal transformer模块的描述不是很清楚,轨迹终止的时刻不是很清楚,在当前帧如果track query没有检测到是否表示轨迹终止了呢?如果是的话那轨迹重新找回是怎么刻画的呢?
该模型的整体速度应该还可以,因为轨迹历史状态都已经是存储的,只需要当前帧图像的特征抽取,和两种transformers的处理。