Ref: https://arxiv.org/pdf/2105.08050.pdf
code:https://github.com/lucidrains/g-mlp-pytorch/blob/54209f0fb2a52557a1c64409f26df9ebd8d5c257/g_mlp_pytorch/g_mlp_pytorch.py
背景
Transformers 自横空出世以来,在NLP领域,大规模了取代了LSTM-RNN模型,在CV上,ConvNets也不再是唯一选择。它有2个重要特性:
- recurrent-free结构,可以并行化计算每个token的表达;
- multi-head self-attention blocks, 可以聚合token之间的空间信息。
其中的attention mechanism一直被认为transformers取得优秀成绩的重要因素。和MLP相比,attention可以根据模型输入,调整参数,而MLP的参数是固定的。那么问题来了,transformers效果那么好,是self-attention起的决定性作用吗,self-attention是必要的吗?
本文提出了gMLPs,一种attention-free, 以MLP为基础的由channel projections, spatial projections 和gating组成的网络结构。
实验显示:
- 在CV上,可以达到和vision transformers差不多的准确率;和MLP-Mixer相比,参数减少66%,准确率还提升了3%;
- 在NLP上,将gMLPs应用到BERT的MLM,和transformers一样,在预训练实时能最小化perplexity。同时,实验也显示,perplexity和模型规模有关,而对attention不敏感;
2.1 随着模型的capacity上升,gMLPs的预训练和finetuning指标会快速接近Transformers,这意味着,只要扩大模型规模,那么无需self-attention,gMLPs和Transformers的差距会不断缩小;
2.2 batch-size为256,进过1Mstep,gMLPs相比Bert,在MNLI达到了86.4%的准确率,在SQuAD达到了89.5%的F1;
2.3 在finetuning阶段,模型规模和perplexity接近的情况下, Transformers在cross-sentence alignment任务上比gMLPs效果好[MNLI任务高1.8%]。但是,当gMLPs的参数量是transformers的3倍时,模型效果就很接近;
2.4 同时,文中提出一个trick,在gMLPs后接一个single-head 128d 的attention,在NLP的各项任务上,就超过了transformers。
因此,本文觉得,提高数据量和算力,无需self-attention,gMLPs,就可以和transformers媲美。
Model
输入:序列长度为n,embedding维度为d:
使用L个block,每个block进行如下操作:
其中:
U,V为沿着channel[可理解为hidden维度]的线性投影,同Transformers的FFN;
为空间上的交互,用于获取tokens之间的关系。本文认为可以学习到位置信息,因此,没有使用positional embedding。
Spatial Gating Unit
为了实现token之间的交互,在层,就要包含一个空间维度的交叉操作。
文中主要介绍了2种SGU:
- 比较直观的,就是使用线性投影:
其中:
, n为序列长度;b可以是一个矩阵,也可以是一个常量。
空间交互通过element-wise实现:
为确保训练的稳定性,W初始化值接近于0, b为1。这相当于初始化的FFN,开始每个token相互独立,随着训练逐渐考虑token之间的交互信息。
- 除了线性投影的gate, 文中还将Z沿着channel分解成,借鉴GLUs的思路:
源代码分析
class SpatialGatingUnit(nn.Module):
def __init__(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3):
"""dim: embedding size
dim_seq: sequence length """
super().__init__()
dim_out = dim // 2
self.causal = causal
self.norm = nn.LayerNorm(dim_out)
self.proj = nn.Conv1d(dim_seq, dim_seq, 1)
# 常规卷积,卷积的是词向量的维度。本文是空间上的信息交互,因此输入/输出通道是序列长度,卷积核尺寸为1。
self.act = act
init_eps /= dim_seq
nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
nn.init.constant_(self.proj.bias, 1.)
def forward(self, x, gate_res = None):
device, n = x.device, x.shape[1]
res, gate = x.chunk(2, dim = -1) #沿着词向量维度,分成2个矩阵。
gate = self.norm(gate)
weight, bias = self.proj.weight, self.proj.bias
if self.causal:
weight, bias = weight[:n, :n], bias[:n]
mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
weight = weight.masked_fill(mask[..., None], 0.)
gate = F.conv1d(gate, weight, bias)
if exists(gate_res):
gate = gate + gate_res
return self.act(gate) * res
GLUs(Gated linear units)补充:
由Language model with gated convolutional network提出,使用CNN学习长文本,为缓解梯度消散,并保留非线性能力,使用门控机制。即:
没有经过非线性转换的卷积层输出*经过非线性转换的卷积层输出
其中:
:element-wise product
注意,GLUs是沿着channel维度[per token]的处理,而SGU是沿着空间维度[cross-token]的处理。
Image Classification
在图片分类ImageNet数据集上,无需添加外部数据,训练gMLPs。
模型配置如下,输入和输出沿用的ViT(Vision Transformer)格式,模型的深度和宽度配置也和ViT/DeiT模型相似。
结果:和Transformer一样,gMLPs在训练集上过拟合,因此采用了DeiT的正则化处理(mixup, cutmix);同时,对模型的维度做了调整。
Masked Language Modeling with BERT
DepthWise convolution补充
一个卷积核负责一个通道,卷积核数量要和图片通道数相同。
好比一个宽的depthwise convolution,接收整个句子的信息。但是depthwise convolution面向的是通道的filter,而gMLPs只使用一个W共享交叉通道。
在NLP上,gMLPs进行了多个ablation实验。
1. Ablation:the importance of gating in gMLP for BERT's Pretraining
- 使用Bert的absolute position embeddings;
- Bert框架 + T5-stype的relative position biases;
- 同1,2,但只保留relative positional biases,去掉content-dependent terms inside the softmax。
困惑度:交叉熵的指数形式。语言模型越好,句子概率越大,熵越小,困惑度越低。
使用SGU可以让gMLPs得到与Bert差不多的perplexity。
2. Case Study: The Behavior of gMLPs as Model Size Increases
Transformer中的6+6:self-attention使用6层,FFN使用6层。
finetuning任务用GLUE表示模型效果。
结果显示:
- gMLPs越深,pretraining perplexity越小,和transformer的模型效果越逼近;
- pretraining的perplexity越小,不意味着finetuning结果越好,比如gMLPs的perplexity比transformer小的时候,在SST-2的模型结果更好,但是MNLI-m的模型结果更差;
3. Ablation: The Usefulness of Tiny Attention in BERT's Finetuning
文中还做了个测试,在一些下游任务上,主要是设计到句子对的任务上,gMLPs表现比Transformers差。 那就再加一个tiny attention,来加强模型对cross-sentence alignment的学习。
这种混个gMLPs和attention的模型,称为aMLPs。结果显示,aMLPs的效果比gMLPs和transformer都要好。
4.Main Results for MLM in the BERT Setup
- 以SQuADv2.0任务为例,base模型,Bert模型的f1达到了78.6,gMLPs只有70.1, 差距8.5%;到了large模型,差距只有81.0-78.3=2.7;
- aMLPs使用128d的attention size,在SQuADv2.0任务,比Bert还要高4.4%的F1.
前面做的几个实验的总结:
- 在finetuning阶段,gMLPs不如transformer,但是,随着模型变大,和transformer的差距会不断缩小;
- aMLPs 不同的attention size(64,128),足够使得模型效果优于其他2个。