知识蒸馏(Knowledge Distilling),让你的模型轻装上阵——keras 实战

深度学习在这两年的发展可谓是突飞猛进,为了提升模型性能,模型的参数量变得越来越多,模型自身也变得越来越大。在图像领域中基于Resnet的卷积神经网络模型,不断延伸着网络深度。而在自然语言处理领域(NLP)领域,BERT,GPT等超大模型的诞生也紧随其后。这些巨型模型在准确性上大部分时候都吊打其他一众小参数量模型,可是它们在部署阶段,往往需要占用巨大内存资源,同时运行起来也极其耗时,这与工业界对模型吃资源少,低延时的要求完全背道而驰。所以很多在学术界呼风唤雨的强大模型在企业的运用过程中却没有那么顺风顺水。

知识蒸馏

为解决上述问题,我们需要将参数量巨大的模型,压缩成小参数量模型,这样就可以在不失精度的情况下,使得模型占用资源少,运行快,所以如何将这些大模型压缩,同时保持住顶尖的准确率,成了学术界一个专门的研究领域。2015年Geoffrey Hinton 发表的Distilling the Knowledge in a Neural Network的论文中提出了知识蒸馏技术,就是为了解决模型压而生的。至于文章的细节这里笔者不做过多介绍,想了解的同学们可以点击上方链接好好研读原文。不过这篇文章的主要思想就如下方图片所示:用一个老师模型(大参数模型)去教一个学生模型(小参数模型),在实做上就是用让学生模型去学习已经在目标数据集上训练过的老师模型。尽管学生模型最终依然达不到老师模型的准确性,但是被老师教过的学生模型会比自己单独训练的学生模型更加强大

这里大家可能会产生疑惑,为什么让学生模型去学习目标数据集会比被老师模型教出来的差。产生这种结果可能原因是因为老师模型的输出提供了比目标数据集更加丰富的信息,如下图所示,老师模型的输出,不仅提供了输入图片上的数字是数字1的信息,而且还附带着数字1和数字7和9比较像等额外信息。

知识蒸馏

知识蒸馏具体流程

接下来笔者介绍一下知识蒸馏在实做上的具体流程。

  • (1)定义一个参数量较大(强大的)的老师模型,和一个参数量较小(弱小的)的学生模型,
  • (2)让老师模型在目标数据集上训练到最佳,
  • (3)将目标数据的label替换成老师模型最后一个全连接层的输出,让学生模型学习老师模型的输出,希望学生模型的输出和老师模型输出之间的交叉熵越小越好。

了解到知识蒸馏的具体步骤之后,我们采用keras在mnist数据集上进行一次简单的实验。

知识蒸馏实战

导入一下必要的python 包,同时载入数据。

from keras.datasets import mnist
from keras.layers import *
from keras import Model
from sklearn.metrics import accuracy_score
import numpy as np
(data_train,label_train),(data_test,label_test )= mnist.load_data()
data_train = np.expand_dims(data_train,axis=3)
data_test = np.expand_dims(data_test,axis=3)
定义老师模型和学生模型

在下方代码中,笔者定义了一个包含3层卷积层的CNN模型作为老师模型(参数量6万),定义了一个包含512个神经元的全连接层作为学生模型(参数量4万,比老师模型少了2万)。

#####定义老师模型——包含三层卷积层的CNN模型
def teacher_model():
    input_ = Input(shape=(28,28,1))
    x = Conv2D(32,(3,3),padding = "same")(input_)
    x = Activation("relu")(x)
    print(x)
    x = MaxPool2D((2,2))(x)
    x = Conv2D(64,(3,3),padding= "same")(x)
    x = Activation("relu")(x)
    x = MaxPool2D((2,2))(x)
    x = Conv2D(64,(3,3),padding= "same")(x)
    x = Activation("relu")(x)
    x = MaxPool2D((2,2))(x)
    x = Flatten()(x)
    out = Dense(10,activation = "softmax")(x)
    model = Model(inputs=input_,outputs=out)
    model.compile(loss="sparse_categorical_crossentropy",
                 optimizer="adam",
                 metrics=["accuracy"])
    model.summary()
    return model

###定义学生模型——— 一层含512个神经元的全连接层
def student_model():
    input_ = Input(shape=(28,28,1))
    x = Flatten()(input_)
    x = Dense(512,activation="sigmoid")(x)
    out = Dense(10,activation = "softmax")(x)
    model = Model(inputs=input_,outputs=out)
    model.compile(loss="sparse_categorical_crossentropy",
                 optimizer="adam",
                 metrics=["accuracy"])
    model.summary()
    return model
训练老师模型

接下来开始训练老师模型,由于mnist数据集较为简单,在三层的CNN模型上,我设定只训练2个epoch。这里需要注意的是,如下图所示:三层卷积的CNN的有6万多个参数

t_model  = teacher_model()
t_model.fit(data_train,label_train,batch_size=64,epochs=2,validation_data=(data_test,label_test))
teacher model

训练结果如下图所示:两个epoch,CNN模型就在测试集上做到了98%的准确性。


teacher result
训练学生模型

在512个神经元的全连接层上训练mnist数据集,学生模型的参数量如下图所示:参数量只有4万个,参数量比老师模型少了2万个

s_model = student_model()
s_model.fit(data_train,label_train,batch_size=64,epochs=10,validation_data=(data_test,label_test))
student model

在学生模型上训练了10个epoch之后,测试机准确率最高也才达到0.9460,远低于CNN老师模型的0.98


student result
老师模型教学生模型

最后我们用老师模型教学生模型,进行知识蒸馏。
首先我们采用下方代码将目标数据集的label替换成老师模型的输出。

t_out = t_model.predict(data_train)

然后用学生模型去学习老师模型的输出。

def teach_student(teacher_out, student_model,data_train,data_test,label_test):
    t_out = teacher_out

    s_model = student_model
    for l in s_model.layers:
        l.trainable = True     
    
    label_test = keras.utils.to_categorical(label_test)
    
    model = Model(s_model.input,s_model.output)
    model.compile(loss="categorical_crossentropy",
                 optimizer="adam")
    model.fit(data_train,t_out,batch_size= 64,epochs = 5)
    
    s_predict = np.argmax(model.predict(data_test),axis=1)
    s_label =  np.argmax(label_test,axis=1)
    print(accuracy_score(s_predict,s_label))

最终得到的实验结果如下图所示:学生模型的性能提升到了0.9511,相比于学生模型在目标数据集上的最好成绩0.9460提升了千分之6个点。这也证明我们知识蒸馏确实起作用了。


result of student model after being taught

结语

当然我们也发现,我们的实验提升的幅度并不大,离老师模型的准确度还有巨大的差距,而要想优化知识蒸馏的性能,我们可以采取升温技术,升温技术的原理图如下图所示:将老师模型的输出在softmax激活函数之前初上一个数值大于1的数字T,这样会使得老师模型输出的个类别概率值变得较为接近。

升温技术

确实升温技术的主要目的就是将老师模型输出的各类型的概率,变得较为接近,这样老师模型的输出信息将变得更加丰富,得学生模型学会分辨出个类别之间细微的区别。当然知识蒸馏的优化方法并不只上述的升温技术这一种,这里笔者只是抛砖引玉,知识蒸馏还有更多的奥秘等着大家去探索,去学习。希望读者能够有所收获的同时,心中的好奇心也能够被激发,主动的学习知识蒸馏这门技术。

参考

https://arxiv.org/pdf/1503.02531.pdf
https://github.com/johnkorn/distillation
https://www.bilibili.com/video/av46561029/?p=54

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,332评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,508评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,812评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,607评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,728评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,919评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,071评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,802评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,256评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,576评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,712评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,389评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,032评论 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,798评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,026评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,473评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,606评论 2 350

推荐阅读更多精彩内容