原文地址:《Channel-wise Knowledge Distillation for Dense Prediction》
代码地址: https://git.io/Distille(由原文提供,好像打不开了)
该文发表在ICCV2021上。文章针对密集性预测任务(dense prediction)提出一种简单而有效的蒸馏方式,之前的知识蒸馏方式之前对于密集预测任务的蒸馏方法都是通过在空间维度上对齐老师和学生网络的activation maps,文章认为直接在空间维度上蒸馏可能会将老师网络中的多余信息带入学生网络,所以文章采用通道维度上进行蒸馏。
这里在说明一下dense prediction任务,dense prediction是一种将输入图片映射为复杂输出的一类任务,例如语义分割、深度估计、物体检测等。(参考文献《Structured Knowledge Distillation for Dense Prediction》)
一、KL散度
在解释文章思想之前,想先介绍一下KL散度。
根据维基百科中的解释,KL散度是用来度量使用基于Q的分布来编码服从P的分布的样本所需的额外的平均比特数。通俗来说KL散度就是用来给出P分布和Q分布之间的差异值。
对于离散随机变量,P和Q两个分布的KL散度可以用下列公式表示:
还要注意一点的是KL散度是不对称的,所以存在公式
二、方法介绍
2.1 空间维度的蒸馏(spatial distillatioin)
在介绍本文的蒸馏方法前,先介绍一下空间维度的蒸馏方式。
空间维度的蒸馏方式可以用下式表示:
上式中表示预测值与gt之间的loss,那语义分割来说通常使用交叉熵loss。和表示logits输出或者网络中的激活层,T和S上角标分别表示老师模型和学生模型。 为平衡loss的超参数。表示loss计算函数,表示对输入的预测或者进行变换,不同算法这两个符号计算方式不一样,具体如下表所示。
从上述公式和表中可以看出,该类方法是针对输出所有通道作为空间上的一点来计算的。
2.2 通道维度的蒸馏
如下图c所示,网络输出特征不同的通道关注的重点是不一样的。(a,b表示空间维度和通道维度的蒸馏方式)
为了更好的对网络输出的每个通道进行知识提取,在计算学生和老师网络之间差异之前,先在通道维度上,将输出的激活值转换成一个概率分布图,然后再去计算老师和学生之间的差异。
先进行一些符号的定义,老师网络用T表示,学生网络用S表示,表示老师输出的特征,表示学生输出的特征。
通道维度的蒸馏loss也可以用通用的公式形式表示
在文中提出的方法,表示将激活通道转换为概率分布:
其中表示通道的索引,i表示每个通道上的空间位置索引。表示超参,该值越大,表示概率越分散(softer),也即关注的空间位置越大。
当存在老师的输出通道数与学生输出通道数不一致的情况,这里使用的卷积对学生的输出进行处理,使其与老师的输出通道数一致。
上面是用来判断老师的输出通道分布与学生的输出通道分布的差异,这里采用KL散度来计算
文字基本原理就是这些,具体实验可以查看原文,但有一点说明的是,文章中在做feature的蒸馏,好像并没有说明用的是哪些层输出的feature。