tf.concat

tf.concat的官方解释

tf.concat(
    values,
    axis,
    name='concat'
)

其中:
values应该是一个tensor的list或者tuple。
axis则是我们想要连接的维度。
tf.concat返回的是连接后的tensor。
比如,如果list中的tensor的shape都是(2,2,2),如果此时的axis为2,即连接第三个维度,那么连接后的shape是(2,2,4),具体表现为对应维度的堆砌。例子如下:

t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]]
tf.concat([t1, t2], axis=-1)

输出结果为

<tf.Tensor 'concat_2:0' shape=(2, 2, 4) dtype=int32>

再sess.run()一下拿出具体tensor为:

[[[ 1,  2,  7,  4],
  [ 2,  3,  8,  4]],

 [[ 4,  4,  2, 10],
  [ 5,  3, 15, 11]]]

可见符合(2,2,4)的shape。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容