如何获取tensor的维度值

  在TensorFlow的使用过程中,我们常常希望得到一个tensor的维度信息使用,具体的说,也就是现在有了一个tensor值,如何才能得到其shape信息,也就是维度值作为一个整数值使用呢?
  对于一个tensor值,我们很容易利用tensor.get_shape()tf.shape(tensor)来获取其shape。但是,这两种方法返回的shape信息都是Dimension 类型的,并非int32类型的。下面两种方法可以获得tensor shape的具体值。

  • 方法一:利用as_list()方法
      利用tensor.get_shape().as_list() 方法。对于一个2-D的tensor,获得其行列值可以这样做,
 num_rows, num_cols  = X.get_shape().as_list()
  • 方法二:利用Dimension对象的value属性
      利用tensor.get_shape()[0].value 方法。对于一个2-D的tensor,获得其行列值可以这样做,
 num_rows, num_cols  = map(lambda i: i.value, X.get_shape())

参考:https://stackoverflow.com/questions/40666316/how-to-get-tensorflow-tensor-dimensions-shape-as-int-values

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

友情链接更多精彩内容