从应用落地的角度来说,bert虽然效果好,但有一个短板就是预训练模型太大,预测时间在平均在300ms以上(一条数据),无法满足业务需求。知识蒸馏是在较低成本下有效提升预测速度的方法。最近在看知识蒸馏方面的内容,对《DistillBert》做个简单的介绍。
提纲
1. Bert后演化的趋势
2.知识蒸馏基本原理
3.《DistillBert》详解
4. 后话
一、Bert后演化的趋势
Bert后,语义表示的基本框架已确定,后续大多模型以提升精度、提升速度来做。基本以知识蒸馏、提升算力、多任务学习、网络结构优化四个方向来做。
如何提升速度?
invida发布transformer op,底层算子做fuse。
知识蒸馏,以distillBert和tinyBert为代表。
神经网络优化技巧。prune来裁剪多余的网络节点,混合精度(fp32和fp16混合来降低计算精度从而实现速度的提升)
如何提升精度?
增强算力。roberta
改进网络。xlnet,利用transformer-xl。
多任务学习(ensemble)。微软发布的mk-dnn
二、知识蒸馏的基本原理
知识蒸馏是从算法层面提速的有效方式,是趋势之一。知识蒸馏从hinton大神14年《Distilling the Knowledge in a Neural Network》这篇paper而来。
定义两个网络,一个teacher model,一个student model。teacher model是预训练出来的大模型,teacher model eval结果出来的softlabel作为student model学习的一部分。student model的学习目标由soft label和hard label组成。
其中有个核心的问题,为什么要用soft label呢?因为作者认为softlabel中包含有hard label中没有信息,也就是样本的概率信息,可以达到泛化的效果。
细节参考这篇博文:https://blog.csdn.net/nature553863/article/details/80568658
三、DistillBert
DistillBert的网络结构:
student model的网络结果与teacher model也就是bert的网络结构基本一致。主要包含如下改动:
每2层中去掉一层。。作者调研后结果是隐藏层维度的变化比层数的变化对计算性能的影响较小,所以只改变了层数,把计算层数减小到原来的一半。
去掉了token type embedding和pooler。
每一层加了初始化,每一层的初始化为teacher model的参数。
2. 三个损失函数:
(1)Lce损失函数
Lce损失函数为Teacher model的soft label的损失函数,Teacher model的logits ti/T(T 为温度),通过softmax计算输出得到teacher的概率分布,与student model logits si/T(T为温度),通过softmax计算输出得到student的概率分布,最后计算两个概率分布的KL散度。
(2)Lmlm损失函数
Lmlm损失函数为hard label的损失函数,是bert 的masked language model的损失函数。
(3)Lcos损失函数
计算teacher hidden state和student hidden state的余弦相似度。官方代码用的是:nn.CosineEmbeddingLoss。
整体计算公式为:
Loss= 5.0*Lce+2.0* Lmlm+1.0* Lcos
3. 参数配置
training阶段:计算8个卡,16GB,V100的GPU机器,90个小时
性能: DistilBERT 比Bert快71%,训练参数为207 MB 。
四、实验结果
DistillBert在GLUE数据集上的表现
下图为Ablation test的结果,可以看出Lce、Lcos、参数初始化为结果影响较大。
五、后话
知识蒸馏本质是什么? 个人理解,其实知识蒸馏实际相当于引入先验概率(prior knowledge), soft label即是网络输入的先验概率,soft label与真实世界的事物类似,呈各种概率分布。