阅读笔记 - The Devil in Linear Transformer

来源:https://www.researchgate.net/publication/364419868_The_Devil_in_Linear_Transformer
代码:https://github.com/OpenNLPLab/Transnormer


这篇文章的目的是优化线性transformer,线性transformer相对于标准transformer能够将计算复杂度从 O(N^2C) 降到O(NC^2). 但线性transformer 相对于标准transformer 往往存在着较明显的指标gap。作者分析认为原因有两点:

  • unbounded gradients。无边界梯度,会导致模型在训练时不稳定,收敛不好;
  • attention dilution。注意力稀释,transformer在lower level时应该更关注局部特征,而higher level更关注全局特征,但线性transformer中的attention往往weight 更均匀化,不能聚焦在local区域上,因此称为attention稀释。
    针对于上述两点,作者提出了NormAttention和DiagAttention两个模块,形成NormFormer的结构。

1.The devil in linear attention

我们首先来看一下作者分析的线性transformer存在的两点缺陷的结论是怎么来的。

1.1 Unbounded gradients

在标准的attention结构中
O = \text{softmax}(QK^T/\sqrt{D})V, ~~ Q=XW_Q, K=XW_K, V=XW_V
正是这里的QK^T 带来的O(N^C)的计算复杂度。而为了解决这个问题目前主要包含两类: 基于pattern的方法和基于kernel的方法。
基于pattern的方式主要是通过一些先验筛选key或query,降低计算复杂度;而基于kernel的方法则是本文提到的线性transformer,通过核函数去取代softmax,从而能够通过矩阵乘法结合律降低计算复杂度。
那么来看一下计算attention时,vanilla和linear transformer的统一形式:
p_{ij} = \frac{f(s_{ij})}{\sum_{k-1}^n f(s_{ik})}
对于vanilla transformer而言, s_{ij} = q_i^Tk_j/\sqrt{d}, ~~ f(x) = \text{exp}(x), 对于linear transformer可以表示为 s_{ij} = \phi(q_i)\phi(k_j)^T,~~f(x)=x. 于是可以比较一下两者的梯度:
vanilla attention: \frac{\partial p_{ij}}{\partial s_{ik}} = \frac{f'(s_{ik})}{f(s_{ik})}\big(1_{j=k}p_{ij} - p_{ij}p_{ik}\big), 这里推理的时候注意凑p_{ij}, p_{ik}
f'(x) = \text{exp}(x) = f(x) \\ \frac{\partial p_{ij}}{\partial s_{ik}} = 1_{j=k}p_{ij} - p_{ij}p_{ik} \\ = \begin{cases} p_{ik} - p_{ij}p_{ik}\in [0, 1/4], &j=k \\ - p_{ij}p_{ik}\in [-1/4, 0],& j\neq k\end{cases}
这里推理的时候只有p_{ik} = p_{ij} 时边界值成立,所以最终
\Big \vert \frac{\partial p_{ij}}{\partial s_{ik}}\Big\vert \le \frac{1}{4}

linear attention: 线性attention的关键在于f'(x) = 1, 因此
f'(x) =1 \\ \frac{\partial p_{ij}}{\partial s_{ik}} = \frac{1}{s_{ik}} \big(1_{j=k}p_{ij} - p_{ij}p_{ik}\big) \\ = \frac{1}{s_{ik}}\begin{cases} p_{ik} - p_{ij}p_{ik}, &j=k \\ - p_{ij}p_{ik},& j\neq k\end{cases} 即,\Big \vert \frac{\partial p_{ij}}{\partial s_{ik}}\Big\vert \le \frac{1}{4|s_{ik}|}.
因为s_{ik} = \phi(q_i)\phi(q_k)^T 大小是不确定的,所以相当于linear attention的梯度是无边界的。这就会导致收敛不稳定,收敛难度大等问题。

1.2 Attention dilution

注意力稀释方面,作者直接评估了不同level上,每一个query在邻域内的其他query上的attention的权重占比,这里需要注意的是,query之间是有序的,即对于NLP或者featmap而言,是有固定结构的,才可以这么评估。l(i, r, N)表示第i个query在其rN个邻域query上的attention之和,可以看下图,a图中transformer和linear transformer相比,显然linear transformer的聚集度要小很多。这就是所谓的注意力稀释。

image.png

2. architecture

针对于1中的两个问题,有针对性的设计了两个模块。

2.1 NormAttention.

作者提出的解决方案
O = Q(K^TV) \\ O_{norm} = \text{XNorm}(Q(K^TV)),
这里的XNorm 可以是Layernorm,也可以是 RMSNorm。注意这里的Q,和K是有激活函数的,公式没写,但图中画了。
\text{RMSNorm}(x) = \frac{x}{\sqrt{\sigma^2 + \epsilon}} \\ \sigma^2 = \sum_{i=1}^d x_i^2 /d , \epsilon > 0,
文章证明这个做法梯度是有上界的。附录的证明过程有点复杂。

2.2 DiagAttention

这个模块其实就是一种基于pattern的attention,将query按距离划分不重叠的window,每个window内进行 attention的计算。奇怪的是 这里的attention使用的都是vanilla attention。

下图是文章方法TransNormer的结构:


image.png

3. 实验

实验都是在NLP上做的,不大了解,因此不做分析,这里只看下消融实验的结论。

image.png

table8. 表明早期的stage应当更关注局部特征,而后期的stage则应该更关注全局信息。
table9. 早期适合使用blockattn,后期适合使用normattn
table10. FFN中作者对比了FFN和GLU的结果,发现GLU效果会更好一些。
image.png

table11.表明diagattn中的window的大小,这个其实有有点说不通,如果DiagAttn使用的linear attention, block size越大不是attention 稀释的越严重吗? 这个地方DiagAttn使用的应该都是vanilla attention,包括softmax attention和ReLA attention.

4. 结论

本文提出的norm attention其实在很多其他方法中都见过,而且所谓的diag attention使用的还是vanilla attention,并没有把linear attention应用到diag block里,感觉不是很充实。值得学习的是本文中提出的梯度分析的方法。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容