使用torch.nn.functional.grid_sample对flownet得到的光流结果进行对齐

flownet2 计算视频中前后两帧的光流信息

    def resample(self, image, flow):    
        '''
        image: 上一帧的图片,torch.Size([1, 3, 256, 256])
        flow: 光流, torch.Size([1, 2, 256, 256])
        final_grid:  torch.Size([1, 2, 256, 256])
        '''
        b, c, h, w = image.size()
        grid = get_grid(b, h, w, gpu_id=flow.get_device())    
        flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1)    
        final_grid = (grid + flow).permute(0, 2, 3, 1).cuda(image.get_device())
        output = torch.nn.functional.grid_sample(image, final_grid, mode='bilinear', padding_mode='border')
        return output

Reference:
1.crop pooling
2.What is the equivalent of torch.nn.functional.grid_sample in Tensorflow / Numpy?

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容