知识蒸馏-简单

参考文献:

https://github.com/DA-southampton/NLP_ability/blob/master/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E8%87%AA%E7%84%B6%E8%AF%AD%E8%A8%80%E5%A4%84%E7%90%86/%E6%A8%A1%E5%9E%8B%E8%92%B8%E9%A6%8F/bert2textcnn%E6%A8%A1%E5%9E%8B%E8%92%B8%E9%A6%8F.md

https://zhuanlan.zhihu.com/p/82129871

这里的蒸馏其实有点意思,什么是蒸馏,这词起的很有意思,比如如何从海水中炼得海盐。就是搞个火炉,去把海水蒸发掉,留下来的就是最精华的海盐。

那知识蒸馏就是把大模型的精华蒸馏出来,给小模型用。详细可以参考tinybert的实现。

这里简单讲讲如何将bert的精华蒸馏到textcnn等速度较快精度也较高的模型中。

简单来讲,就是同一个语句,输入bert得到一个logit,然后输入textcnn得到一个logit,两个logit之间做mse损失就可以了。其实就是想让我的textcnn学习到比较精准logit的结果,而不是简单的0-1,因为logit里面其实有很多的隐含知识,并不是最后简单的label信息。至于损失函数,可以如下:



对应再加上一些数据增强的措施(类bert操作),增强数据,防止过拟合,如:

Masking 使用[mask]标签来随机替换一个单词,例如“I love the comedy"替换为” I [mask] the comedy"。

POS-guided word replacement 将一个单词替换为相同POS标签的随机单词。例如,“What do pigs eat?"替换为"How do pigs eat?"。

n-gram sampling 随机采用n-gram,n从1到5,并丢弃其它单词。


具体的训练猜测:

同时训练两个模型,一种是bert的fineturn,一种是textcnn的学习,训练阶段的时候,可以同时获取两个模型的logit,然后计算mse损失,然后进行回传。这里的回传为了保证bert不受影响,可以不用回传到bert那侧,bert那边正常fineturn即可。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容