Social GAN:行人轨迹预测代码细节

paper:Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks
code:https://github.com/agrimgupta92/sgan

np.around() 返回四舍五入后的值,可指定精度。
np.transpose() 转置
np.polyfit()数据拟合
np.cumsum()累加
datasets返回值:
def __getitem__(self, index):
    start, end = self.seq_start_end[index]
    out = [
         self.obs_traj[start:end, :], self.pred_traj[start:end, :],
         self.obs_traj_rel[start:end, :], self.pred_traj_rel[start:end, :],
         self.non_linear_ped[start:end], self.loss_mask[start:end, :]
     ]
"""
obs_traj:过去序列,(peds_num,2,obs_len)
pred_traj:预测序列,(peds_num,2,pred_len)
obs_traj_rel:过去相对序列,(peds_num,2,obs_len)
pred_traj_rel:预测相对序列,(peds_num,2,pred_len)
non_linear_ped:非线性值
loss_mask:mask
"""

这里用到dataloader里的collate_fn参数,collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果。感觉在目标检测的时候用的比较多,主要是一个batch里有好几张图片的box,那需要单独加个索引再concat起来,不然因为每张图片box数量不同,是无法load进去的。

def seq_collate(data):
    (obs_seq_list, pred_seq_list, obs_seq_rel_list, pred_seq_rel_list,
     non_linear_ped_list, loss_mask_list) = zip(*data)

    _len = [len(seq) for seq in obs_seq_list]
    cum_start_idx = [0] + np.cumsum(_len).tolist()
    seq_start_end = [[start, end]
                     for start, end in zip(cum_start_idx, cum_start_idx[1:])]

    # Data format: batch, input_size, seq_len
    # LSTM input format: seq_len, batch, input_size
    obs_traj = torch.cat(obs_seq_list, dim=0).permute(2, 0, 1)
    pred_traj = torch.cat(pred_seq_list, dim=0).permute(2, 0, 1)
    obs_traj_rel = torch.cat(obs_seq_rel_list, dim=0).permute(2, 0, 1)
    pred_traj_rel = torch.cat(pred_seq_rel_list, dim=0).permute(2, 0, 1)
    non_linear_ped = torch.cat(non_linear_ped_list)
    loss_mask = torch.cat(loss_mask_list, dim=0)
    seq_start_end = torch.LongTensor(seq_start_end)
    out = [
        obs_traj, pred_traj, obs_traj_rel, pred_traj_rel, non_linear_ped,
        loss_mask, seq_start_end
    ]

    return tuple(out)

通过调用collate_fn,dataloader的返回值由(batch,peds_num,2,seq_len)->(seq_len,batch*peds_num,2)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容