元学习(meta learning)之torchmeta
安装
pip install torchmeta
数据下载
anaconda中打开python
输入 torchmeta.datasets.Omniglot("data", num_classes_per_task=5, meta_train=True,download=True)
这句话表示下载Omniglot数据集到当前目录下的data文件夹下
运行
-在github上获取源代码后是这样的
-在examples/maml里面是这样(注意,这里已经把data/omniglot下载了)
-训练
退出python到anaconda环境
输入python train.py "data"
这里的python argparse语法参考这里
因为参数接口的设计在train.py里面
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser('Model-Agnostic Meta-Learning (MAML)')
parser.add_argument('folder', type=str,
help='Path to the folder the data is downloaded to.')
parser.add_argument('--num-shots', type=int, default=5,
help='Number of examples per class (k in "k-shot", default: 5).')
parser.add_argument('--num-ways', type=int, default=5,
help='Number of classes per task (N in "N-way", default: 5).')
parser.add_argument('--first-order', action='store_true',
help='Use the first-order approximation of MAML.')
parser.add_argument('--step-size', type=float, default=0.4,
help='Step-size for the gradient step for adaptation (default: 0.4).')
parser.add_argument('--hidden-size', type=int, default=64,
help='Number of channels for each convolutional layer (default: 64).')
parser.add_argument('--output-folder', type=str, default=None,
help='Path to the output folder for saving the model (optional).')
parser.add_argument('--batch-size', type=int, default=16,
help='Number of tasks in a mini-batch of tasks (default: 16).')
parser.add_argument('--num-batches', type=int, default=100,
help='Number of batches the model is trained over (default: 100).')
parser.add_argument('--num-workers', type=int, default=1,
help='Number of workers for data loading (default: 1).')
parser.add_argument('--download', action='store_true',
help='Download the Omniglot dataset in the data folder.')
parser.add_argument('--use-cuda', action='store_true',
help='Use CUDA if available.')
args = parser.parse_args()
args.device = torch.device('cuda' if args.use_cuda
and torch.cuda.is_available() else 'cpu')
train(args)
我训练的时候出现这个错误
解决办法
在这里遇到很大的挫折,这个bug搞了好几天了,头发掉了很多,上面那个链接是改pytorch的,但是train.py里面有一个参数--unm-workers他的默认是1,这个参数是控制python的多线程的,似乎在windows上只能用单线程,所以把这里的1改为0(当然你在anaconda里面在输入参数的时候指定--unm-workers为0也可以的)
改为0以后再训练就成功啦!!!!
参考https://github.com/tristandeleu/pytorch-meta