0. XLNet简介
2018年Bert横空出世,刷新了很多NLP任务的SOTA。之后人们开始研究对Bert的改进,本文介绍的XLNet就是比较成功的另一个模型。不同于Bert的AutoEncoder模式,XLNet用的是AutoRegressive模式。据说,XLNet在20个任务上比BERT做得更好,的确吸人眼球。
Bert类(AE)模型的的不足之处在于,在训练的时候引入了[MASK],而在fine-tune阶段并不会出线[MASK],导致了预训练与finetune不一致。[MASK]的另一个问题是假设MASK掉的token是相互独立的,然而并不是这样,譬如New York。
XLNet提供了一种新的方法,让AR语言模型从双向的上下文学习,避免了AE语言模型中MASK带来的弊端。
1. Permutation Language Modeling
AR语言模型只能使用前向的上下文或后向的上下文,那么如何使用双向的上下文呢,XLNet提出了一个新的目标,叫做重排序语言建模(Permutation Language Modeling)。
理论上对于长度为T的序列X,存在T!中排列方式,但实际上由于计算复杂度的限制,不可能计算所有的序列排列,因此对于每个序列输入只采样一个排列方式。而且在实际训练时,不会打乱序列,而是通过mask矩阵
实现permutation。作者特意强调,这样可以保持与finetune输入顺序的一致,不会存在pretrain-finetune
差异。
2. Two-Stream Self-Attention
将序列X打乱顺序后有一个很大的问题,就是如何加入位置信息。在预测的时候,我们应该知道的位置编码,而不是上下文编码,同时还要知道之前的上下文编码。
Query stream:只能看到当前的位置信息,不能看到当前token的编码,如图b
Content stream:传统self-attention
,像GPT一样对当前token进行编码,如图a
在预训练阶段的最终预测只是用Query stream,在fine-tune阶段使用Content stream。
3. Transformer-XL
由于内存和算力的限制,Bert只能支持到512个字符,过长的文本都要截断,丢弃部分信息。Transformer-XL参考了RNN,将前边的信息以隐藏单元的形式记录下来。
从图b可以看出,右上角的token获取到了远远大于之前截断方式的信息,类似CNN的感受野。