Microsoft在2020年提出了TwinBERT: Distilling Knowledge to Twin-Structured Compressed BERT Models for Large-Scale Retrieval这篇论文。今天有幸看了看,简单的跟大家分享下。
解决问题
论文主要解决的问题是:性能~ 性能~ 性能~~~
Online Server需要快速处理,尤其是在召回阶段,面对上亿级Doc,为此减少在线计算大势所趋。
架构原理
TwinBert就是在这种背景下应运而生的,如下图结构:
主要讲下上面这张图:
- 整体:
- 两个对称的Bert, 左边的Bert用于Query建模,右边的Bert用于Title keyword建模(或者Doc Context keyword建模)。
- 两个Bert走完后,再各自经过一个Pooling Layer,池化层,听起来很高大上,其实很简单,主要是将序列中每个token的向量搞在一起,做成一个向量。 Query做成一个向量, keyword做成一个向量,以方便进行后面的Cross Layer的交互。 池化层有两个操作二选一,【用CLS】 或者 【所有tokens向量平均加权起来】,其中后者权重是学出来的。
- 输入 : 均为Word Embeding + Position Embeding。 因为两边都是一句话,所以就没有了Segment Embeding了。
值得提一下是,论文中是训练的英文的模型,所对输入进行了Word Hashing,具体说是使用了Tri-letter, 至于什么是Word Hashing ,见本人的另外一文章Word Hashing。
*Transformer Encoder
这里不多说,其中L用的是6层。
- 池化层
见整体部分,已说明。
*Cross Layer
Query做成一个向量q, keyword做成一个向量k,二者进行距离计算,有两种方式,一种是余弦相似度,如下图:
另一种是Residual network, 这里不多讲,有兴趣,自身翻阅。
如何训练?
蒸馏方法训练。
teacher model
所以要搞一个teacher model,文章用的12层的 query和title关键词的训练的。二分类,分为相关和不相关。最后输出一个概率。student model
有了teacher model, 现在就开始teach学生把,将上面讲的Cross layer做的输出通过LR压缩到区间(0,1), 因为余弦的值域是[-1,1].
然后做一个做交叉熵 cross entropy。如下面公式:
优点
节省性能,Query在线用Bert预测, Doc提前离线算好刷到索引。在线只需要做一次Query Bert预测,以及与Doc的向量计算。