推理引擎TP并行-MLP流程

#01 原理

前文中我们已经描述过MLP列并行与行并行的基础原理。
简单回顾一下:

  • 权重按列切:输入不能切分,得到的是结果,shape是不完整的,最终的结果需要all-gather
  • 权重按行切:输入需要按列切,得到的结果中,shape是完整的,最终的结果需要all-reduce
    我们以qwen3 dense模型的MLP来进行后续的分析。
    先看模型结构:


模型结构如结构一所示,由于gate与up都属于线形变换,所以可以进行融合,同时,silu属于Point Wise操作,所以可以继续融合,最终融合后的结构如结构四所示。

02 tp实现

下面我们来分析如何对MLP层进行tp切分。

整体思路:

  • 在gate-up阶段进行列切,使得每个tp最终得到的结果仅包含部分列,
  • 在down阶段进行行切,使得每个tp仅需处理一部分的计算,最后通过all-reduce得到最终的结果

下面以tp=2来进行分析,对于其他的tp情况类似。


各个rank中的权重分配:



说明:

  • wd行切,每个rank处理切分后的一块
  • wg与wu列切,每个rank处理切分后的一块,因为wg与wu的计算没有依赖关系,将两个权重矩阵合并为一个矩阵,但是合并后的矩阵,仅有列的前一半进行silu的处理

计算原理参考下图


对计算过程进行分析:

  • Block1:原始权重矩阵,wg与wu按照tp size进行切分,切分后均分到不同的rank
  • Block2:在rank内执行矩阵乘法,实际上这里得到的结果,按列包含了gate与up的一部分
  • Block3:Block3并不是真实的rank内的计算过程,而是为了展示数学上的原理,在Block2中,每个rank计算出各自的结果之后,将不同的rank进行合并,得到完整的gate与up计算结果
  • Block4:Block4也不是真实的rank计算过程,Block3中得到了合并后的计算计算结果,然后执行gate与up结果的乘积,这个乘积将用于计算down。mul的结果包含两部分,分别对应于gate与up不同分片的乘积,这里可以看出与Block2中rank内的结果直接相乘是一样的,Block2中的相乘结果可以视为是Block4结果按列切。

从上面Block4中的计算结果可以看出,不需要执行Block3中的合并,而是直接在Block2中执行乘积,并将结果看成是对gate与up乘积的列切,最后再与按行切的down进行计算即可。

上图说明如下:

  • Block2:在执行gate-up操作之后,同时得到了gate与up的结果,然后对结果的一半列执行silu,最后再执行按位乘
  • Block3:Block2每个rank的结果相当于是对down的输入进行了列切分,而down权重经过了行切分,所以执行matmul之后再进行all-reduce就可以得到MLP的最终输出。
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容