Python和tensorflow编程中经常见这三种shape的用法,容易混淆,特写一篇文章来总结以备遗忘。这三个函数都是用来获取维度信息的,但用法和使用对象各有不同,下面进行一一介绍。
(1) np.shape()
这个函数是numpy中的一个函数(函数要加括号!!!),其功能是获取括号内数据的长度或维度信息,其使用对象既可以是一个数,也可以是数组或矩阵。如下例所示:
In [1]: import numpy as np
In [2]: np.shape(0)
Out[2]: ()
In [3]: np.shape([0])
Out[3]: (1,)
In [4]: np.shape([1, 2, 3])
Out[4]: (3,)
In [5]: np.shape([[1], [2]])
Out[5]: (2, 1)
In [12]: a = np.zeros([2,3])
In [13]: a
Out[13]:
array([[0., 0., 0.],
[0., 0., 0.]])
In [14]: np.shape(a)
Out[14]: (2, 3)
In [15]: np.shape(a)[1]
Out[15]: 3
(2) array.shape
array.shape是numpy中ndarray数据类型的一个属性。我们先来理解一下几个问题:
1.什么是ndarray数据类型?
ndarray是numpy库中的一种数据类型,凡是以np.array()定义的数据都是ndarray类型,就跟pytorch中的张量tensor类似。
2.什么是属性?
属性就是python类中初始化的时候,self.xx代表的变量,是该类特有的信息。比如我们定义一个学生类:
class Student:
def __init__(self, height, weight, number):
self.height = height # 身高
self.weight = weight # 体重
self.number = number # 学号
Student类中的self.height,self.weight ,self.number就是属性。如果ZhangSan是一个Student类,我们想要获知张三的身高体重学号等信息,就采用ZhangSan.height,ZhanSan.weight,ZhangSan.number即可获得,并且可以看到这里的属性是不带括号的。
在Python中,一切数据对象都是一个类,包括ndarray类型。shape就是ndarray数据的一个属性,shape表示这个ndarray实例的形状,即各维度的数值。dtype也是其属性之一,即datatype得缩写,表示这个ndarray实例的数据类型。
需要注意的就是注意属性不加括号!!!使用方法如下:
In [16]: b = np.array([[1,2,3],[4,5,6],[7,8,9]])
In [17]: b
Out[17]:
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
In [18]: b.shape
Out[18]: (3, 3)
In [19]: b.shape[0]
Out[19]: 3
In [20]: c = [1, 2, 3] # c不是ndarray类型
In [21]: c.shape
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-21-d6049491b182> in <module>
----> 1 c.shape
AttributeError: 'list' object has no attribute 'shape'
从类的角度看,把ndarray和numpy都当做Python的一个类,ndarray.shape表示ndarray的属性,自然可知,np.shape()其实就是numpy类的方法。
在numpy中,一般可直接用于ndarray类型数据上的方法也有与之对应的numpy函数可执行相同操作,如:
In [52]: a = np.arange(5)
In [53]: a
Out[53]: array([0, 1, 2, 3, 4])
In [54]: np.sum(a)
Out[54]: 10
In [55]: a.sum()
Out[55]: 10
记住,函数或方法很像,都要带括号!!!属性不带括号!!!
In [56]: a = np.random.randn(5,3)
In [57]: a
Out[57]:
array([[-0.47169257, -1.33625595, 1.09450799],
[ 0.68097098, -0.77349608, -0.13462524],
[ 1.01122524, -0.72573122, -2.80145914],
[ 0.32187105, 0.66012558, -0.80316889],
[-0.79434656, 0.33565231, -0.51083857]])
In [58]: a.shape #获取矩阵大小
Out[58]: (5, 3)
In [59]: a.ndim #获取矩阵维度
Out[59]: 2
In [60]: a.dtype #获取矩阵数据类型
Out[60]: dtype('float64')
(3) Tensor.get_shape().as_list()
这是tensorflow中常用于获取tensor维度信息的函数,注意该函数只能用于tensor对象。Tensor.get_shape()本身获取tensor的维度信息并以元组的形式返回,由于元组内容不可更改,故该函数常常跟.as_list()连用,返回一个tensor维度信息的列表,以供后续操作使用。