import tensorflow as tf
input = [[[1,1,1],[2,2,2]],
[[3,3,3],[4,4,4]],
[[5,5,5],[6,6,6]]]
x = tf.slice(input,[0,0,0],[1,2,3])
sess = tf.InteractiveSession()
print(sess.run(x))
>>> [[[1 1 1]
[[2 2 2]]]
首先来看tf.slice里的几个参数,
- input代表输入的tensor,
- [0,0,0]代表begin,起始值
- [1,2,3]代表切的大小size。
要明白tf.slice是一个切片函数,那应该怎么切呢?
注意到tf.slice从begin开始切,
例如上面就是从[0,0,0],也就是第0行第0列第0维开始切,
然后size[1,2,3]表示切出1行2列3维的大小。
所以切出来了:
[[[1 1 1]
[[2 2 2]]]
倘若是下面的代码:
import tensorflow as tf
input = [[[1,1,1],[2,2,2]],
[[3,3,3],[4,4,4]],
[[5,5,5],[6,6,6]]]
x = tf.slice(input,[1,0,0],[2,1,3])
sess = tf.InteractiveSession()
print(sess.run(x))
>>> [[[3 3 3]]
[[5 5 5]]]
tf.slice会从第一行第0列第0维开始切,
并切出2行1列3维的大小。