tf.concat的官方解释
tf.concat(
values,
axis,
name='concat'
)
其中:
values应该是一个tensor的list或者tuple。
axis则是我们想要连接的维度。
tf.concat返回的是连接后的tensor。
比如,如果list中的tensor的shape都是(2,2,2),如果此时的axis为2,即连接第三个维度,那么连接后的shape是(2,2,4),具体表现为对应维度的堆砌。例子如下:
t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]]
tf.concat([t1, t2], axis=-1)
输出结果为
<tf.Tensor 'concat_2:0' shape=(2, 2, 4) dtype=int32>
再sess.run()一下拿出具体tensor为:
[[[ 1, 2, 7, 4],
[ 2, 3, 8, 4]],
[[ 4, 4, 2, 10],
[ 5, 3, 15, 11]]]
可见符合(2,2,4)的shape。