Pay Attention to MLPs

Pay Attention to MLPs

Ref: https://arxiv.org/pdf/2105.08050.pdf
code:https://github.com/lucidrains/g-mlp-pytorch/blob/54209f0fb2a52557a1c64409f26df9ebd8d5c257/g_mlp_pytorch/g_mlp_pytorch.py

背景

Transformers 自横空出世以来,在NLP领域,大规模了取代了LSTM-RNN模型,在CV上,ConvNets也不再是唯一选择。它有2个重要特性:

  1. recurrent-free结构,可以并行化计算每个token的表达;
  2. multi-head self-attention blocks, 可以聚合token之间的空间信息。

其中的attention mechanism一直被认为transformers取得优秀成绩的重要因素。和MLP相比,attention可以根据模型输入,调整参数,而MLP的参数是固定的。那么问题来了,transformers效果那么好,是self-attention起的决定性作用吗,self-attention是必要的吗

本文提出了gMLPs,一种attention-free, 以MLP为基础的由channel projections, spatial projections 和gating组成的网络结构。

实验显示:

  1. 在CV上,可以达到和vision transformers差不多的准确率;和MLP-Mixer相比,参数减少66%,准确率还提升了3%;
  2. 在NLP上,将gMLPs应用到BERT的MLM,和transformers一样,在预训练实时能最小化perplexity。同时,实验也显示,perplexity和模型规模有关,而对attention不敏感;
    2.1 随着模型的capacity上升,gMLPs的预训练和finetuning指标会快速接近Transformers,这意味着,只要扩大模型规模,那么无需self-attention,gMLPs和Transformers的差距会不断缩小;
    2.2 batch-size为256,进过1Mstep,gMLPs相比Bert,在MNLI达到了86.4%的准确率,在SQuAD达到了89.5%的F1;
    2.3 在finetuning阶段,模型规模和perplexity接近的情况下, Transformers在cross-sentence alignment任务上比gMLPs效果好[MNLI任务高1.8%]。但是,当gMLPs的参数量是transformers的3倍时,模型效果就很接近;
    2.4 同时,文中提出一个trick,在gMLPs后接一个single-head 128d 的attention,在NLP的各项任务上,就超过了transformers。

因此,本文觉得,提高数据量和算力,无需self-attention,gMLPs,就可以和transformers媲美。

Model

输入:序列长度为n,embedding维度为d:
X\in R^{n\times d}

使用L个block,每个block进行如下操作:

Z = \sigma (XU) = GeLU(XU)
\tilde Z = s(Z)
Y = \tilde ZV

其中:
U,V为沿着channel[可理解为hidden维度]的线性投影,同Transformers的FFN;
s(\cdot)为空间上的交互,用于获取tokens之间的关系。本文认为s(\cdot)可以学习到位置信息,因此,没有使用positional embedding。

gMLPs

Spatial Gating Unit

为了实现token之间的交互,在s(\cdot)层,就要包含一个空间维度的交叉操作。

文中主要介绍了2种SGU:

  1. 比较直观的,就是使用线性投影:
    f_{W,b} (Z) = WZ + b

其中:
W\in R^{n\times n}, n为序列长度;b可以是一个矩阵,也可以是一个常量。
空间交互通过element-wise实现:
s(Z) = Z \odot f_{W,b} (Z)

为确保训练的稳定性,W初始化值接近于0, b为1。这相当于初始化的FFN,开始每个token相互独立,随着训练逐渐考虑token之间的交互信息。

  1. 除了线性投影的gatef_{W,b} (\cdot), 文中还将Z沿着channel分解成(Z_1,Z_2),借鉴GLUs的思路:
    s(Z) = Z_1 \odot f_{W,b} (Z_2)

源代码分析

class SpatialGatingUnit(nn.Module):
    def __init__(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3):
        """dim: embedding size 
            dim_seq: sequence length """
        super().__init__()
        dim_out = dim // 2
        self.causal = causal

        self.norm = nn.LayerNorm(dim_out)
        self.proj = nn.Conv1d(dim_seq, dim_seq, 1) 
        # 常规卷积,卷积的是词向量的维度。本文是空间上的信息交互,因此输入/输出通道是序列长度,卷积核尺寸为1。

        self.act = act

        init_eps /= dim_seq
        nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
        nn.init.constant_(self.proj.bias, 1.)

    def forward(self, x, gate_res = None):
        device, n = x.device, x.shape[1]

        res, gate = x.chunk(2, dim = -1) #沿着词向量维度,分成2个矩阵。
        gate = self.norm(gate)

        weight, bias = self.proj.weight, self.proj.bias
        if self.causal:
            weight, bias = weight[:n, :n], bias[:n]
            mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
            weight = weight.masked_fill(mask[..., None], 0.)

        gate = F.conv1d(gate, weight, bias)

        if exists(gate_res):
            gate = gate + gate_res

        return self.act(gate) * res

GLUs(Gated linear units)补充:

由Language model with gated convolutional network提出,使用CNN学习长文本,为缓解梯度消散,并保留非线性能力,使用门控机制。即:
没有经过非线性转换的卷积层输出*经过非线性转换的卷积层输出
h(x) = (X*W+b)\odot \sigma(X*V + b)

其中:
\odot:element-wise product
X\in R^{N \times m}
W,V \in R^{k \times m \times n}

注意,GLUs是沿着channel维度[per token]的处理,而SGU是沿着空间维度[cross-token]的处理。

Image Classification

在图片分类ImageNet数据集上,无需添加外部数据,训练gMLPs。
模型配置如下,输入和输出沿用的ViT(Vision Transformer)格式,模型的深度和宽度配置也和ViT/DeiT模型相似。
结果:和Transformer一样,gMLPs在训练集上过拟合,因此采用了DeiT的正则化处理(mixup, cutmix);同时,对模型的维度做了调整。


CV gMLPs
ImageNet模型结果
图片分类准确率和模型规模关系

Masked Language Modeling with BERT

DepthWise convolution补充

一个卷积核负责一个通道,卷积核数量要和图片通道数相同。
f_{W,b}( \cdot)好比一个宽的depthwise convolution,接收整个句子的信息。但是depthwise convolution面向的是通道的filter,而gMLPs只使用一个W共享交叉通道。

在NLP上,gMLPs进行了多个ablation实验。

1. Ablation:the importance of gating in gMLP for BERT's Pretraining

  1. 使用Bert的absolute position embeddings;
  2. Bert框架 + T5-stype的relative position biases;
  3. 同1,2,但只保留relative positional biases,去掉content-dependent terms inside the softmax。

困惑度:交叉熵的指数形式。语言模型越好,句子概率越大,熵越小,困惑度越低。

各种模型的perplexity比较

使用SGU可以让gMLPs得到与Bert差不多的perplexity。

2. Case Study: The Behavior of gMLPs as Model Size Increases

模型规模和finetuing结果比较

Transformer中的6+6:self-attention使用6层,FFN使用6层。
finetuning任务用GLUE表示模型效果。
结果显示:

  1. gMLPs越深,pretraining perplexity越小,和transformer的模型效果越逼近;
  2. pretraining的perplexity越小,不意味着finetuning结果越好,比如gMLPs的perplexity比transformer小的时候,在SST-2的模型结果更好,但是MNLI-m的模型结果更差;

3. Ablation: The Usefulness of Tiny Attention in BERT's Finetuning

文中还做了个测试,在一些下游任务上,主要是设计到句子对的任务上,gMLPs表现比Transformers差。 那就再加一个tiny attention,来加强模型对cross-sentence alignment的学习。

Hybrid

这种混个gMLPs和attention的模型,称为aMLPs。结果显示,aMLPs的效果比gMLPs和transformer都要好。


模型比较

4.Main Results for MLM in the BERT Setup

模型效果总结
  1. 以SQuADv2.0任务为例,base模型,Bert模型的f1达到了78.6,gMLPs只有70.1, 差距8.5%;到了large模型,差距只有81.0-78.3=2.7;
  2. aMLPs使用128d的attention size,在SQuADv2.0任务,比Bert还要高4.4%的F1.

前面做的几个实验的总结:

  1. 在finetuning阶段,gMLPs不如transformer,但是,随着模型变大,和transformer的差距会不断缩小;
  2. aMLPs 不同的attention size(64,128),足够使得模型效果优于其他2个。
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,558评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,002评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,036评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,024评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,144评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,255评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,295评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,068评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,478评论 1 305
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,789评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,965评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,649评论 4 336
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,267评论 3 318
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,982评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,223评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,800评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,847评论 2 351

推荐阅读更多精彩内容