Numpy中stack()函数的理解

np.stack(array,axis,out=None),函数原型。
其中最重要是的这个axis怎么理解的。
举例说明:
arrays = [np.random.randn(3, 4) for _ in range(10)]
会生成一个 10 *( 3 * 4 )的矩阵列表。十个矩阵,每个矩阵是(3 * 4)大小。
首先说明一下axis的映射。在这个例子中,10->axis=0 ,3->axis=1

>>>np.stack(arrays,axis=1)
array([[[-0.42233185, -0.13270788, -0.47724388, -1.48881134],
        [ 0.2284937 , -0.30139984,  0.15633374,  0.04428078],
        [ 2.0193316 ,  0.1098357 , -0.32044757, -1.24868601],
        [ 0.9859909 , -0.42781564,  0.57524126,  0.58154297],
        [-0.13059124,  2.15207301,  0.36007904, -0.71344781],
        [-1.68010975,  1.25350273,  0.11073033, -0.28531604],
        [ 0.60021096, -0.18691447,  1.49261775,  0.47628294],
        [-0.18268831, -0.32463742, -0.89726008,  0.19245843],
        [-0.27384598,  0.56068318,  1.57096001,  1.11169077],
        [ 0.27035354, -0.54258351, -0.69891459,  1.84282464]],

       [[ 1.44874184, -1.6645958 ,  1.14128754, -2.26945958],
        [ 0.28754711, -1.59591539, -0.92798468, -0.05021877],
        [ 1.09050239, -0.86881164, -0.59820951, -0.39628311],
        [-1.09540304, -0.33438594, -0.71075442, -1.48691938],
        [ 0.7155825 ,  0.24710929, -0.65019501, -1.24407802],
        [-0.11059045, -1.57851632,  1.34142995, -0.44438407],
        [ 0.9258746 ,  1.62418684, -0.25380587, -1.1423341 ],
        [-1.76337136,  0.55031978,  1.25834475,  0.53257722],
        [ 0.05755626,  1.16156935, -1.84999546,  1.57175386],
        [ 0.48836813, -0.21907532, -0.78655392,  0.51705705]],

       [[-0.24451876, -0.09881284,  1.17611246,  0.81276037],
        [ 0.89510841,  0.9106155 ,  0.4923826 , -0.07364133],
        [-0.0670429 ,  0.72968107, -1.31473173, -0.31313322],
        [ 0.62314248,  0.97792175,  0.0840199 , -0.38035465],
        [ 0.70222737,  0.53761069,  0.50546661, -2.02777762],
        [-0.85454667, -0.76359383, -0.25280887, -0.94252057],
        [ 0.38294622, -0.38729216,  0.03757319, -0.48955485],
        [ 1.52718003,  1.14814816,  1.33147053, -0.50341043],
        [-0.38600834,  0.19781327, -0.35596671,  1.59331045],
        [-0.07073478, -1.4710414 ,  1.95192939, -0.83379204]]])
>>> np.stack(arrays, axis=1).shape
(3, 10, 4)

为什么会变成 3 * 10 * 4了呢。首先我们的函数是对 10 * 3 * 4 中的3,也就是axis=1,进行了堆叠。
那么这个 axis = 1,在十个矩阵中代表什么呢?代表 每个矩阵中的一行。所以这个函数的操作就是,把10矩阵中的第i行拿出来拼成一个矩阵。因为一个矩阵有三行,所以堆叠后的矩阵就是,3 * 10 * 4,这个10 * 4,就是原来矩阵中,十个矩阵的第一行,第二行,第三行,拼接而成的。所以是 3 * 10 * 4。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • import numpy as np 创建ndarray data1 = [6,7.5, 8, 0, 1]arr1...
    陆文斌阅读 4,023评论 0 1
  • 基础篇NumPy的主要对象是同种元素的多维数组。这是一个所有的元素都是一种类型、通过一个正整数元组索引的元素表格(...
    oyan99阅读 10,540评论 0 18
  • Numpy是Python的第第三方模块,用于科学计算。 1.属性 列表转化为数组: 2. array的创建 指定数...
    井底蛙蛙呱呱呱阅读 8,655评论 0 10
  • Numpy的组成与功能 Numpy(Numeric Python)可以被理解为一个用python实现的科学计算包,...
    不做大哥好多年阅读 9,870评论 0 10
  • 直立姿势虽然是人类有别于其它动物的一个显著标志。 但人类直立以后,由于地心引力的作用,造成了三个弊病: 一是血液的...
    四月七月阅读 2,206评论 0 0