TensorFlow基本操作3-索引与切片
重点函数
索引与切片:
x[idx0][idx1][idx2]
x[𝑠𝑡𝑎𝑟𝑡: 𝑒𝑛𝑑]
x[𝑠𝑡𝑎𝑟𝑡: 𝑒𝑛𝑑: 𝑠𝑡𝑒𝑝]
x[A, ... , B, 𝑠𝑡𝑒𝑝]
选择索引:
tf.gather()
:在某一维度按索引取值
tf.gather(a ,axis= ,indices= )
tf.gather(a,axis=2,indices=[2,1,7,0,6])
tf.gather_nd()
:在各个维度按不同的索引取值
tf.gather_nd(a,[[0,1,2]])
推荐 indices format :
[[0], [1], ...]
[[0,0], [1,1], ...]
[0,0,0], [1,1,1], ...]
tf.boolean_mask()
: 按True,False决定是否取值
tf.boolean_mask(a, mask=[True,False] ,axis=)
6 索引与切片
6.1 索引
在 TensorFlow 中,支持基本的[𝑖][𝑗]…
标准索引方式,也支持通过逗号分隔索引号的索引方式。
考虑输入 X 为 4 张 32x32 大小的彩色图片(为了方便演示,大部分张量都使用随机分布模拟产生,后文同),shape 为[4,32,32,3]。
basic indexing
//首先创建张量
x = tf.random.normal([4,32,32,3])
//取第 1 张图片的数据:
x[0]
//取第 1 张图片,第 2 行,第 3 列的像素:
x[0][1][2]
//取第 3 张图片,第 2 行,第 1 列的像素,B 通道(第 2 个通道)颜色强度值:
x[2][1][0][1]
numpy-style indexing
当张量的维度数较高时,使用[𝑖][𝑗]. . .[𝑘]的方式书写不方便,可以采用[𝑖,𝑗, … , 𝑘]的方式索引,它们是等价的。
//如取第 2 张图片,第 10 行,第 3 列:
x[1,9,2]
6.2 切片
通过𝑠𝑡𝑎𝑟𝑡: 𝑒𝑛𝑑: 𝑠𝑡𝑒𝑝
切片方式可以方便地提取一段数据。
[start, end)
默认 start = 0; end = -1;
其中 start 为开始读取位置的索引,end 为结束读取位置的索引(不包含 end 位),step 为读取步长。
//以 shape 为[4,32,32,3]的图片张量为例:
//读取第 2,3 张图片:
x[1:3]
切片总是返回一个 vector
索引可能返回一个 scalar
start: end: step
切片方式有很多简写方式,其中 start、end、step 3 个参数可以根据需要选择性地省略,全部省略时即::,表示从最开始读取到最末尾,步长为 1,即不跳过任何元素。
//如 x[0,::] 表示读取第 1 张图片的所有行,其中::表示在行维度上读取所有行,它等于x[0]的写法:
x[0] == x[0,::]
//为了更加简洁,::可以简写为单个冒号:,如
x[:,0:28:2,0:28:2,:]
//表示取所有图片,隔行采样,隔列采样,所有通道信息,相当于在图片的高宽各缩放至原来的 50%。
我们来总结start: end: step切片的简写方式,其中从第一个元素读取时 start 可以省略, 即 start=0 是可以省略,取到最后一个元素时 end 可以省略,步长为 1 时 step 可以省略,简写方式总结如表格 4.1:
切片方式 | 意义 |
---|---|
start:end:step | 从 start 开始读取到 end(不包含 end),步长为 step |
start:end | 从 start 开始读取到 end(不包含 end),步长为 1 |
start: | 从 start 开始读取完后续所有元素,步长为 1 |
start::step | 从 start 开始读取完后续所有元素,步长为 step |
:end:step | 从 0 开始读取到 end(不包含 end),步长为 step |
:end | 从 0 开始读取到 end(不包含 end),步长为 1 |
::step | 每隔 step-1 个元素采样所有 |
:: | 读取所有元素 |
ps: step 可以为负数。
考虑最特殊的一种例子,step = −1 时(逆序),start: end: −1表示从 start 开始,逆序读取至 end 结束(不包含 end),索引号𝑒𝑛𝑑 ≤ 𝑠𝑡𝑎𝑟𝑡。考虑一 0~9 简单序列,逆序取到第 1 号元素,不包含第 1 号:
x = tf.range(9)
x[8:0:-1]
output: <tf.Tensor: id=466, shape=(8,), dtype=int32, numpy=array([8, 7, 6, 5, 4, 3,
2, 1])>
//逆序取全部元素:
x[::-1]
//逆序间隔采样:
x[::-2]
当张量的维度数量较多时,不需要采样的维度一般用单冒号:表示采样所有元素,此时有可能出现大量的:出现。继续考虑[4,32,32,3]的图片张量,当需要读取 G 通道上的数据时,前面所有维度全部提取,此时需要写为: x[:,:,:,1]
为了避免出现像 𝑥[: , : , : ,1]
这样出现过多冒号的情况,可以使用 ⋯
符号表示取多个维度上所有的数据,其中维度的数量需根据规则自动推断:当切片方式出现 ⋯
符号时, ⋯
符号左边的维度将自动对齐到最左边,⋯符号右边的维度将自动对齐到最右边,此时系统再自动推断⋯符号代表的维度数量,它的切片方式总结如表格所示:
切片方式 | 意义 |
---|---|
a,⋯,b | a 维度对齐到最左边,b 维度对齐到最右边,中间的维度全部读取,其他维度按 a/b 的方式读取 |
a,⋯ | a 维度对齐到最左边,a 维度后的所有维度全部读取,a 维度按 a 方式 |
读取。这种情况等同于 a 索引/切片方式 | |
⋯,b | b 维度对齐到最右边,b 之前的所有维度全部读取,b 维度按 b 方式读取 |
⋯ | 读取张量所有数据 |
考虑如下例子:
//读取第 1-2 张图片的 G/B 通道数据:
x[0:2,...,1:]
//读取最后 2 张图片:
x[2:,...]
//读取 R/G 通道数据:
x[...,:2]
6.3 选择索引
tf.gather()
:在某一维度按索引取值
tf.gather(a ,axis= ,indices= )
a=tf.random.normal([4,35,8])
a[2:4].shape
//TensorShape([2, 35, 8])
tf.gather(a,axis=2,indices=[2,1,7,0,6]).shape
//TensorShape([4, 35, 5])
tf.gather_nd()
:在各个维度按不同的索引取值
In [9]: a.shape
Out[9]: TensorShape([4, 35, 8])
In [10]: tf.gather_nd(a,[0]).shape
Out[10]: TensorShape([35, 8])
In [11]: tf.gather_nd(a,[0,1]).shape
Out[11]: TensorShape([8])
In [12]: tf.gather_nd(a,[0,1,2]).shape
Out[12]: TensorShape([])
In [13]: tf.gather_nd(a,[[0,1,2]]).shape
Out[13]: TensorShape([1])
In [14]: tf.gather_nd(a,[[0,0],[1,1]]).shape
Out[14]: TensorShape([2, 8])
In [15]: tf.gather_nd(a,[[0,0],[1,1],[2,2]]).shape
Out[15]: TensorShape([3, 8])
In [16]: tf.gather_nd(a,[[0,0,0],[1,1,1],[2,2,2]]).shape
Out[16]: TensorShape([3])
推荐 indices format :
[[0], [1], ...]
[[0,0], [1,1], ...]
[0,0,0], [1,1,1], ...]
tf.boolean_mask()
: 按True,False决定是否取值
tf.boolean_mask(a,mask=[True,False],axis=)
In [20]: a=tf.random.normal([4,28,28,3])
In [21]: a.shape
Out[21]: TensorShape([4, 28, 28, 3])
In [23]: tf.boolean_mask(a,mask=[True,True,False,False]).shape
Out[23]: TensorShape([2, 28, 28, 3])
In [24]: tf.boolean_mask(a,mask=[True,True,False],axis=3).shape
Out[24]: TensorShape([4, 28, 28, 2])
In [25]: a=tf.ones([2,3,4])
In [26]: tf.boolean_mask(a,mask=[[True,True,False],[True,True,False]])
Out[26]:
<tf.Tensor: id=136, shape=(4, 4), dtype=float32, numpy=
array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32)>