概述
光流与warp
参考https://blog.csdn.net/qq_33757398/article/details/106332814
函数定义
- 源码中的函数定义为:
def dense_image_warp(
image: types.TensorLike, flow: types.TensorLike, name: Optional[str] = None
) -> tf.Tensor
官方给的定义给出了计算公式:
Apply a non-linear warp to the image, where the warp is specified by a
dense flow field of offset vectors that define the correspondences of
pixel values in the output image back to locations in the source image.
Specifically, the pixel value atoutput[b, j, i, c]
is
images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]
.
- 接收三个参数。
image
是图像的tf数组,flow
是其对应的光流数组,name
用于tf.name_scope
,如果为None则tf.name_scope
默认为dense_image_warp
。源码中的注释为:
image: 4-D float
Tensor
with shape[batch, height, width, channels]
.
flow: A 4-D floatTensor
with shape[batch, height, width, 2]
.
name: A name for the operation (optional).
图像和光流不必是同一类型的tf数组。
Note that image and flow can be of type
tf.half
,tf.float32
, or
tf.float64
, and do not necessarily have to be the same type.
- 返回值是和输入图像大小相同的warp后的图像。
光流warp流程
在看源码实现之前,我们先对dense warp的流程进行梳理。dense warp过程比较简单,可以分为以下几个步骤:
- 生成meshgrid
- 将光流flow和meshgrid叠加,得到query_points
- 用得到的query_points在图像image上索引对应的值
接下来我们详细来看每一步的原理:
生成meshgrid
meshgrid是索引坐标的数组。在TensorFlow中可以用meshgrid
函数直接生成。
x = [1, 2, 3]
y = [4, 5, 6]
X, Y = tf.meshgrid(x, y)
# X = [[1, 2, 3],
# [1, 2, 3],
# [1, 2, 3]]
# Y = [[4, 4, 4],
# [5, 5, 5],
# [6, 6, 6]]
# from https://www.tensorflow.org/api_docs/python/tf/meshgrid
为什么要有这样的函数?如果我们将X
和Y
叠起来得到一个grid
:
grid = tf.stack([X, Y], axis=2)
此时,grid[0, 0]=[1, 4]
,grid[2, 2]=[3, 6]
,我们建立了两个grid之间的索引,这极大方便了warp的过程。为了方便光流的计算,这里我们生成一个与图像以及光流大小相同的meshgrid:
grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height))
stacked_grid = tf.cast(tf.stack([grid_y, grid_x], axis=2), flow.dtype)
将光流flow和meshgrid叠加,得到query_points
上一步生成的stacked_grid尺寸为,是对输入图像坐标的索引。如
stacked_grid[0, 0] = [0,0], stacked_grid[1, 2] = [1,2]
。现在我们将光流与其相减(或相加,取决于光流为正向还是负向),根据光流的定义,我们就得到了一个新的grid:query_points_on_grid
。我们可以将这个grid看作warp结果图像在(i, j)
这个坐标的像素要从原图像那个位置来取。即结果图中的(0, 0)
这个位置的像素值,需要从原图像的query_points_on_grid[0, 0]
的位置来取。
虽然上述方法我们已经得到了warp后的结果,但是query_points_on_grid[0, 0]
的结果可能是浮点数,因此我们需要进行插值的操作来得到整数索引所对应的值。
用得到的query_points在图像image上索引对应的值
这一操作在pytorch里被实现为torch.nn.functional.grid_sample
,网址:https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
在tf源码里,tensorflow_addons/image/dense_image_warp.py
文件中有一个函数interpolate_bilinear
实现了类似的功能。
在warp中的实现思路很简单,需要计算两个变量,然后进行双线性插值。
首先我们需要算出query_points_on_grid
中的每个坐标对应到原图上后四个交点所对应的值,如上图所示。我们直接对query_points_on_grid
的值向下取整,就得到了每个点对应的原图中位置的左上角坐标floor
,然后去原图索引就得到了左上角角点的值top_left
。有左上角坐标,右下角坐标无非就是横纵坐标+1。这样四个角点在原图所对应的值也能都找到了。我们用四个数组top_left, top_right, bottom_left, bottom_right
存储query_points_on_grid
每个点对应原图的四个角点的值。
另一个需要计算的值alphas
是每个点离左上角有多远,这决定了四个角点参与插值的权重。我们直接用query_points_on_grid
减去每个点对应的左上角坐标floor
就可以得到。
在计算完成后,我们使用以下的公式就完成了双线性插值:
interp_top = alphas[1] * (top_right - top_left) + top_left
interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
interp = alphas[0] * (interp_bottom - interp_top) + interp_top
dense_image_warp源码实现
tensorflow_addons.image.dense_image_warp的源代码如下所示,详情参照(https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/dense_image_warp.py)。
如前文所述,该方法主要分为三个步骤,以下代码中添加了对应注释:
def dense_image_warp(
image: types.TensorLike, flow: types.TensorLike, name: Optional[str] = None
) -> tf.Tensor:
with tf.name_scope(name or "dense_image_warp"):
# 类型转化,获取维度
image = tf.convert_to_tensor(image)
flow = tf.convert_to_tensor(flow)
batch_size, height, width, channels = (
_get_dim(image, 0),
_get_dim(image, 1),
_get_dim(image, 2),
_get_dim(image, 3),
)
## 1. 生成meshgrid
grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height))
stacked_grid = tf.cast(tf.stack([grid_y, grid_x], axis=2), flow.dtype)
# 增加一个batch维度
batched_grid = tf.expand_dims(stacked_grid, axis=0)
## 2. 将光流flow和meshgrid叠加,得到query_points
query_points_on_grid = batched_grid - flow
# 展平数组
query_points_flattened = tf.reshape(
query_points_on_grid, [batch_size, height * width, 2]
)
## 3. 用得到的query_points在图像image上索引对应的值
# 双线性插值
interpolated = interpolate_bilinear(image, query_points_flattened)
interpolated = tf.reshape(interpolated, [batch_size, height, width, channels])
return interpolated
其中interpolate_bilinear
部分的源代码主要在_interpolate_bilinear_impl
部分:
def _interpolate_bilinear_impl(
grid: types.TensorLike,
query_points: types.TensorLike,
indexing: str,
name: Optional[str],
) -> tf.Tensor:
with tf.name_scope(name or "interpolate_bilinear"):
# 初始化一些变量
grid_shape = tf.shape(grid)
query_shape = tf.shape(query_points)
batch_size, height, width, channels = (
grid_shape[0],
grid_shape[1],
grid_shape[2],
grid_shape[3],
)
num_queries = query_shape[1]
query_type = query_points.dtype
grid_type = grid.dtype
alphas = [] # 对应于前文的alphas
floors = [] # 对应于前文的floors
ceils = [] # 右下角点,等于floor+1
# 是x,y 还是y,x
index_order = [0, 1] if indexing == "ij" else [1, 0]
# 将x,y分开
unstacked_query_points = tf.unstack(query_points, axis=2, num=2)
for i, dim in enumerate(index_order):
with tf.name_scope("dim-" + str(dim)):
queries = unstacked_query_points[dim]
size_in_indexing_dimension = grid_shape[i + 1]
## 计算每个点对应的左上角和右下角坐标
max_floor = tf.cast(size_in_indexing_dimension - 2, query_type)
min_floor = tf.constant(0.0, dtype=query_type)
# 将floor的值限制在一定范围
floor = tf.math.minimum(
tf.math.maximum(min_floor, tf.math.floor(queries)), max_floor
)
int_floor = tf.cast(floor, tf.dtypes.int32)
floors.append(int_floor)
ceil = int_floor + 1
ceils.append(ceil)
## 计算alphas的值,就是query_points减去floor
alpha = tf.cast(queries - floor, grid_type)
min_alpha = tf.constant(0.0, dtype=grid_type)
max_alpha = tf.constant(1.0, dtype=grid_type)
# 同样,将值限制在0到1之间
alpha = tf.math.minimum(tf.math.maximum(min_alpha, alpha), max_alpha)
# 为了广播计算扩展维度
alpha = tf.expand_dims(alpha, 2)
alphas.append(alpha)
flattened_grid = tf.reshape(grid, [batch_size * height * width, channels])
batch_offsets = tf.reshape(
tf.range(batch_size) * height * width, [batch_size, 1]
)
# gather函数的作用是从flattened_grid取出(y_coords, x_coords)所对应的值
# 源代码中的注释:It's possible this code would be made simpler by using tf.gather_nd.
# 即可以使用tf.gather_nd进行简化,tf.gather_nd可以不用linear_coordinates,直接用元组索引
# 关于gather_nd和gather可以参考https://blog.csdn.net/oksupersonic/article/details/104559821
def gather(y_coords, x_coords, name):
with tf.name_scope("gather-" + name):
# batch_offsets保证了不同batch的每一个坐标有唯一的linear_coordinates值
linear_coordinates = batch_offsets + y_coords * width + x_coords
# 调用了tf.gather,从flattened_grid中取出linear_coordinates坐标对应的值
gathered_values = tf.gather(flattened_grid, linear_coordinates)
return tf.reshape(gathered_values, [batch_size, num_queries, channels])
# 使用gather取出四个角点在原图中的值
top_left = gather(floors[0], floors[1], "top_left")
top_right = gather(floors[0], ceils[1], "top_right")
bottom_left = gather(ceils[0], floors[1], "bottom_left")
bottom_right = gather(ceils[0], ceils[1], "bottom_right")
# 有了四个角点的值以及alphas后就可以双线性插值
with tf.name_scope("interpolate"):
interp_top = alphas[1] * (top_right - top_left) + top_left
interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
interp = alphas[0] * (interp_bottom - interp_top) + interp_top
return interp
参考资料
- https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/dense_image_warp.py
- https://www.tensorflow.org/api_docs/python/tf/gather
- https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- 一文搞懂光流 光流的生成,可视化以及映射(warp).https://blog.csdn.net/qq_33757398/article/details/106332814
- 深度理解tf.gather和tf.gather_nd的用法. https://blog.csdn.net/oksupersonic/article/details/104559821
- grid_sample()函数及双线性采样. https://zhuanlan.zhihu.com/p/112030273