关于numpy mean函数的axis参数

理解多维矩阵的"求和"、"平均"操作确实太恶心了,numpy提供的函数里还有一堆参数,搞得晕头转向的,这里做个笔记,提醒一下自己, 下面是例程

import numpy as np
X = np.array([[1, 2], [4, 5], [7, 8]])
print np.mean(X, axis=0, keepdims=True)
print np.mean(X, axis=1, keepdims=True)

结果是分别是

                 [[ 1.5]
 [[ 4.  5.]]      [ 4.5]    
                  [ 7.5]]

我个人比较raw的认识就是,axis=0,那么输出矩阵是1行,求每一列的平均(按照每一行去求平均);axis=1,输出矩阵是1列,求每一行的平均(按照每一列去求平均)。还可以这么理解,axis是几,那就表明哪一维度被压缩成1。

再举个更复杂点的例子,比如我们输入为batch = [128, 28, 28],可以理解为batch=128,图片大小为28×28像素,我们相求这128个图片的均值,应该这么写

m = np.mean(batch, axis=0)

输出结果m的shape为(28,28),就是这128个图片在每一个像素点平均值。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容