光流warp原理 TensorFlow dense_image_warp源码解析

概述

光流与warp

参考https://blog.csdn.net/qq_33757398/article/details/106332814

函数定义

  • 源码中的函数定义为:
def dense_image_warp(
    image: types.TensorLike, flow: types.TensorLike, name: Optional[str] = None
) -> tf.Tensor

官方给的定义给出了计算公式:

Apply a non-linear warp to the image, where the warp is specified by a
dense flow field of offset vectors that define the correspondences of
pixel values in the output image back to locations in the source image.
Specifically, the pixel value at output[b, j, i, c] is
images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c].

  • 接收三个参数。image是图像的tf数组,flow是其对应的光流数组,name用于tf.name_scope,如果为None则tf.name_scope默认为dense_image_warp。源码中的注释为:

image: 4-D float Tensor with shape [batch, height, width, channels].
flow: A 4-D float Tensor with shape [batch, height, width, 2].
name: A name for the operation (optional).

图像和光流不必是同一类型的tf数组。

Note that image and flow can be of type tf.half, tf.float32, or
tf.float64, and do not necessarily have to be the same type.

  • 返回值是和输入图像大小相同的warp后的图像。

光流warp流程

在看源码实现之前,我们先对dense warp的流程进行梳理。dense warp过程比较简单,可以分为以下几个步骤:

  • 生成meshgrid
  • 将光流flowmeshgrid叠加,得到query_points
  • 用得到的query_points在图像image上索引对应的值

接下来我们详细来看每一步的原理:

生成meshgrid

meshgrid是索引坐标的数组。在TensorFlow中可以用meshgrid函数直接生成。

x = [1, 2, 3]
y = [4, 5, 6]
X, Y = tf.meshgrid(x, y)
# X = [[1, 2, 3],
#      [1, 2, 3],
#      [1, 2, 3]]
# Y = [[4, 4, 4],
#      [5, 5, 5],
#      [6, 6, 6]]
# from https://www.tensorflow.org/api_docs/python/tf/meshgrid

为什么要有这样的函数?如果我们将XY叠起来得到一个grid:

grid = tf.stack([X, Y], axis=2)

此时,grid[0, 0]=[1, 4],grid[2, 2]=[3, 6],我们建立了两个grid之间的索引,这极大方便了warp的过程。为了方便光流的计算,这里我们生成一个与图像以及光流大小相同的meshgrid:

grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height))
stacked_grid = tf.cast(tf.stack([grid_y, grid_x], axis=2), flow.dtype)

将光流flowmeshgrid叠加,得到query_points

上一步生成的stacked_grid尺寸为width \times height \times 2,是对输入图像坐标的索引。如stacked_grid[0, 0] = [0,0], stacked_grid[1, 2] = [1,2]。现在我们将光流与其相减(或相加,取决于光流为正向还是负向),根据光流的定义,我们就得到了一个新的grid:query_points_on_grid。我们可以将这个grid看作warp结果图像在(i, j)这个坐标的像素要从原图像那个位置来取。即结果图中的(0, 0)这个位置的像素值,需要从原图像的query_points_on_grid[0, 0]的位置来取。

虽然上述方法我们已经得到了warp后的结果,但是query_points_on_grid[0, 0]的结果可能是浮点数,因此我们需要进行插值的操作来得到整数索引所对应的值。

用得到的query_points在图像image上索引对应的值

这一操作在pytorch里被实现为torch.nn.functional.grid_sample,网址:https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

在tf源码里,tensorflow_addons/image/dense_image_warp.py文件中有一个函数interpolate_bilinear实现了类似的功能。

双线性插值(from: https://zhuanlan.zhihu.com/p/112030273)

在warp中的实现思路很简单,需要计算两个变量,然后进行双线性插值。
首先我们需要算出query_points_on_grid中的每个坐标对应到原图上后四个交点所对应的值,如上图所示。我们直接对query_points_on_grid的值向下取整,就得到了每个点对应的原图中位置的左上角坐标floor,然后去原图索引就得到了左上角角点的值top_left。有左上角坐标,右下角坐标无非就是横纵坐标+1。这样四个角点在原图所对应的值也能都找到了。我们用四个数组top_left, top_right, bottom_left, bottom_right存储query_points_on_grid每个点对应原图的四个角点的值。
另一个需要计算的值alphas是每个点离左上角有多远,这决定了四个角点参与插值的权重。我们直接用query_points_on_grid减去每个点对应的左上角坐标floor就可以得到。
在计算完成后,我们使用以下的公式就完成了双线性插值:

interp_top = alphas[1] * (top_right - top_left) + top_left
interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
interp = alphas[0] * (interp_bottom - interp_top) + interp_top

dense_image_warp源码实现

tensorflow_addons.image.dense_image_warp的源代码如下所示,详情参照(https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/dense_image_warp.py)。

如前文所述,该方法主要分为三个步骤,以下代码中添加了对应注释:

def dense_image_warp(
    image: types.TensorLike, flow: types.TensorLike, name: Optional[str] = None
) -> tf.Tensor:
    with tf.name_scope(name or "dense_image_warp"):
        # 类型转化,获取维度
        image = tf.convert_to_tensor(image)
        flow = tf.convert_to_tensor(flow)
        batch_size, height, width, channels = (
            _get_dim(image, 0),
            _get_dim(image, 1),
            _get_dim(image, 2),
            _get_dim(image, 3),
        )
        
        ## 1. 生成meshgrid
        grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height))
        stacked_grid = tf.cast(tf.stack([grid_y, grid_x], axis=2), flow.dtype)
        # 增加一个batch维度
        batched_grid = tf.expand_dims(stacked_grid, axis=0)

        ## 2. 将光流flow和meshgrid叠加,得到query_points
        query_points_on_grid = batched_grid - flow
        # 展平数组
        query_points_flattened = tf.reshape(
            query_points_on_grid, [batch_size, height * width, 2]
        )

        ## 3. 用得到的query_points在图像image上索引对应的值
        # 双线性插值
        interpolated = interpolate_bilinear(image, query_points_flattened)
        interpolated = tf.reshape(interpolated, [batch_size, height, width, channels])
        return interpolated

其中interpolate_bilinear部分的源代码主要在_interpolate_bilinear_impl部分:

def _interpolate_bilinear_impl(
    grid: types.TensorLike,
    query_points: types.TensorLike,
    indexing: str,
    name: Optional[str],
) -> tf.Tensor:
    with tf.name_scope(name or "interpolate_bilinear"):
        # 初始化一些变量
        grid_shape = tf.shape(grid)
        query_shape = tf.shape(query_points)
        batch_size, height, width, channels = (
            grid_shape[0],
            grid_shape[1],
            grid_shape[2],
            grid_shape[3],
        )
        num_queries = query_shape[1]
        query_type = query_points.dtype
        grid_type = grid.dtype

        alphas = [] # 对应于前文的alphas
        floors = [] # 对应于前文的floors
        ceils = []  # 右下角点,等于floor+1
        # 是x,y 还是y,x
        index_order = [0, 1] if indexing == "ij" else [1, 0]
        # 将x,y分开
        unstacked_query_points = tf.unstack(query_points, axis=2, num=2)

        for i, dim in enumerate(index_order):
            with tf.name_scope("dim-" + str(dim)):
                queries = unstacked_query_points[dim]
                size_in_indexing_dimension = grid_shape[i + 1]

                ## 计算每个点对应的左上角和右下角坐标
                max_floor = tf.cast(size_in_indexing_dimension - 2, query_type)
                min_floor = tf.constant(0.0, dtype=query_type)
                # 将floor的值限制在一定范围
                floor = tf.math.minimum(
                    tf.math.maximum(min_floor, tf.math.floor(queries)), max_floor
                )
                int_floor = tf.cast(floor, tf.dtypes.int32)
                floors.append(int_floor)
                ceil = int_floor + 1
                ceils.append(ceil)

                ## 计算alphas的值,就是query_points减去floor
                alpha = tf.cast(queries - floor, grid_type)
                min_alpha = tf.constant(0.0, dtype=grid_type)
                max_alpha = tf.constant(1.0, dtype=grid_type)
                # 同样,将值限制在0到1之间
                alpha = tf.math.minimum(tf.math.maximum(min_alpha, alpha), max_alpha)
                # 为了广播计算扩展维度
                alpha = tf.expand_dims(alpha, 2)
                alphas.append(alpha)

            flattened_grid = tf.reshape(grid, [batch_size * height * width, channels])
            batch_offsets = tf.reshape(
                tf.range(batch_size) * height * width, [batch_size, 1]
            )

        # gather函数的作用是从flattened_grid取出(y_coords, x_coords)所对应的值
        # 源代码中的注释:It's possible this code would be made simpler by using tf.gather_nd. 
        # 即可以使用tf.gather_nd进行简化,tf.gather_nd可以不用linear_coordinates,直接用元组索引
        # 关于gather_nd和gather可以参考https://blog.csdn.net/oksupersonic/article/details/104559821
        def gather(y_coords, x_coords, name):
            with tf.name_scope("gather-" + name):
                # batch_offsets保证了不同batch的每一个坐标有唯一的linear_coordinates值
                linear_coordinates = batch_offsets + y_coords * width + x_coords
                # 调用了tf.gather,从flattened_grid中取出linear_coordinates坐标对应的值
                gathered_values = tf.gather(flattened_grid, linear_coordinates)
                return tf.reshape(gathered_values, [batch_size, num_queries, channels])

        # 使用gather取出四个角点在原图中的值
        top_left = gather(floors[0], floors[1], "top_left")
        top_right = gather(floors[0], ceils[1], "top_right")
        bottom_left = gather(ceils[0], floors[1], "bottom_left")
        bottom_right = gather(ceils[0], ceils[1], "bottom_right")

        # 有了四个角点的值以及alphas后就可以双线性插值
        with tf.name_scope("interpolate"):
            interp_top = alphas[1] * (top_right - top_left) + top_left
            interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
            interp = alphas[0] * (interp_bottom - interp_top) + interp_top

        return interp

参考资料

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 217,826评论 6 506
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,968评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,234评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,562评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,611评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,482评论 1 302
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,271评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,166评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,608评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,814评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,926评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,644评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,249评论 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,866评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,991评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,063评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,871评论 2 354

推荐阅读更多精彩内容