#01 原理
前文中我们已经描述过MLP列并行与行并行的基础原理。
简单回顾一下:
- 权重按列切:输入不能切分,得到的是结果,shape是不完整的,最终的结果需要all-gather
-
权重按行切:输入需要按列切,得到的结果中,shape是完整的,最终的结果需要all-reduce
我们以qwen3 dense模型的MLP来进行后续的分析。
先看模型结构:
模型结构如结构一所示,由于gate与up都属于线形变换,所以可以进行融合,同时,silu属于Point Wise操作,所以可以继续融合,最终融合后的结构如结构四所示。
02 tp实现
下面我们来分析如何对MLP层进行tp切分。
整体思路:
- 在gate-up阶段进行列切,使得每个tp最终得到的结果仅包含部分列,
- 在down阶段进行行切,使得每个tp仅需处理一部分的计算,最后通过all-reduce得到最终的结果
下面以tp=2来进行分析,对于其他的tp情况类似。
各个rank中的权重分配:
说明:
- wd行切,每个rank处理切分后的一块
- wg与wu列切,每个rank处理切分后的一块,因为wg与wu的计算没有依赖关系,将两个权重矩阵合并为一个矩阵,但是合并后的矩阵,仅有列的前一半进行silu的处理
计算原理参考下图
对计算过程进行分析:
- Block1:原始权重矩阵,wg与wu按照tp size进行切分,切分后均分到不同的rank
- Block2:在rank内执行矩阵乘法,实际上这里得到的结果,按列包含了gate与up的一部分
- Block3:Block3并不是真实的rank内的计算过程,而是为了展示数学上的原理,在Block2中,每个rank计算出各自的结果之后,将不同的rank进行合并,得到完整的gate与up计算结果
- Block4:Block4也不是真实的rank计算过程,Block3中得到了合并后的计算计算结果,然后执行gate与up结果的乘积,这个乘积将用于计算down。mul的结果包含两部分,分别对应于gate与up不同分片的乘积,这里可以看出与Block2中rank内的结果直接相乘是一样的,Block2中的相乘结果可以视为是Block4结果按列切。
从上面Block4中的计算结果可以看出,不需要执行Block3中的合并,而是直接在Block2中执行乘积,并将结果看成是对gate与up乘积的列切,最后再与按行切的down进行计算即可。
上图说明如下:
- Block2:在执行gate-up操作之后,同时得到了gate与up的结果,然后对结果的一半列执行silu,最后再执行按位乘
- Block3:Block2每个rank的结果相当于是对down的输入进行了列切分,而down权重经过了行切分,所以执行matmul之后再进行all-reduce就可以得到MLP的最终输出。