NUMPY 传播机制

1、 broadcast机制

传播机制是numpy在算数计算中处理不同维度数组的方法。

NumPy 操作通常在逐个元素的基础上对数组对进行。 在最简单的情况下,两个数组必须具有完全相同的形状,如下例所示:

a = np.array([1.0, 2.0, 3.0])
b = np.array([2.0, 2.0, 2.0])
a * b
array([2.,  4.,  6.])

当数组的形状满足某些约束时,NumPy 的广播规则会放宽这个约束。 最简单的广播示例发生在同时操作数组和标量值时:

a = np.array([1.0, 2.0, 3.0])
b = 2.0
a * b
array([2.,  4.,  6.])

这里的结果前一个示例相同。 想象标量 b 在算术运算期间被拉伸成一个与 a 形状相同的数组。 b 中的新元素,如下图所示,只是原始标量的副本。 拉伸只是概念性的, NumPy 可以使用原始标量值而无需实际制作副本,从而使广播操作尽可能地提升内存和计算效率。

b的“拉伸”
2、广播规则

当numpy对两个数组进行操作时,他会自右向左比较两个数组的维度,若当前维度满足以下条件,则继续比较:

  • 两个数组当前维度相同
  • 两数组存在一个当前维度为1

否则报错: ValueError: operands could not be broadcast together

结果数组的大小是沿输入的 每个轴不为 1 的大小。

数组不需要具有相同的维数。
例如,如果您有一个 256x256x3 的 RGB 值数组,并且您希望将图像中的每种颜色缩放不同的值,则可以将图像乘以具有 3 个值的一维数组。 根据广播规则 从尾轴 排列这些数组的大小,可以看出它们是兼容的:

Image  (3d array): 256 x 256 x 3
Scale  (1d array):             3
Result (3d array): 256 x 256 x 3

当比较的任一维度是1时,使用另一个。 换句话说,尺寸为 1 的维度被拉伸,以匹配另一个维度。

在以下示例中,A 和 B 数组都有长度为 1 的轴,在广播操作期间扩展为更大的大小:

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5
3、Broadcastable arrays

如果上述规则产生有效结果,则将一组数组称为“可广播”到相同的形状。
例子:

a = np.array([[ 0.0,  0.0,  0.0],
              [10.0, 10.0, 10.0],
              [20.0, 20.0, 20.0],
              [30.0, 30.0, 30.0]])
b = np.array([1.0, 2.0, 3.0])
a + b
array([[  1.,   2.,   3.],
        [11.,  12.,  13.],
        [21.,  22.,  23.],
        [31.,  32.,  33.]])
image.png
a = np.array([0.0, 10.0, 20.0, 30.0])
b = np.array([1.0, 2.0, 3.0])
a[:, np.newaxis] + b
array([[ 1.,   2.,   3.],
       [11.,  12.,  13.],
       [21.,  22.,  23.],
       [31.,  32.,  33.]])
image.png
4、 实例 Vector Quantization

VQ 中的基本操作是在一组点(在 VQ 术语中称为codes )中找到最接近给定点的点,称为observation。

在二维的场景下:
observation的值描述了要分类的运动员的体重和身高。
codes 代表不同类别的运动员。
找到最近的点需要计算observation和每个codes 之间的距离。
最短距离提供最佳匹配。 在此示例中,codes [0] 是最接近的类别,表明该运动员可能是一名篮球运动员。

from numpy import array, argmin, sqrt, sum
observation = array([111.0, 188.0])
codes = array([[102.0, 203.0],
               [132.0, 193.0],
               [45.0, 155.0],
               [57.0, 173.0]])
diff = codes - observation    # the broadcast happens here
dist = sqrt(sum(diff**2,axis=-1))
argmin(dist)
0

在此示例中,observation 数组被拉伸以匹配codes 数组的形状:

Observation      (1d array):      2
Codes            (2d array):  4 x 2
Diff             (2d array):  4 x 2

通常,将可能从数据库中读取的大量observation 结果与一组codes 进行比较。 考虑这种情况:

Observation      (2d array):      10 x 3
Codes            (2d array):       5 x 3
Diff             (3d array):  5 x 10 x 3

三维数组 diff 是广播的结果,而不是计算的必需品。 大型数据集将生成计算效率低下的大型中间数组。 相反,如果使用 Python 循环围绕上述二维示例中的代码单独计算每个观察值,则使用更小的数组。

reference

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。