浅析DeepSeek多头潜在注意力机制(MLA)
背景:DeepSeek在无损模型效果的同时大幅降低了大模型的训练以及推理成本,引起业界广范关注。所涉及的优化包括不限于:使用低精度计算,知识蒸馏和稀疏计算,软硬件协同优化,模型架构优化【Mutil-head Latent Attention 以及DeepSeekMoE】,分布式训练和优化通信策略,自我学习和高效利用GPU资源等。本文将重点围绕从DeepSeek-V2开始的注意力优化,潜在注意力机制Mutil-head Latent Attention进行浅析。MLA是通过显著减少大模型推理过程中的kv-缓存来实现推理加速的。所以本文将从(1)大模型推理过程中为什么需要KV缓存(2)常用的几种优化推理过程中KV缓存的方法 (3)MLA是什么以及为什么能减少KV缓存的数学原理。 三部分展开。
为什么需要KV缓存:
大模型的常用架构:
常见大模型的基本架构是基于transformer的,仅仅使用transformer的decoder部分,而且将transformer的decoder中的Mutil-HEAD-Self-Attention去掉,仅保留MASK-Mutil-Head-Self-Attention。如下图所示(左传统tranformer-右常用大模型架构)。
所以整个训练以及推理过程中,Mask-Mutil-Head-SelfAttention作为Token学习上下文依赖的部分是比较重要的。
大模型的整体推理过程:
当我们输入一段问题(提示词)给大模型后,大模型是按下述方式完成回答。
1. Prefill:提示词的embedding prefill&&推理出第一个词。
(1) 提示词lookup table 得到每一个词的embedding && 计算Position Embedding。
(2)计算提示词的每一个词的Q,K,V 矩阵(并将KV缓存起来)。 并得到self-attention矩阵 softmax(Q*transpose(K)),如果提示词的长度是N,attention矩阵是N*N的。多个头的话 就是HEAD_NUM *N*N. 【实际上大模型会pdding成输入的max_len比如2048纬度确保长度统一,】HEAD_NUM*MAX_LEN*MAX_LEN.
(3)然后根据Attention矩阵的最后一行以及每一个词的V得到 最后一个词的MHA阶段的输出,将最后一行该输出放入FFN,就会得到1个输出,这个词就是回答阶段的第一个词。
2. Decoding.第一个词
(1)这个词looking up table && 计算 PE。
(2)得到该词的Q,K,V矩阵。计算该词的attention矩阵需要所有输入词的K,V矩阵(并加入缓存)。来得到一个1*(N+1)的attention权重。根据这个权重和所有的V【之前提示词的以及这个词的】得到这个词的MHA的输出。如果我们不缓存所有的KV,那么相当推理每一个词的时候都将之前的KV按照 1里面的(1)(2)步骤重复计算。
(3)将MHA的输出放入FFN,得到1个输出这个词就是大模型回答的第二个答案。
3.Decoding 直到结束。
将第二个token 按照步骤2的操作 执行一遍会得到第三个词。直到出现end_token或者到达MAX_LEN终止。
KVCache-必要性演示:
下面以一个简化版本的LLM推理过程来演示KV-cache必要性。第一个词,有没有kv-cache计算量是一样的。对于当前query token词来说attention矩阵就是1*1.
第二个词,如果没有kv cache 需要重新计算key Token1以及Value Token1。【也就是有kv cache的紫色部分】
第三个词,如果没有kv cache,需要重新计算key Token1以及2, Value Token1&&2.
重复计算的次数随着Token位置的延后 越累积越多。所以KVcache 是必要。实际上将MHA 部分的复杂度从n*2降低到n。
训练过程中为什么不需要KV-Cache:
(1)训练过程中是batch 训练,不像推理那样一个一个token都需要依赖上文信息,每一个batch 实际上是一个完整句子,拆开的多个样本,也就是这个batch的最后一个样本是见到了这个句子全部token的。所有token的注意力矩阵一下就能全部算出来。举一个例子。上图的图三,maxlen就是3个词,那么一次计算就得到了这个batch的3*3的mask-attention矩阵。KVcache没有意义。
常用的几种优化推理过程中KV缓存的方法 :
通过上面的分析可以看出KV-cache就是一个空间换时间的过程,但是GPU相对CPU的单核缓存较小问题是不允许cache过大的。所以有很多研究围绕着缩小kvcache 展开。
普通MHA kv_cache 存储数:
按上文所述。来看一下普通的MHA需要存储多少kv_cache.
按GPT中的一些数字大小。
句子的最大长度 L 2048
一个词embeding的size: DIM 12288
注意力的头数:Head_Num 96
每个头的注意力维数 Head_DIM 128. [128*96=12288]
解码器层数 96
所以一层解码器需要的存储kv-cache是 2*Head_Num*Head_DIM*L.
因为需要存储k&&V,所以是2,不同头的参数不同,所以需要HEAD_NUM, HEAD_DIM是每个K,V的向量维度。L是句子的最大长度。所以解码阶段最大的存储就是2*Head_Num*Head_DIM*L。
所以一层解码器需要的存储kv-cache是 2*Head_Num*Head_DIM*L.
因为需要存储k&&V,所以是2,不同头的参数不同,所以需要HEAD_NUM, HEAD_DIM是每个K,V的向量维度。L是句子的最大长度。所以解码阶段最大的存储就是2*Head_Num*Head_DIM*L。
常用的几种优化KV的方法:
什么是MLA:
MLA旨在进一步缩小KV缓存的大小,同时在性能上超越之前提到的注意力机制(包括MHA)。它通过将KV缓存压缩到低维潜在空间,成功将缓存大小减小了90%+,
MLA不会像传统方式那样在每个头计算和存储每个令牌的键和值,而是使用下投影矩阵DownLinear把它们压缩成潜在向量C。想达到一个C解决所有头的KV_Cache问题,在推理时候,再通过一个UPLinear 升维变相达到kv_cache的目的。这就是MLA的核心思想 - 在保持模型能力的同时,通过降维来减少内存占用
MLA的数学原理:
数学原理上:首先创造一个维度为dc的latent向量,这个向量维度远小于Head_Num*Head_DIM。对于第S个head,我们建立参数矩阵,注意这里的参数矩阵,和传统的MHA比,多了一个Wc ,(传统的只有Wq,Wk,Wv)同时Wq的维度与传统的MHA一样是d*dk 纬度d也就是前面我们常说的DIM(一个词的embedding size),dk 就是我们前面说的每一个头的维度Head_DIM. Wk,Wv 纬度已经变了,从d*dk 变成了dc *dk . 现在的Wc*Wv 才是老的Wv。Wc*Wk 才是老的Wk
首先,将输入的词xi投影到一个低维空间得到ci。xi的维度是1*d。
然后利用ci和其他参数矩阵得到输入词xi在第s个头的q,k,v矩阵:
这里的ki = xi *老的Wk= xi *Wc*Wk = Ci*Wk ,同样vi也是一样的推导过程。
对计算句子中位置t的词qt的attention矩阵。 只需要计算qt*transpose(ki), i<=t即可,这里利用乘法交换律
可以看出对qt 这个词来说,在第S个头上,他的attention矩阵的对i这个词的值只需要自己的输入xt*(Wq*transpose(Wk))*tranpose(ci),那么这里我们只需要存储若干个ci就行了,他就是每一个词对应的潜在向量。 对不同的头,ci都是一样的。而我们知道老的MHA中,每次qt的得到也是需要xt*Wq,这里只不过是变成了xt*Wq*transpose(Wk),可以说运算量几乎没增加。但是所需要存储的cache数大大减少。仅仅需要。
L*dc,直接和MQA一个量级,但是效果确和MHA一样。DeepSeek-v2中 dc是512维度,其存储和group=2的MGA是一样的。
MLA与ROPE的融合:
什么是ROPE:rope是这样一个矩阵,与q,k相乘后,实现位置编码的功效。m就是位置【0,1,2,3。。。】词是第几个位置就是几,theta_i=10000^{-2i/ d}, d就是128, i就是0 到d/2-1.
但当加入RoPE后,这个合并就无法实现了。因为RoPE是一个与位置相关的 dk×dk 分块对角矩阵 Rm,它满足 Rm*tranpose(Rn)=Rm−n。加入RoPE后的注意力计算变为:
多了Rt-i这个是与位置i相关的使得MLA的kv存储失效了。
DeepSeek采用了一个混合方案 - 在每个Attention Head的Q、K中新增 dr 【实战中dr=dk/2=64】个维度用于RoPE,其中K的新增维度在所有Head间共享:(同时v3开始 q也被压缩到低维度
带ROPE的QKV如下图:
对第s个头的token_i计算attention 矩阵的第i个位置 qi与ki的转置 相乘。因为q,k这里都是1行多列的向量【Qpart1,Qpart2】*tranpose(【Kpart1,Kpart2】)= Qpart1*tranpose(Kpart1) + Qpart2*tranpose(Kpart2)
实际上
这部分在上面已经推导过,只缓存ci即可。而Qpart2*tranpose(kPart2)WkrRI 缓存起来,这样R部分就没有任何和位置计算相关的地方在计算attention的过程中。
整个的MLA的流程如下图所示,对于单一token,只需要存储latent向量C,维度512.以及WkrRi 也就是图中的KtR 纬度64,(【512+64】/128 = 4.5)。