1 、什么是知识?
通常认为,知识是模型学习到的参数(比如卷积的权重)
2 、什么是蒸馏?
将知识从大模型(教师模型)转移到更适合部署的小模型(学生模型)
Distilling the knowledge in a Neural Network
知识蒸馏主要思想:
Student Model 学生模型模仿 Teacher Model 教师模型,二者相互竞争,直到学生模型可以与教师模型持平甚至卓越的表现;(使用的数据是相同的)
知识蒸馏的算法:
主要由:1)知识 Knwoledge,2)蒸馏算法 Distillate,3)师生架构组成(见上图)
通过上图可知Knowledge 知识的形式主要有三种:
1 、 Response-Based Knowledge
主要指Teacher-Model 教师模型的最后一层———输出层的特征。
主要思想是让 Student Model 学生模型直接学习教师模型的预测结果(Knowledge)。
最简单有效的模型压缩方法,即:老师学习好,把结论直接告诉学生就 OK
假设张量为教师模型输出,张量
为学生模型输出,Response-Based Knowledge的蒸馏形式可以被描述为:
(
,
) =
(
,
)
通过学习流程图,可知相同的数据,有个 Teacher模型(上图红色代表老师)和 Student 模型(绿色代表学生),会把 Teacher 输出的特征给到 Student 模型去学习,会拿出最后一层的特征,然后通过Distialltion Loss,让学生最后一层的特征去学习老师输出的特征。老师的输出特征应该是比较固定的而学生是不太固定的需要去学习的,于是通过一个损失函数去模拟、去减少、去学习是的两个的 Logits 越小越好
2、Feature-Based Knowledge
深度神经网络善于学习到不同层级的表征,因此中间层和输出层的都可以被用做知识来训练学生模型,中间层学习知识的Feature-Based Knowledge 对于Response-Based Knowledge 是一个很好的补充,其主要思想是将教师和学生的特征激活进行关联起来,Feature-Based Knowledge 的知识转移的蒸馏损失可表示为:
(
(x),
(x)) =
(
(
(x)),
(
(x)))
通过学习流程图可知:Distillation Loss 是建立在 Teacher Model 和 Student Model 的中间层,通过中间层去建立连接关系。
这种算法的好处 Teacher 网络可以为 Student 网络提供大量的、有用的参考信息。但如何有效的从教师模型中选择提示层,从学生模型中选择引导层,仍有待进一步研究。
缺点:由于提示层和引导层大小存在明显差异,如何正确匹配教师和学生的特征也需要探讨。
3、Relation-Based Knowledge
基于Feature-Based Knowledge 和Response-Based Knowledge 中都使用了教师模型中特定层中特征的输出。基于关系的知识进一步探索了不同层或数据样本之间的关系。一般情况下,基于特征图关系的关系知识的蒸馏损失可以表示为:
(
,
) =
(
(
,
),
(
,
))
通过上图可知:相同的数据,有个 Teacher模型和 Student 模型,Distillation Loss 就不仅仅是学习网络模型中间的特征还有最后一层的特征信息,它还会学习数据样本和网络模型层之间的关系
Knowledge Distillation: A Survey
知识蒸馏可以划分为:1)Offline Distillation 2)Online Distillation 3)Self-Distillation
红色代表:预训练的模型
黄色代表:将要去训练的模型
1)Offline Distillation 通俗讲:指知识渊博的教师向学生传授知识
大多数蒸馏采用Offline Distillation,蒸馏过程被分为两个阶段
1.1、蒸馏前 Teacher 模型预训练
1.2 、蒸馏算法迁移知识
因此Offline Distillation主要侧重于知识迁移部分
通常采用单向知识转移和两阶段训练过程。在步骤 1 中需要 Teacher 模型参数量比较大,训练时间比较长,这种方式对学生模型的蒸馏比较高效。
Tips:这种训练模式下的学生模型往往过度依赖于教师模型
2)Online Distillation 通俗讲:指教师和学生共同学习知识
主要针对参数量大、精度性能好的教师模型不可获得的情况。教师模型和学生模型同时更新,整个知识蒸馏算法是一种有效的端到端可训练方案(教师模型和学生模型一起去学习)
Tips:现有的Online Distillation 往往难以获得在线环境下参数量大、精度性能好的教师模型
3)Self-Distillation 通俗讲:指学生自己学习知识
教师模型和学生模型使用相同的网络结构(自学习),同样采用端到端可训练方案,属于Online Distillation 的一种特例
知识蒸馏的过程
分成 5小步:
1 、把数据喂养到教师网络去训练,通过升温的 Softmax(T=t),得到 soft targets1
2 、把数据喂养到学生网络去训练,通过升温的 Softmax(T=t),得到 soft targets2(与步骤1是同温的)
3、通过 1 、 2 两步之后有两个结果,对这两个结果来运用一下,算一下就可以得到蒸馏损失 Distillation loss
4 、同样把数据喂养到学生网络去训练,通过正常的(未升温) Softmax(T=1),得到 soft targets3
5、通过soft targets3和Ground Truth Label(正确标签) ,这两个值再计算一下,就可以的得到一个学生损失(Student loss)
整个过程会涉及两个损失
蒸馏损失distillation loss 和 学生损失 student loss,这两个损失
蒸馏损失
输入:相同温度下,学生模型和教师模型的 soft targets
常用:KL 散度
作用:让学生网络的类别输出预测分布尽可能拟合教师网络输出预测分布(通俗讲:让学生去学老师的一些行为)
学生损失
输入:T=1 时,学生模型的 soft targets 和正确标签
常用:交叉熵损失
作用:减少教师网络中的错误信息被蒸馏到学生网络中
蒸馏损失和学生损失,两个损失函数是独立的,怎么建立联系?以及知识蒸馏整个过程的关键点有哪些?下一篇将详细介绍经典的知识蒸馏算法