Numpy的广播机制

广播

广播(broadcasting)指的是不同形状的数组之间的算术运算的执行方式。它是一种非常强大的功能,但也容易令人误解,即使是经验丰富的老手也是如此。将标量值跟数组合并时就会发生最简单的广播:

import numpy as np
arr = np.arange(5)
arr
array([0, 1, 2, 3, 4])
arr * 4
array([ 0,  4,  8, 12, 16])

这里我们说:在这个乘法运算中,标量值4被广播到了其他所有的元素上。

看一个例子,我们可以通过减去列平均值的方式对数组的每一列进行距平化处理。这个问题解决起来非常简单:

arr = np.random.randn(4, 3)
arr
array([[ 1.14072113e+00, -3.75330408e-01,  1.07997253e+00],
       [ 2.92296713e-01,  5.19115583e-01,  1.29876898e+00],
       [-1.12729644e+00,  1.30713095e+00, -4.75432622e-01],
       [-2.30075456e-01,  2.16281589e+00,  1.92077343e-03]])
arr.mean(0)
array([0.01891149, 0.903433  , 0.47630741])
demeaned = arr - arr.mean(0)
demeaned
array([[ 1.12180965, -1.27876341,  0.60366511],
       [ 0.27338522, -0.38431742,  0.82246156],
       [-1.14620793,  0.40369794, -0.95174004],
       [-0.24898694,  1.25938289, -0.47438664]])
demeaned.mean(0)
array([-6.93889390e-18,  0.00000000e+00, -2.77555756e-17])

下图形象地展示了该过程。用广播的方式对行进行距平化处理会稍微麻烦一些。幸运的是,只要遵循一定的规则,低维度的值是可以被广播到数组的任意维度的(比如对二维数组各列减去行平均值)。


广播原则

  • 让所有输入数组都向其中形状最长的数组看齐,形状中不足的部分都通过在前面加 1 补齐。
  • 输出数组的形状是输入数组形状的各个维度上的最大值。
  • 如果输入数组的某个维度和输出数组的对应维度的长度相同或者其长度为 1 时,这个数组能够用来计算,否则出错。
  • 当输入数组的某个维度的长度为 1 时,沿着此维度运算时都用此维度上的第一组值。

简单理解:对两个数组,分别比较他们的每一个维度(若其中一个数组没有当前维度则忽略),满足:

  • 数组拥有相同形状。
  • 当前维度的值相等。
  • 当前维度的值有一个是 1

画张图并想想广播的原则。再来看一下最后那个例子,假设你希望对各行减去那个平均值。由于arr.mean(0)的长度为3,所以它可以在0轴向上进行广播:因为arr的后缘维度是3,所以它们是兼容的。根据该原则,要在1轴向上做减法(即各行减去行平均值),较小的那个数组的形状必须是(4,1):

arr
array([[ 1.14072113e+00, -3.75330408e-01,  1.07997253e+00],
       [ 2.92296713e-01,  5.19115583e-01,  1.29876898e+00],
       [-1.12729644e+00,  1.30713095e+00, -4.75432622e-01],
       [-2.30075456e-01,  2.16281589e+00,  1.92077343e-03]])
row_means = arr.mean(1)
row_means
array([ 0.61512108,  0.70339376, -0.0985327 ,  0.64488707])
row_means.shape
(4,)
row_means.reshape((4, 1))
array([[ 0.61512108],
       [ 0.70339376],
       [-0.0985327 ],
       [ 0.64488707]])
demeaned = arr - row_means.reshape((4, 1))
demeaned
array([[ 0.52560005, -0.99045149,  0.46485144],
       [-0.41109705, -0.18427818,  0.59537522],
       [-1.02876373,  1.40566365, -0.37689992],
       [-0.87496253,  1.51792882, -0.6429663 ]])
demeaned.mean(1)
array([ 0.00000000e+00, -3.70074342e-17,  5.55111512e-17,  0.00000000e+00])

下图说明了该运算的过程

二维数组在轴1上的广播



下图展示了另外一种情况,这次是在一个三维数组上沿0轴向加上一个二维数组。
三维数组在轴0上的广播


最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 文章转载自:https://www.jianshu.com/p/550c90dfffa0 在使用Tensorflo...
    _白马阅读 853评论 0 2
  • Numpy概述 NumPy(Numerical Python的简称)是Python数值计算最重要的基础包。大多数提...
    __method__阅读 274评论 0 1
  • 基础篇NumPy的主要对象是同种元素的多维数组。这是一个所有的元素都是一种类型、通过一个正整数元组索引的元素表格(...
    oyan99阅读 5,151评论 0 18
  • 简单整理了NumPy的一些特性,参考的是 《利用Python进行数据分析》( Wes McKinney 著),Nu...
    蜘蛛鱼阅读 675评论 0 0
  • 无意间翻QQ相册,看到自己前年拍的紫荆花,然后就想起来那个美好的中午。 我从小就很喜欢紫色,对于紫色的花,更是热爱...
    偷懒的云阅读 526评论 0 0