来源:https://www.researchgate.net/publication/364419868_The_Devil_in_Linear_Transformer
代码:https://github.com/OpenNLPLab/Transnormer
这篇文章的目的是优化线性transformer,线性transformer相对于标准transformer能够将计算复杂度从 降到. 但线性transformer 相对于标准transformer 往往存在着较明显的指标gap。作者分析认为原因有两点:
- unbounded gradients。无边界梯度,会导致模型在训练时不稳定,收敛不好;
- attention dilution。注意力稀释,transformer在lower level时应该更关注局部特征,而higher level更关注全局特征,但线性transformer中的attention往往weight 更均匀化,不能聚焦在local区域上,因此称为attention稀释。
针对于上述两点,作者提出了NormAttention和DiagAttention两个模块,形成NormFormer的结构。
1.The devil in linear attention
我们首先来看一下作者分析的线性transformer存在的两点缺陷的结论是怎么来的。
1.1 Unbounded gradients
在标准的attention结构中
正是这里的 带来的的计算复杂度。而为了解决这个问题目前主要包含两类: 基于pattern的方法和基于kernel的方法。
基于pattern的方式主要是通过一些先验筛选key或query,降低计算复杂度;而基于kernel的方法则是本文提到的线性transformer,通过核函数去取代softmax,从而能够通过矩阵乘法结合律降低计算复杂度。
那么来看一下计算attention时,vanilla和linear transformer的统一形式:
对于vanilla transformer而言, , 对于linear transformer可以表示为 . 于是可以比较一下两者的梯度:
vanilla attention: , 这里推理的时候注意凑
这里推理的时候只有 时边界值成立,所以最终
linear attention: 线性attention的关键在于, 因此
即,.
因为 大小是不确定的,所以相当于linear attention的梯度是无边界的。这就会导致收敛不稳定,收敛难度大等问题。
1.2 Attention dilution
注意力稀释方面,作者直接评估了不同level上,每一个query在邻域内的其他query上的attention的权重占比,这里需要注意的是,query之间是有序的,即对于NLP或者featmap而言,是有固定结构的,才可以这么评估。表示第i个query在其个邻域query上的attention之和,可以看下图,a图中transformer和linear transformer相比,显然linear transformer的聚集度要小很多。这就是所谓的注意力稀释。
2. architecture
针对于1中的两个问题,有针对性的设计了两个模块。
2.1 NormAttention.
作者提出的解决方案
,
这里的XNorm 可以是Layernorm,也可以是 RMSNorm。注意这里的Q,和K是有激活函数的,公式没写,但图中画了。
文章证明这个做法梯度是有上界的。附录的证明过程有点复杂。
2.2 DiagAttention
这个模块其实就是一种基于pattern的attention,将query按距离划分不重叠的window,每个window内进行 attention的计算。奇怪的是 这里的attention使用的都是vanilla attention。
下图是文章方法TransNormer的结构:
3. 实验
实验都是在NLP上做的,不大了解,因此不做分析,这里只看下消融实验的结论。
table8. 表明早期的stage应当更关注局部特征,而后期的stage则应该更关注全局信息。
table9. 早期适合使用blockattn,后期适合使用normattn
table10. FFN中作者对比了FFN和GLU的结果,发现GLU效果会更好一些。
table11.表明diagattn中的window的大小,这个其实有有点说不通,如果DiagAttn使用的linear attention, block size越大不是attention 稀释的越严重吗? 这个地方DiagAttn使用的应该都是vanilla attention,包括softmax attention和ReLA attention.
4. 结论
本文提出的norm attention其实在很多其他方法中都见过,而且所谓的diag attention使用的还是vanilla attention,并没有把linear attention应用到diag block里,感觉不是很充实。值得学习的是本文中提出的梯度分析的方法。