扫Switch Transformer论文中,作者提到了Product Key Networks。出自2019年末Facebook的工作
论文题目:《Large Memory Layers with Product Keys》
论文地址:https://arxiv.org/abs/1907.05242
摘要:
介绍了一种结构化的容易嵌入到神经网络的memory模块。这种memory能够显著提升模型容量,在数十亿参数的增加基础上同时拥有可忽略的计算支出。memory模块基于product key的模式,能够进行快速和准确的最近邻搜索,从而使得系统在训练和推理的时候在预测准确率和计算消耗的trade-off上表现更好。实验中,在30亿单词的数据集上,将memory layer注入sota模型中,12层transfomer效果优于24层,并且在推理性能2X。
总的来看,是对于大规模memory network的探索,在提升效果的同时拥有良好的性能指标。
背景介绍:
依旧是经典的CV&NLP领域中,大规模模型带来高收益和泛化表现同时饱受高计算复杂度的困扰。因此对于可接受的计算复杂度的研究也很多,比如“On-device Visual Intelligence Challenge”,关注图像分类的complexity/accuracy trade-off。
一些研究尝试增加模型容量而不带来计算复杂度增加。比如Rae等人工作提出在神经网路结构的快速最近邻搜索从而利用大量的稀疏读写操作的key-value层。但是他们的方法依赖于外在的索引结构,这是近似的并且在训练中需要时常更新从而避免catastrophic drift。
不同于先前工作,收到乘积量化的启发,本文的key-value memory layer定义了一种两层sub-keys结合。具体区别如下图,传统k-v网络中,通过query network生成的query vector寻找最近邻的k个keys。product key网络将key分为两组sub-key的组合,这样在每一组sub-key中只需要O()次比较。并且memory参数更新中,在每个batch只需要更新很少量的memory slots,这种稀疏的模式使得训练和推理都十分高效。
三点贡献:
1.一种大参数量和轻量级计算增加的layer
2.检索方法不需要在训练中额外re-learned
3.1层memory + 12层transformer效果优于24层transformer,并且推理性能2X
相关工作介绍:
条件计算:对于每个输入只分发部分的连接和层处理
记忆增强网络:一种表征复杂问题的任意长度输入的方法,这些memorys在特征空间进行不同的读写操作。但是计算量随着memory大小线性增长,因此在推理上会采用近似的lookup技术。
离散化技术:低维二分类编码&局部敏感哈希 缺点:在训练时候,近似索引在高维空间中往往不够准确
借助product quantization思想,不是直接建立近似索引,通过更少的反传学习的向量来表征大量的key vectors。
借鉴无监督学习中的系数模型思想,比如k-sparse自编码器在隐空间中只保留topk value。winner take all自编码器将memory分为稀疏子集,从而根据数据结构进行content-based读操作。
与self-attention层对比:key value和input无关&value的个数很大
结构:
high-level结构上来说:query network + key选择模块 + value lookup table
流程:query network生成query tensor,选取topk最接近的product keys。通过value table加权求和得到最后的输出,因此每个输入只更新k个memory slots。
query生成:x->q(x)。q是线性映射或者多层感知机,将维度减小到512。这里作者提出一个tips,在query network前增加BN能够帮助负载均衡(后续实验也会提到)
标准的key分配和加权:
如公式所示,很简单:
复杂度为内积操作。
product key设置:
两个vector codebooks c1和c2,因此总共keys组合为=。c1和c2的维度为dq/2,将输入的query tensor分成两个子queries q1和q2,然后计算k sub-keys。
复杂度:
对头记忆机制:
每一头拥有独立的query network和sub-keys,所有头共享相同的values。类比于multi-head attention,但是没有将query切分到每个头中,而是创造不同的query tensor。实际观察中,不同头的key选择很少有重复,因此这个方法能够提升key的利用率并且提升最终表现。
实验:
说了这么多,实验才是王道,看看效果吧。
如图用memory代替一些transformer层中的FFN,同时加入残差连接。用x<-x + PKM(x)代替x<-x + FFN(x)。实际中,也可以直接在transformer中插入PKM,而不是替代FFN。
数据集:
最大公共的语言模型数据集是One Billion Word。但是对于标准的模型结构来说还是太小,因此需要乏味的正则化。用16层维度为1024的transformer结构,可以观察到明显的过拟合。因此决定在30倍大小的Common Crawl(CC)数据集上验证。训练集有280亿单词,140GB,来自于4000万英文新闻文章。验证和测试集是从训练集中划出的5000篇新闻。
细节:没有shuffle sentence,模型学习长范围的依赖性;Moses toolkit分词;Byte Pair Encoding减少vocabulary size
评估指标:
评估usage:memory usage&均匀分布的KL散度
细节:
稀疏更新memory values,较大的学习率e-3。
H=4 memory heads,key=32per head,512**2 memory slots。
结果:
如上图所示,加入1层memory layer之后,12层layers效果优于24层,并且2-3层memory layer效果更佳。bert-base加入memory layer之后效果媲美bert-large,并且推理速度为2倍。这里有一点需要提一下,正常情况下bert-base性能是bert-large的两倍以上,也只能说加入了memory layer之后,没有增加太多的计算量。
消融实验:
6层transformer 8头,512**2slots & 4 memory slots & 32 selected keys & layer 5 注入
·memory size变大,效果更好。同时在推理阶段,memory size并不影响性能
·加入BN层之后,usage显著提高
·模型效果与利用的keys数目相关
·插入中间层效果更好,memory需要在更抽象的特征空间中运算,并且需要顶部有一些处理和收集信息的层
·不同头数和k的影响,总的来说越大效果越好
·product keys vs. flat keys. product keys相比于full match少了倍的参数量,在效果和性能方面均优于flat keys。
结论:
两个关键因素:keys的因式分解为product set & 稀疏的read/write memory values更新
这篇文章是memory网络的一次性能上的探索,最近看的FB文章都很注重模型的性能,偏重于实践工程。之前也有文章分析过,在FFN结构中,两层的linear本质就是key-value结构,本文能取得效果提升也是可以预见的。接下来会分析一些transformer内部结构的文章,希望能够对于MOE工作带来一些指导与启发。