高维矩阵乘法的本质 2026-03-23

一、高维矩阵乘法的本质

高维矩阵乘法,其实就是在最后2维相乘,前面的维度类似for loop的次数

  • 批量矩阵乘法:对于 A @ B,只有最后两个维度参与标准的矩阵乘法 (M × K) × (K × N) → (M × N)
  • 前面的所有维度都是“批量维度”,它们只负责指定有多少个独立的矩阵乘法要执行,类似于隐式的 for 循环。
  • 因此,形状解读:将 A 看作 (..., M, K)B 看作 (..., K, N),结果形状为 (..., M, N)

在 NumPy、PyTorch、TensorFlow 等主流框架中,高维数组的矩阵乘法(如 np.matmul@ 运算符)确实遵循这个规则:

  • 最后两维执行标准的矩阵乘法 ( (M \times K) \times (K \times N) = (M \times N) )。
  • 前面的所有维度(称为“批量维度”)只是作为循环的嵌套,要求它们完全一致(或者符合广播规则)。

举个例子

假设有两个张量:
设 ( A ) 形状为 (2, 3, 4),( B ) 形状为 (2, 4, 5)

  • 批量维度是 (2,),它们一致。
  • 最后两维:(3, 4)(4, 5),得到 (3, 5)
  • 最终结果形状为 (2, 3, 5)

这等价于:

import numpy as np
A = np.random.rand(2, 3, 4)
B = np.random.rand(2, 4, 5)
result = A @ B  # shape (2, 3, 5)

# 手动循环验证
manual = np.stack([A[i] @ B[i] for i in range(2)])

print(result)
print("----")
print(manual)

assert np.allclose(result, manual)

二、广播规则的第一性原理

广播是为了让不同形状的数组能够进行元素级运算,其规则可归结为:

  1. 从后向前对齐维度(即从最内层维度开始)。
  2. 对于每个维度:
    • 若两维度相等 → 保留。
    • 若某一维度为 1 或缺失(视为 1) → 扩展为另一维度的大小。
    • 否则 → 广播失败。
  3. 广播是逻辑上的扩展,底层通过将相应维度的步长设为 0 实现零拷贝,兼顾语义与效率。

三、广播规则的例子

场景一:形状相同(无需广播)

  • 数组A.shape = (2, 3)B.shape = (2, 3)
  • 运算A + B
  • 广播过程
    1. 从后向前对齐:两形状维度数相同,每个维度一一比较。
    2. 所有维度相等,无需扩展。
  • 结果形状(2, 3)

场景二:维度数不同,从后向前对齐

  • 数组A.shape = (3, 4)B.shape = (4,)
  • 运算A + B
  • 广播过程
    1. 对齐:
      • 最后维度:A 有 4,B 有 4 → 相等
      • 倒数第二维度:A 有 3,B 无此维度 → 视为 1
    2. 扩展:B(4,) 逻辑上扩展为 (1, 4),再扩展为 (3, 4)(实际上通过步长 0 实现)
  • 结果形状(3, 4)
  • 等价理解:将 B 的每个元素沿行方向复制 3 次。

场景三:维度数相同,但某维度为 1

  • 数组A.shape = (2, 1, 3)B.shape = (4, 3)
  • 运算A + B(注意这里维度数不同,需先对齐)
  • 对齐步骤
    1. A 形状为 (2, 1, 3)B 形状为 (4, 3)。从后向前:
      • 最后维度:3 与 3 → 相等
      • 倒数第二维度:A 有 1,B 有 4 → 一个为 1,允许广播,结果维度取 4
      • 倒数第三维度:A 有 2,B 无此维度 → 视为 1,结果维度取 2
    2. 最终批量形状 (2, 4)B 逻辑上扩展为 (1, 4, 3) 再扩展为 (2, 4, 3)A 的倒数第二维度从 1 扩展为 4。
  • 结果形状(2, 4, 3)

场景四:广播失败案例

  • 数组A.shape = (3, 2)B.shape = (4, 2)
  • 运算A + B
  • 对齐:从后向前:
    • 最后维度:2 与 2 → 相等
    • 倒数第二维度:3 与 4 → 既不相等,也没有一方为 1
  • 结果:抛出 ValueError: operands could not be broadcast together with shapes (3,2) (4,2)

场景五:矩阵乘法中的广播(特殊规则)

矩阵乘法 @ 的广播只针对最后两维之前的维度。

  • 数组A.shape = (3, 1, 4, 2)(批量 (3,1),矩阵 (4,2)
    B.shape = (4, 2, 5)(批量 (4,),矩阵 (2,5)
  • 批量对齐
    1. A 批量 (3,1)B 批量 (4,) → 将 B 批量补为 (1,4)
    2. 从后向前对齐:
      • 维度 1:1 与 4 → 广播为 4
      • 维度 0:3 与 1 → 广播为 3
    3. 最终批量形状 (3,4)
  • 结果形状(3, 4, 4, 5)(最后两维为矩阵乘结果 (4,5)

总结:广播规则的三个关键点

  1. 从后向前对齐:保证最内层维度优先匹配。
  2. 维度匹配条件:相等或至少一个为 1(缺失视为 1)。
  3. 扩展方式:逻辑复制,底层步长为 0,无实际内存拷贝。

四、矩阵乘法中的广播应用

  • matmul 中,广播仅作用于批量维度(最后两维之前的维度),最后两维不参与广播,必须严格满足矩阵乘法的维度要求(即 A 的列数等于 B 的行数)。
  • 正确拆分形状
    例如 A.shape = (3, 4, 2) 应拆分为批量维度 (3,) 和矩阵维度 (4, 2)
    B.shape = (4, 2, 5) 拆分为批量维度 (4,) 和矩阵维度 (2, 5)
    然后对批量形状 (3,)(4,) 应用广播规则——二者均为 1 维,且值不相等且都不为 1,因此无法广播,会直接报错。
  • 只有在批量维度满足“相等或一方为 1”时,广播才能成功,最终结果的批量维度取两者对应维度的最大值。

关键理解

  • 高维矩阵乘法 = 批量维度的广播对齐 + 最后两维的矩阵乘法
  • 广播规则的核心是“从后向前,等或一”,保证了数学直觉与计算效率的统一。
  • 处理高维矩阵乘法时,务必先分离出批量维度与矩阵维度,再判断广播可行性。
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

  • """1.个性化消息: 将用户的姓名存到一个变量中,并向该用户显示一条消息。显示的消息应非常简单,如“Hello ...
    她即我命阅读 7,118评论 0 6
  • 1、expected an indented block 冒号后面是要写上一定的内容的(新手容易遗忘这一点); 缩...
    庵下桃花仙阅读 1,571评论 1 2
  • 一、工具箱(多种工具共用一个快捷键的可同时按【Shift】加此快捷键选取)矩形、椭圆选框工具 【M】移动工具 【V...
    墨雅丫阅读 2,239评论 0 0
  • 跟随樊老师和伙伴们一起学习心理知识提升自已,已经有三个月有余了,这一段时间因为天气的原因休课,顺便整理一下之前学习...
    学习思考行动阅读 1,687评论 0 2
  • 一脸愤怒的她躺在了床上,好几次甩开了他抱过来的双手,到最后还坚决的翻了个身,只留给他一个冷漠的背影。 多次尝试抱她...
    海边的蓝兔子阅读 1,319评论 1 4

友情链接更多精彩内容