一、高维矩阵乘法的本质
高维矩阵乘法,其实就是在最后2维相乘,前面的维度类似for loop的次数
-
批量矩阵乘法:对于
A @ B,只有最后两个维度参与标准的矩阵乘法(M × K) × (K × N) → (M × N)。 -
前面的所有维度都是“批量维度”,它们只负责指定有多少个独立的矩阵乘法要执行,类似于隐式的
for循环。 - 因此,形状解读:将
A看作(..., M, K),B看作(..., K, N),结果形状为(..., M, N)。
在 NumPy、PyTorch、TensorFlow 等主流框架中,高维数组的矩阵乘法(如 np.matmul 或 @ 运算符)确实遵循这个规则:
- 最后两维执行标准的矩阵乘法 ( (M \times K) \times (K \times N) = (M \times N) )。
- 前面的所有维度(称为“批量维度”)只是作为循环的嵌套,要求它们完全一致(或者符合广播规则)。
举个例子
假设有两个张量:
设 ( A ) 形状为 (2, 3, 4),( B ) 形状为 (2, 4, 5)。
- 批量维度是
(2,),它们一致。 - 最后两维:
(3, 4)乘(4, 5),得到(3, 5)。 - 最终结果形状为
(2, 3, 5)。
这等价于:
import numpy as np
A = np.random.rand(2, 3, 4)
B = np.random.rand(2, 4, 5)
result = A @ B # shape (2, 3, 5)
# 手动循环验证
manual = np.stack([A[i] @ B[i] for i in range(2)])
print(result)
print("----")
print(manual)
assert np.allclose(result, manual)
二、广播规则的第一性原理
广播是为了让不同形状的数组能够进行元素级运算,其规则可归结为:
- 从后向前对齐维度(即从最内层维度开始)。
- 对于每个维度:
- 若两维度相等 → 保留。
- 若某一维度为 1 或缺失(视为 1) → 扩展为另一维度的大小。
- 否则 → 广播失败。
- 广播是逻辑上的扩展,底层通过将相应维度的步长设为 0 实现零拷贝,兼顾语义与效率。
三、广播规则的例子
场景一:形状相同(无需广播)
-
数组:
A.shape = (2, 3),B.shape = (2, 3) -
运算:
A + B -
广播过程:
- 从后向前对齐:两形状维度数相同,每个维度一一比较。
- 所有维度相等,无需扩展。
-
结果形状:
(2, 3)
场景二:维度数不同,从后向前对齐
-
数组:
A.shape = (3, 4),B.shape = (4,) -
运算:
A + B -
广播过程:
- 对齐:
- 最后维度:
A有 4,B有 4 → 相等 - 倒数第二维度:
A有 3,B无此维度 → 视为 1
- 最后维度:
- 扩展:
B的(4,)逻辑上扩展为(1, 4),再扩展为(3, 4)(实际上通过步长 0 实现)
- 对齐:
-
结果形状:
(3, 4) -
等价理解:将
B的每个元素沿行方向复制 3 次。
场景三:维度数相同,但某维度为 1
-
数组:
A.shape = (2, 1, 3),B.shape = (4, 3) -
运算:
A + B(注意这里维度数不同,需先对齐) -
对齐步骤:
-
A形状为(2, 1, 3),B形状为(4, 3)。从后向前:- 最后维度:3 与 3 → 相等
- 倒数第二维度:
A有 1,B有 4 → 一个为 1,允许广播,结果维度取 4 - 倒数第三维度:
A有 2,B无此维度 → 视为 1,结果维度取 2
- 最终批量形状
(2, 4),B逻辑上扩展为(1, 4, 3)再扩展为(2, 4, 3),A的倒数第二维度从 1 扩展为 4。
-
-
结果形状:
(2, 4, 3)
场景四:广播失败案例
-
数组:
A.shape = (3, 2),B.shape = (4, 2) -
运算:
A + B -
对齐:从后向前:
- 最后维度:2 与 2 → 相等
- 倒数第二维度:3 与 4 → 既不相等,也没有一方为 1
-
结果:抛出
ValueError: operands could not be broadcast together with shapes (3,2) (4,2)
场景五:矩阵乘法中的广播(特殊规则)
矩阵乘法 @ 的广播只针对最后两维之前的维度。
-
数组:
A.shape = (3, 1, 4, 2)(批量(3,1),矩阵(4,2))
B.shape = (4, 2, 5)(批量(4,),矩阵(2,5)) -
批量对齐:
-
A批量(3,1),B批量(4,)→ 将B批量补为(1,4) - 从后向前对齐:
- 维度 1:1 与 4 → 广播为 4
- 维度 0:3 与 1 → 广播为 3
- 最终批量形状
(3,4)
-
-
结果形状:
(3, 4, 4, 5)(最后两维为矩阵乘结果(4,5))
总结:广播规则的三个关键点
- 从后向前对齐:保证最内层维度优先匹配。
- 维度匹配条件:相等或至少一个为 1(缺失视为 1)。
- 扩展方式:逻辑复制,底层步长为 0,无实际内存拷贝。
四、矩阵乘法中的广播应用
- 在
matmul中,广播仅作用于批量维度(最后两维之前的维度),最后两维不参与广播,必须严格满足矩阵乘法的维度要求(即A的列数等于B的行数)。 -
正确拆分形状:
例如A.shape = (3, 4, 2)应拆分为批量维度(3,)和矩阵维度(4, 2);
B.shape = (4, 2, 5)拆分为批量维度(4,)和矩阵维度(2, 5)。
然后对批量形状(3,)与(4,)应用广播规则——二者均为 1 维,且值不相等且都不为 1,因此无法广播,会直接报错。 - 只有在批量维度满足“相等或一方为 1”时,广播才能成功,最终结果的批量维度取两者对应维度的最大值。
关键理解
- 高维矩阵乘法 = 批量维度的广播对齐 + 最后两维的矩阵乘法。
- 广播规则的核心是“从后向前,等或一”,保证了数学直觉与计算效率的统一。
- 处理高维矩阵乘法时,务必先分离出批量维度与矩阵维度,再判断广播可行性。