元学习(meta learning)之torchmeta

元学习(meta learning)之torchmeta

安装

pip install torchmeta

数据下载

anaconda中打开python
输入 torchmeta.datasets.Omniglot("data", num_classes_per_task=5, meta_train=True,download=True)
这句话表示下载Omniglot数据集到当前目录下的data文件夹下


image.png

运行

-在github上获取源代码后是这样的

image.png

-在examples/maml里面是这样(注意,这里已经把data/omniglot下载了)
image.png

-训练
退出python到anaconda环境
输入python train.py "data"
image.png

这里的python argparse语法参考这里

因为参数接口的设计在train.py里面

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser('Model-Agnostic Meta-Learning (MAML)')

    parser.add_argument('folder', type=str,
        help='Path to the folder the data is downloaded to.')
        
    parser.add_argument('--num-shots', type=int, default=5,
        help='Number of examples per class (k in "k-shot", default: 5).')
    parser.add_argument('--num-ways', type=int, default=5,
        help='Number of classes per task (N in "N-way", default: 5).')

    parser.add_argument('--first-order', action='store_true',
        help='Use the first-order approximation of MAML.')
    parser.add_argument('--step-size', type=float, default=0.4,
        help='Step-size for the gradient step for adaptation (default: 0.4).')
    parser.add_argument('--hidden-size', type=int, default=64,
        help='Number of channels for each convolutional layer (default: 64).')

    parser.add_argument('--output-folder', type=str, default=None,
        help='Path to the output folder for saving the model (optional).')
    parser.add_argument('--batch-size', type=int, default=16,
        help='Number of tasks in a mini-batch of tasks (default: 16).')
    parser.add_argument('--num-batches', type=int, default=100,
        help='Number of batches the model is trained over (default: 100).')
    parser.add_argument('--num-workers', type=int, default=1,
        help='Number of workers for data loading (default: 1).')
    parser.add_argument('--download', action='store_true',
        help='Download the Omniglot dataset in the data folder.')
    parser.add_argument('--use-cuda', action='store_true',
        help='Use CUDA if available.')

    args = parser.parse_args()
    args.device = torch.device('cuda' if args.use_cuda
        and torch.cuda.is_available() else 'cpu')
   
    train(args)

我训练的时候出现这个错误

image.png

解决办法
在这里遇到很大的挫折,这个bug搞了好几天了,头发掉了很多,上面那个链接是改pytorch的,但是train.py里面有一个参数--unm-workers他的默认是1,这个参数是控制python的多线程的,似乎在windows上只能用单线程,所以把这里的1改为0(当然你在anaconda里面在输入参数的时候指定--unm-workers为0也可以的)

image.png

改为0以后再训练就成功啦!!!!

image.png

参考https://github.com/tristandeleu/pytorch-meta

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容