STN(Spatial Transformer Networks)网络学习(附代码)

参考资料:
[1]. spatial transformer network 李宏毅教学视频
[2]. 知乎 Spatial Transformer Networks
[3]. 详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了
[4]. kevinzakka/spatial-transformer-network

代码:

from scipy import ndimage
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2


def gen_grid(o_dims):
    height,width = o_dims

    x = np.linspace(0, 1.0, width, endpoint=False)
    y = np.linspace(0, 1.0, height, endpoint=False)
    # x = np.linspace(0, width, width, endpoint=False)
    # y = np.linspace(0, height, height, endpoint=False)

    x_table,y_table = np.meshgrid(x,y)
    ones_table = np.ones(shape=o_dims)

    grid = np.concatenate((np.expand_dims(y_table,0), np.expand_dims(x_table,0), np.expand_dims(ones_table,0)))
    flatten_grid = np.reshape(grid,(3,-1))

    flatten_grid = tf.convert_to_tensor(flatten_grid, dtype='float32')
    return flatten_grid


def get_pixel_value(imgs, x, y):
    num_batches = x.shape[0]
    height = x.shape[1]
    width = x.shape[2]

    b = tf.range(num_batches)
    b = tf.reshape(b, shape=(num_batches, 1, 1))
    b = tf.tile(b, [1, height, width])

    indices = tf.stack([b,y,x], axis=3)

    return tf.gather_nd(imgs, indices)


def test_get_pixel_value():
    N = 2
    H = 4
    W = 5
    C = 3
    imgs = tf.range(N*C*H*W)
    imgs = tf.reshape(imgs,shape=[N, C, H, W])
    imgs = tf.transpose(imgs, [0,2,3,1])


    # print(imgs.eval()[0,:,:,0])

    x = tf.zeros(shape=(N, H, W), dtype='int32')
    y = tf.zeros(shape=(N, H, W), dtype='int32')

    ret = get_pixel_value(imgs, x, y)

    tf.InteractiveSession()
    print(ret.eval().shape)


def bilinear_interpolation(imgs, x, y):
    num_batches, height, width = imgs.shape[0], imgs.shape[1], imgs.shape[2]

    x0 = tf.cast(tf.floor(x), 'int32')
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), 'int32')
    y1 = y0 + 1

    x0 = tf.clip_by_value(x0, 0, width-1)
    x1 = tf.clip_by_value(x1, 0, width-1)
    y0 = tf.clip_by_value(y0, 0, height-1)
    y1 = tf.clip_by_value(y1, 0, height-1)

    a = get_pixel_value(imgs, x0, y1)
    b = get_pixel_value(imgs, x1, y1)
    c = get_pixel_value(imgs, x0, y0)
    d = get_pixel_value(imgs, x1, y0)

    x0 = tf.cast(x0, 'float32')
    x1 = tf.cast(x1, 'float32')
    y0 = tf.cast(y0, 'float32')
    y1 = tf.cast(y1, 'float32')

    wa = (x1 - x) * (y - y0)
    wb = (x - x0) * (y - y0)
    wc = (x1 - x) * (y1 - y)
    wd = (x - x0) * (y1 - y)

    wa = tf.expand_dims(wa, axis=3)
    wb = tf.expand_dims(wb, axis=3)
    wc = tf.expand_dims(wc, axis=3)
    wd = tf.expand_dims(wd, axis=3)

    inter_img = a * wa + b * wb + c * wc + d * wd

    # 保证图片色彩正常显示
    inter_img = tf.clip_by_value(inter_img, 0, 255)
    inter_img = tf.cast(inter_img, 'uint8')

    return inter_img


def STN(input, thetas, o_shape=None):
    if o_shape is None:
        o_shape = input.shape[1:3]

    num_batches = thetas.shape[0]

    # expand_dims是为了后面的相乘
    in_shape = input.get_shape()[1:3]
    in_shape = tf.expand_dims(in_shape,0)
    in_shape = tf.expand_dims(in_shape,2)

    flatten_grid = gen_grid(o_shape)

    # (B*2*3)*(3,O_H*O_W) = B * 2 * (O_H*O_W)
    locations_in_input = tf.matmul(thetas, flatten_grid)
    locations_in_input = locations_in_input * tf.cast(in_shape, 'float32')
    y, x = tf.split(locations_in_input, 2, axis=1)
    x = tf.reshape(x, shape=(num_batches, o_shape[0], o_shape[1]))
    y = tf.reshape(y, shape=(num_batches, o_shape[0], o_shape[1]))

    output = bilinear_interpolation(input, x, y)

    tf.InteractiveSession()
    output = output.eval()

    return output


if __name__ == '__main__':
    img = cv2.imread('./cat.jpg')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    height = img.shape[0]
    width = img.shape[1]

    imgs = np.concatenate((np.expand_dims(img,0), np.expand_dims(img,0)))
    imgs = tf.convert_to_tensor(imgs, dtype='float32')

    thetas = [
        [[1., 0., .5],
         [0., 1., 0.]],
        [[1., 0., 0],
         [0., 1., 0]],
    ]
    thetas = tf.convert_to_tensor(thetas, dtype='float32')

    output = STN(imgs, thetas, (height//2, width//2))

    plt.figure()
    plt.subplot(131)
    plt.imshow(img)
    plt.subplot(132)
    plt.imshow(output[0])
    plt.subplot(133)
    plt.imshow(output[1])

    plt.show()

    # test_get_pixel_value()


image.png
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 220,976评论 6 513
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 94,249评论 3 396
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 167,449评论 0 360
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 59,433评论 1 296
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 68,460评论 6 397
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 52,132评论 1 308
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,721评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,641评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 46,180评论 1 319
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 38,267评论 3 339
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,408评论 1 352
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 36,076评论 5 347
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,767评论 3 332
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,255评论 0 23
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,386评论 1 271
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,764评论 3 375
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,413评论 2 358

推荐阅读更多精彩内容