Pytorch Tips

# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))
  • 中断时保存参数
try:
    train_net(net=net, epochs=args.epochs, batch_size=args.batchsize,
              lr=args.lr, gpu=args.gpu, img_scale=args.scale)
except KeyboardInterrupt:  # 用户中断执行(通常是输入^C)
    import time
    save_time = time.strftime("%Y-%m-%d-%H-%M", time.localtime())
    torch.save(net.state_dict(), '{}_INTERRUPTED.pth'.format(save_time))
    print('Saved interrupt')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)

将该代码添加至save_model合适的位置,可实现“Early Stopping”

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。