一、基本信息
论文名称:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
出版社:arXiv
发表时间:2017
二、研究目的
训练出合适的模型初始参数,使得在样本量较少的情况下迅速收敛。
三、训练方式
假设实际需要面对的问题是5 way 1 shot,也就是说有5个类别,每个类别只有一个样例,目标是模型可以利用这5个样例训练到收敛而又不至于过拟合。假设数据集有100个种类(和上述5个种类不重合),我们需要利用数据集训练出一个合适的模型,使得该模型可以在5 way 1 shot问题上快速收敛。
将数据集分成meta-train和meta-test两部分,meta-test测试模型的收敛速度(用Dtrain训练,用Dtest测试分类效果),meta-train用于训练模型(Dtrain和Dtest一起训练模型),下图中每一个横条称为一个task。
对于meta-train数据集,首先整理成上图中的形式(从100个类中随机抽取5个类,从5个类的数据集中分别随机抽1个样例作为Dtrain,从5个类数据集中随机抽2个样例作为Dtest),这样一来就会形成很多个task,这些task就是训练集(一个task相当于传统机器学习中的一个样例),多个task构成一个batch。算法如下:
- 初始化分类器参数,提取batch(在每一个batch内计算梯度,然后利用该梯度更新分类器参数,每个batch更新一次)
- 每个batch内有多个task,每个task内有Dtrain和Dtest。对于每个task,用Dtrain计算梯度并更新参数(暂时更新,仅仅用于计算对应Dtest的梯度,其实他是整个batch过完后再更新),用参数更新后的分类器计算Dtest的损失。整个batch中有多个task,也就是有多个Dtest的损失,把这些损失加和然后计算梯度,进而更新分类器参数。这就完成了一次参数更新。下个batch用更新后的参数重复进行下去。
在一个batch中,算法在每个task上的Dtrain计算梯度,在Dtest上计算损失,这就相当于试探步,每个task试探不同的方向,最后将整个batch的试探结果(Dtest的loss)综合起来更新参数。
- 每个batch内有多个task,每个task内有Dtrain和Dtest。对于每个task,用Dtrain计算梯度并更新参数(暂时更新,仅仅用于计算对应Dtest的梯度,其实他是整个batch过完后再更新),用参数更新后的分类器计算Dtest的损失。整个batch中有多个task,也就是有多个Dtest的损失,把这些损失加和然后计算梯度,进而更新分类器参数。这就完成了一次参数更新。下个batch用更新后的参数重复进行下去。