PyTorch 之模型的保存与加载

1. torch.save

主要参数

  • obj: 对象
  • f:输出路径
2. torch.load

主要参数

  • f: 文件路径
  • map_location: 指定存放位置,cpu or gpu
方法1:保存整个module (耗时,占内存)

保存:

torch.save(net.path)

加载:

path_model = './model.pkl'
net_load = torch.load(path_model)
方法2:保存模型参数(官方推荐)

保存:

state_dict = net.state_dict()
torch.save(state_dict, path)

加载:

path_state_dict = './model_state_dict.pkl'
state_dict_load = torch.load(path_state_dict)
net.load_state_dict(state_dict_load)
3. 断点续存训练

保存断点(在epoch循环中):

if (epoch + 1) % checkpoint_interval == 0:  # 每隔checkpoint_interval保存一次
    checkpoint = {"model_state_dict": net.state_dict()  # 模型数据
                  "optimizer_state_dict": optimizer.state_dict()  # 优化器数据
                  "epoch": epoch  # 迭代次数
                  }
    path_checkpoint = './checkpoint_{}_epoch.pkl'.format(epoch)
    torch.save(checkpoint, path_checkpoint)

断点恢复:

path_checkpoint = './checkpoint_4_epoch.pkl'
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。