本文来源于https://community.bigquant.com/t/python%E3%80%81numpy-%E4%B8%8E-axis/127023
这次和大家分享的是 numpy 中的 axis 这个东西。当初学的时候也没太在意,向来都是感觉差不多就直接过去了,没有去深究背后的一些逻辑。前些天被问起的时候一时懵懂,查了下资料后发现还有点意思,于是就打算写这么一篇专栏来分享一下所得
要想学习 axis,首先要知道的就是 axis 的计数方式。我们在使用 numpy 的各种函数——比如说 np.sum——的时候,有一个参数就叫做 axis。那么这个参数的意思是什么呢?最直白地来说的话,就是“最外面的括号代表着 axis=0,依次往里的括号对应的 axis 的计数就依次加 1”
举个例子,现在我们有一个矩阵:
不管画风怎么变,很丑这一点都无法改变啊……
所以相应的运算就是:
对应的代码实现和运行结果如下:
可以看到,貌似出来的结果比我们推导的结果的括号要少一些。这是因为诸如 np.sum 这种函数中有一个参数叫 keepdims,它的默认值是 False,此时它会把多余的括号给删掉。假如我们把它设为 True 的话,就可以得到和推导中一致的结果了:
下面来看一个更“高维”一点的例子:
对应的代码实现和运行结果如下:
以及
可以看到结果和我们推导的确实一样
现在我们知道哪个 axis 对应于数组中的哪些元素了,接下来还需要知道的就是 transpose 这个函数到底在背后干了什么。从纸面上来看,如果一个高维数组 x 的 shape 是 (2, 3, 4),那么 transpose 的作用就是把这个 shape 中各个数的顺序改一改。比如说:
但是 transpose 返回的结果究竟是如何得到的,可能就比较难理解了。幸运的是,这个回答 2非常好地阐明了这背后的原理。为了方便观众老爷们,我在这里就当一个搬运 and 润色工
首先是对这个 shape 的理解。直观地说,shape 中的各个数就是对应 axis 的元素个数。比如说上图中的 x,它画出来会是这个样子的:
字比画还丑呢……
如果我们换一种思路的话,以 axis=0 为例,由于我们现在整个数组里面一共有 24 个数,而 axis=0 只有两个元素,所以可以理解为在 axis=0 这个 axis 上,每隔 24 / 2 = 12 个数就跳一下。比如说上面这个图中就可以看出,两个橙色矩阵对应的数之间差的都是 12
类似的,由于一个橙色矩阵中只有 24 / 2 = 12 个数,所以我们可以理解为在 axis=1 这个 axis 上,每隔 12 / 3 = 4 个数就跳一下。表现在图中,就是同一个橙色矩阵的两个相邻的蓝色向量对应的数之间差的都是 4
再次类似的,由于一个蓝色向量中只有 12 / 3 = 4 个数,我们可以理解为在 axis=2 这个 axis 上,每隔 4 / 4 = 1 个数就跳一下。表现在图中……观众老爷们想必也知道是怎样的了 ( σ’ω’)σ
所以我们现在可以定义一个新的东西,比如说叫做 strides 吧,它记录着每个 axis 上跳过的数。比如说上图对应的三维数组,它的 strides 就是 (12, 4, 1)
那么接下来激动人心的时刻到了:transpose 的本质,其实就是对 strides 中各个数的顺序进行调换。举个例子:
在 transpose(1, 0, 2) 后,相应的 strides 会变成 (4, 12, 1)。而从上图可以看出,transpose 的结果确实满足:
- axis=0 的 axis 上,每隔 4 个数跳一下
- axis=1 的 axis 上,每隔 12 个数跳一下
- axis=2 的 axis 上,每隔 1 个数跳一下
至此,transpose 背后的逻辑就理顺啦!撒花!★,°:.☆( ̄▽ ̄)/$:.°★ 。