一 写在前面
未经允许,不得转载,谢谢~~~
这篇文章属于knowledge distillation,但是与之前Hiton大佬提出的从复杂模型迁移到小模型在整体的思路上有很大的不同,一个是从model的角度,一个是从dataset的角度,观点挺新颖的。
放上原文链接及最早提出知识蒸馏的文章链接供大家参考~
二 主要内容
2.1 文章工作
这篇文章的最核心的idea就在于之前的工作model distillation,即将知识从复杂模型迁移到更简单的小模型,而文章从不同的角度提出dataset distillation,保持模型结构不变,将知识从大的训练集凝练到小的训练集。
for example, 对于包含60,000张训练图像的minst数据集,文章可以为其生成10张synthetic distilled images,即每个类仅为其生成一张,这样一个生成的小数据集就能达到与原有数据集类似的训练效果。
围绕这个核心创新点,总结来说文章完成了以下几件事情:
- idea: network distillation ---》dataset distillation
- 可以将几千几万张训练图像 ---》 几张distilled images, 甚至做到1张/1类。
- 网络学习的任务从面向任务优化权重 ---》 面向任务优化distilled images的各个像素,可以理解为网络的目标为如何合成这些图像。
- 尝试了4种不同的网络初始化方法。
- fixed initialization;
- random initialization;
- fixed pre-trained weights;
- random pre-trained weights。
- 在image classfication和poisoning attack两个任务上都进行实验并取得不错的结果。
2.2 实验结果图
- 展示了用distilled images也能用于将网络模型训练的很好(image classification);
- 展示了用distilled imaged可以很快将在一个数据集上训练过的模型在另一个数据集上进行fine-tune (image classification);
- 展示了用distilled images可以用于攻击已经训练好的网络模型 (poisoning attack)。
2.3 相关工作
在这里简单列一下相关的方向,具体的关系与差别就不说啦~
- Knowledge distillation.(知识蒸馏)
- Dataset pruning,core-set construction,and instance selection.(数据集修剪,数据集子集构建,样本选择)
- Gradient-based hyperparameter optimization(基于梯度下降的参数优化)
- Understanding datasets.(数据集理解)
三 方法
3.1 setup
- 给定训练集x={xi}, i 属于[1,N]
- 用θ表示网络参数,l(xi,θ) 表示数据点xi的损失函数。
-
对应的训练目标即为:(其中l(x,θ) 表示在整个训练集x上的平均损失)
3.2 optimizing distilled data (fixed init, single GD step)
- 标准的训练方法是对minibatch进行梯度下降,每一步t,都对网络参数从θ:t优化到θ(t+1);
- 文章的目标是要学习一个很小的数据集x~,在初始化参数θ0已知的情况下,只优化一步得到的θ1即为式1所示;
- 那么此时的优化目标即为寻找合适的x~和学习率 η~,具体如式2所示。
3.3 distillation for random initializations(random init, single GD step)
- 与3.2中fixed init中不同的就是不再受限于固定的初始化参数θ。
- 单步的参数优化过程其实就在上式3中将θ0改成服从p()分布,然后随机采样得到,作为初始化参数,具体见下式4。
-
实验验证这样随机初始化训练出来的distilled images会看起来更make sense。
- 网络整体的优化过程:
这里最重要的是要理解第6,7行是先在distilled images上更新网络参数,然后用这个网络参数在真实的数据集上去做loss评价,我觉得可以理解作者希望网络在distilled images和真实的训练集上的loss都低,以达到distilled images的生成。 最后第9行是用第7行得到的真实数据集上的loss来指导distilled images像素和学习率的更新。
3.4 analysis of a simple linear case
- 该部分主要是以一个简单的线性函数为例子进行理论推导,证明M最小是多少才能保证经过一次GD step之后可以达到与全训练集一样的训练效果。
- 整体推导过程就不写了,感兴趣的同学可以去看一下公式~
- 最终给出的结论是:dTd必须满秩,且M>=D,其中d表示原始训练集在各个类上的分布矩阵[N*D], N是原始数据集大小,D是类别数据。
3.5 multiple gradient descent steps and multiple epochs
- 前面介绍的fixed init 和random init,以及整个dataset distillation的算法图都是基于single GD step的情况的。
- 当多步的时候就将algorithm1中第6行从原来的单步改成:
- 相应的第9行也用反向传播算法逐步传播梯度。
- 文章中还用了优化算法来加快梯度回传的过程。
3.6 distillation with different initializations
不同的参数初始化方法。
- 其中pre-trained weights的初始化方法还可以用于度量dataset distillation方法在减小两个数据集之间gap上的效果。(开头实验结果图中的第二栏所表现的)
3.7 distillation with different objectives
- 不同的训练目标函数可以使的distilled data表现出不同的行为。
- 之前提到的都是image classification。
- 该方法还可以用于攻击网络。
-
简单来说,为了让网络将类别原来为K的图像,错分成类别T,那么对应的目标函数即为:
四 写在最后
实验部分就不再放了。
主要内容就是这些,整篇文章还是挺有意思的,创新点鲜明突出,实验支持也比较完整~
希望自己也能做出一篇这样的work。
路漫漫其修远兮,吾将上下而求索,加油,给所有正在努力科研的人儿~~
φ(๑˃∀˂๑)♪