TensorFlow高阶操作1-合并与分割
合并
其中一个维度可以不一致,不创建新维度。
tf.concat(tensors,axis)
维度要求一致,创建新维度。
tf.stack(tensors, axis)
分割
在某一维度上,均分成 n 部分,或者任意分割成不同大小的几部分(输入参数 List )。
tf.split(x, axis, num_or_size_splits)
在某一维度上,拆分成大小为 1 的 tensor。
tf.unstack(x, axis)
重点函数
tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
tf.one_hot(y,depth=10)
原地修改数据:w1.assign_sub(lr * grads[0])
5.1 合并与分割
5.1.1 合并
合并是指将多个张量在某个维度上合并为一个张量。
tf.concat()
拼接 在 TensorFlow 中,可以通过 tf.concat(tensors,axis=)
其中 tensors 保存了所有需要合并的张量 List,axis 指定需要合并的维度。
堆叠 tf.concat()
直接在现有维度上面合并数据,并不会创建新的维度。
tf.stack(tensors, axis)
使用 tf.stack(tensors, axis)
可以合并多个张量 tensors,其中 axis 指定插入新维度的位置,axis 的用法与 tf.expand_dims()
的一致,当 axis ≥ 0时,在 axis 之前插入;当axis < 0时,在 axis 之后插入新维度。
5.1.2 分割
合并操作的逆过程就是分割,将一个张量分拆为多个张量。
tf.split(x, axis, num_or_size_splits)
通过 tf.split(x, axis, num_or_size_splits)
可以完成张量的分割操作,其中
❑ x:待分割张量。
❑ axis:分割的维度索引号
❑ num_or_size_splits:切割方案。
当 num_or_size_splits 为单个数值时,如 10,表示切割为 10 份;
当 num_or_size_splits 为 List 时,每个元素表示每份的长度,如[2,4,2,2]表示切割为 4 份,每份的长度分别为 2,4,2,2 现在我们将总成绩册张量切割为 10 份:
tf.unstack(x, axis)
特别地,如果希望在某个维度上全部按长度为 1 的方式分割,还可以直接使用 tf.unstack(x, axis)
。这种方式是 tf.split
的一种特殊情况,切割长度固定为 1,只需要指定切割维度即可。
可以看到,通过 tf.unstack
切割后,shape 变为[35,8],即班级维度消失了,这也是与 tf.split
区别之处。