LLM面面观之MoE

1. 背景

根据本qiang~最新的趋势观察,基于MoE架构的开源大模型越来越多,比如马斯克的Grok-1(314B), Qwen1.5-MoE-A2.7B等,因此想探究一下MoE里面的部分细节。

此文是本qiang~针对大语言模型的MoE的整理,包括原理、流程及部分源码。

2. MoE原理

MoE的流行源于”欧洲的OpenAI”

Mistral AI发布的论文及模型《Mixtral of Experts》,评测集上的效果吊打众多开源模型,如Llama 2 70B和GPT3.5。

《Mixtral of Experts》基础模型使用的是Mistral

AI自研的Mistral 7B,该模型的特点包括:滑窗注意力(Sliding Window Aattention), 滚动缓冲区缓存(Rolling

Buffer Cache)以及预填充-分块(Pre-fill and Chunking),具体细节可以查阅文末的论文地址。

本文以《Mixtral of

Experts》为引子,探究MoE的相关细节,MoE的原理如下图所示:


图2.1 MoE的原理

(1) Transformers架构中的每一层中的FFN网络均替换为了8个FFN(专家),且由一个网关路由(gate

router)进行控制

(2) 针对每一个token,每一层的网关路由仅选择其中的2个FFN(专家)来处理当前状态并进行加权输出

(3) 结果就是,每一个token访问了47B参数,但是在推理阶段仅仅使用了13B的激活参数(即,只使用2个专家,冻结其他6个专家)。

(4) 与Dropout机制对比,Dropout让部分神经元失活,而MoE是让部分专家失活。

3. 源码

本qiang~研读并尝试执行了Mistral官网的github推理代码,该代码框架非常适合新手,无他,只因其几乎只是在torch上层做的封装,很少引擎其他第三方库,不像transformers,功能强大,但不适合新手研读代码…

为了普适性,下面的代码截取了transformers框架中的代码。

首先看下通用Transformers中FFN中的代码模块,代码位置在transformers.models.mistral.modeling_mistral,主要流程是:

(1) 先经过gate_proj和up_proj的2个[hidden_size,

intermediate_size]的线性转换

(2) 使用激活函数对gate_proj进行激活

(3) 二者的内积再经过down_proj线性转换。


class MistralMLP(nn.Module):

    def __init__(self,  config):

        super().__init__()

        self.config = config

        self.hidden_size =  config.hidden_size


  self.intermediate_size = config.intermediate_size

        self.gate_proj =  nn.Linear(self.hidden_size, self.intermediate_size, bias=False)

        self.up_proj =  nn.Linear(self.hidden_size, self.intermediate_size, bias=False)

        self.down_proj =  nn.Linear(self.intermediate_size, self.hidden_size, bias=False)

        self.act_fn = ACT2FN[config.hidden_act]


    def forward(self, x):

        return  self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


再来看下MoE中的专家模块,代码位置在transformers.models.mixtral.modeling_mixtral,主要流程是:

(1) 首先经过网关路由self.gate

(2) 然后选择其中2个专家,并归一化

(3) 之后遍历每个专家网络,并按照expert_mask进行筛选

(4) 如果expert_mask有值,则选择指定部分的隐藏层进行FFN操作,且输出结果进行加权

(5) 最后原地增加先前初始化的最终结果变量final_hidden_states


class MixtralSparseMoeBlock(nn.Module):


    def __init__(self,  config):

        super().__init__()

        self.hidden_dim =  config.hidden_size

        self.ffn_dim =  config.intermediate_size

        self.num_experts =  config.num_local_experts

        self.top_k =  config.num_experts_per_tok


        # gating

        self.gate =  nn.Linear(self.hidden_dim, self.num_experts, bias=False)


        self.experts =  nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in  range(self.num_experts)])


    def forward(self,  hidden_states: torch.Tensor) -> torch.Tensor:

        """  """

        batch_size,  sequence_length, hidden_dim = hidden_states.shape

        hidden_states =  hidden_states.view(-1, hidden_dim)

        # router_logits:  (batch * sequence_length, n_experts)

        router_logits =  self.gate(hidden_states)


        routing_weights =  F.softmax(router_logits, dim=1, dtype=torch.float)

        routing_weights,  selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)

        routing_weights /=  routing_weights.sum(dim=-1, keepdim=True)

        # we cast back to  the input dtype

        routing_weights =  routing_weights.to(hidden_states.dtype)


        final_hidden_states  = torch.zeros(

            (batch_size *  sequence_length, hidden_dim), dtype=hidden_states.dtype,  device=hidden_states.device

        )


        # One hot encode the  selected experts to create an expert mask

        # this will be used  to easily index which expert is going to be sollicitated

        expert_mask =  torch.nn.functional.one_hot(selected_experts,  num_classes=self.num_experts).permute(2, 1, 0)


        # Loop over all  available experts in the model and perform the computation on each expert

        for expert_idx in  range(self.num_experts):

            expert_layer =  self.experts[expert_idx]

            idx, top_x =  torch.where(expert_mask[expert_idx])


            if  top_x.shape[0] == 0:

                continue


            # in torch it is  faster to index using lists than torch tensors

            top_x_list =  top_x.tolist()

            idx_list =  idx.tolist()


            # Index the  correct hidden states and compute the expert hidden state for

            # the current  expert. We need to make sure to multiply the output hidden

            # states by  `routing_weights` on the corresponding tokens (top-1 and top-2)

            current_state =  hidden_states[None, top_x_list].reshape(-1, hidden_dim)


  current_hidden_states = expert_layer(current_state) *  routing_weights[top_x_list, idx_list, None]


            # However  `index_add_` only support torch tensors for indexing so we'll use

            # the `top_x`  tensor here.


  final_hidden_states.index_add_(0, top_x,  current_hidden_states.to(hidden_states.dtype))

        final_hidden_states  = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

        return final_hidden_states, router_logits


其中MixtralBlockSparseTop2MLP代码如下,可以看到和传统MistralMLP内容完全一致。


class MixtralBlockSparseTop2MLP(nn.Module):

    def __init__(self,  config: MixtralConfig):

        super().__init__()

        self.ffn_dim =  config.intermediate_size

        self.hidden_dim =  config.hidden_size


        self.w1 =  nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

        self.w2 =  nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)

        self.w3 = nn.Linear(self.hidden_dim,  self.ffn_dim, bias=False)


        self.act_fn =  ACT2FN[config.hidden_act]


    def forward(self,  hidden_states):


  current_hidden_states = self.act_fn(self.w1(hidden_states)) *  self.w3(hidden_states)

        current_hidden_states  = self.w2(current_hidden_states)

        return  current_hidden_states


4. MoE微调

由于MoE只是将每一层的FFN改变为了每一层的gate网关路由+8个FFN专家,且gate网关路由和8个专家内部均为线性运算,所以可以无缝地结合LoRA、QLoRA进行指令微调。

可以参考开源项目:https://github.com/yangjianxin1/Firefly

5. 答疑解惑

(1) 问:MoE 8*7B的模型是56B参数?

答:MoE 8*7B的参数量是47B,而不是56B,原因是每一层除了8个专家网络外,其他层均是复用的。

(2) 问:MoE的基础模型是Mistral7B?

答:不是,MoE的模型架构与Mistral

7B相同,但其中的FFN替换为了8个FFN,且MoE是基于多语言数据集预训练而来的。

(3) MoE的稀疏性(sparse)体现在哪里?

答:在训练和推理时,同时只有两个专家网络会被激活,进行前向计算,其它专家网络处于失活状态。

6. 总结

一句话足矣~

本文主要针对大语言模型的MoE,包括原理及部分源码。

此外,建议大家可以针对源码进行运行,关于源码,欢迎大家一块交流。

7. 参考

(1) Mistral 7B:https://arxiv.org/pdf/2310.06825v1.pdf

(2) MoE:https://arxiv.org/pdf/2401.04088v1.pdf

(3) MoE开源指令微调框架Firefly:https://github.com/yangjianxin1/Firefly

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,539评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,911评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,337评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,723评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,795评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,762评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,742评论 3 416
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,508评论 0 271
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,954评论 1 308
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,247评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,404评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,104评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,736评论 3 324
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,352评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,557评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,371评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,292评论 2 352

推荐阅读更多精彩内容