Approach
训练和部署阶段采用不同的模型形态,可以类比于自然界中很多昆虫有多种形态以适应不同阶段的需求。具体地,如蝴蝶在幼虫以蛹的形式存储能量和营养来更好的发育,但是到了后期就为了更好的繁殖和移动它就呈现了另外一种完全不一样的形态。这种方法可以称为知识蒸馏,又叫孪生网络。
所谓模型蒸馏就是将训练好的复杂模型推广能力“知识”迁移到一个结构更为简单的网络中。或者通过简单的网络去学习复杂模型中“知识”。大致可以分为两个阶段:
- 原始模型训练:
- 根据提出的目标问题,设计一个或多个复杂网络(N1,N2,…,Nt)。
- 收集足够的训练数据,按照常规CNN模型训练流程,并行的训练1中的多个网络得到。得到(M1,M2,…,Mt)
- 精简模型训练:
- 根据(N1,N2,…,Nt)设计一个简单网络N0。
- 收集简单模型训练数据,此处的训练数据可以是训练原始网络的有标签数据,也可以是额外的无标签数据。
- 将2中收集到的样本输入原始模型(M1,M2,…,Mt),修改原始模型softmax层中温度参数T为一个较大值如T=20。每一个样本在每个原始模型可以得到其最终的分类概率向量,选取其中概率至最大即为该模型对于当前样本的判定结果。对于t个原始模型就可以t概率向量。然后对t概率向量求取均值作为当前样本最后的概率输出向量,记为soft_target,保存。
- 标签融合2中收集到的数据定义为hard_target,有标签数据的hard_target取值为其标签值1,无标签数据hard_taret取值为0。Target = ahard_target + bsoft_target(a+b=1)。Target最终作为训练数据的标签去训练精简模型。参数a,b是用于控制标签融合权重的。
- 设置精简模型softmax层温度参数与原始复杂模型产生Soft-target时所采用的温度,按照常规模型训练精简网络模型。
- 部署时将精简模型中的softmax温度参数重置为1,即采用最原始的softmax
Experiment
- Mnist
- Speech Recognition
References:
https://zhuanlan.zhihu.com/p/24337627