我们看到torchvision提供的detection训练代码中
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.output_dir:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'args': args,
'epoch': epoch},
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
都是保存和加载了optimizer和lr_scheduler,为什么不直接保存model呢,因为考虑到adam和sgd两种常用的优化器,adam的原理 可以看 https://stats.stackexchange.com/questions/220494/how-does-the-adam-method-of-stochastic-gradient-descent-work
adam是动态调整的,和当前parameter有关,所以resume时需要加载optimizer.state_dict()
。
sgd的learning rate一般都是epoch相关,所以需要lr_scheduler.state_dict()