PyTorch模型保存深入理解

前面写过一篇PyTorch保存模型的文章:Pytorch模型保存与加载,并在加载的模型基础上继续训练 ,简单介绍了PyTorch模型保存和加载函数的用法,足以快速上手,但对相关函数和参数的具体用法和代表的含义没有进行展开介绍,这篇文章用于记录之。

PyTorch保存模型的语句是这样的:
torch.save(model.state_dict(), path)
加载是这样的:
model.load_state_dict(torch.load(path))

下面我们将其拆开逐句介绍,深入理解。

1.torch.save()和torch.load()

顾名思义,save函数是PyTorch的存储函数,load函数则是读取函数。save函数可以将各种对象保存至磁盘,包括张量,列表,ndarray,字典,模型等;而相应地,load函数将保存在磁盘上的对象读取出来。

用法:

torch.save(保存对象, 保存路径)
torch.load(文件路径)

应用举例:

保存张量

In [3]: a = torch.ones(3)                                                       

In [4]: a                                                                       
Out[4]: tensor([1., 1., 1.])

In [5]: torch.save(a, './a.pth')          # 保存Tensor               

In [6]: a_load = torch.load('./a.pth')    # 读取Tensor

In [7]: a_load                                                                  
Out[7]: tensor([1., 1., 1.])

保存字典

In [11]: b = {k:v for v,k in enumerate('abc',1)}                                

In [12]: b                                                                      
Out[12]: {'a': 1, 'b': 2, 'c': 3}

In [13]: torch.save(b, './b.rar')                        

In [14]: torch.load('./b.rar')                           
Out[14]: {'a': 1, 'b': 2, 'c': 3}

可以看出,保存和读取非常方便。这里需要注意的是文件的命名,命名必须要有扩展名,扩展名可以为‘xxx.pt’,‘xxx.pth’,‘xxx.pkl’,‘xxx.rar’等形式。

2.model.state_dict()

在PyTorch中,state_dict是一个从参数名称隐射到参数Tesnor的字典对象

In [15]: class MLP(nn.Module): 
    ...:     def __init__(self): 
    ...:         super(MLP, self).__init__() 
    ...:         self.hidden = nn.Linear(3, 2) 
    ...:         self.act = nn.ReLU() 
    ...:         self.output = nn.Linear(2, 1) 
    ...:  
    ...:     def forward(self, x): 
    ...:         a = self.act(self.hidden(x)) 
    ...:         return self.output(a) 
    ...:                                                                        

In [16]: net = MLP()                                                            

In [17]: net.state_dict()                                                       
Out[17]: 
OrderedDict([('hidden.weight', tensor([[ 0.4839,  0.0254,  0.5642],
                      [-0.5596,  0.2602, -0.5235]])),
             ('hidden.bias', tensor([-0.4986, -0.5426])),
             ('output.weight', tensor([[0.0967, 0.4980]])),
             ('output.bias', tensor([-0.4520]))])

可以看出,state_dict()返回的是一个有序字典,该字典的键即为模型定义中有参数的层的名称+weight或+bias,值则对应相应的权重或偏差,无参数的层则不在其中。

除了模型中有参数的层(卷积层、线性层等)有state_dict,优化器对象:

optimizer = torch.optim.xxxx(...)    # 如SGD,Adam等

也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。以及,学习率调整器对象:

scheduler = torch.optim.lr_scheduler.xxxx(...)    # 如LambdaLR,CosineAnnealingLR等

也有一个state_dict,其中包含当前学习率的值以及迭代次数记录。

如果有程序中断后继续接着训练的需求,最好将这些状态字典都以字典形式保存下来:

check_point = {'lr': scheduler.state_dict(), 'optimizer': optimizer.state_dict(), 'model': model.state_dict()}
torch.save(check_point, path)

恢复时只需要在相应对象实例化之后进行加载即可:

check_point = torch.load(path) 
... ...
model = xxxNet(...)
model.load_state_dict(check_point['model'])
... ... 
optimizer = torch.optim.xxxx(...)
optimizer.load_state_dict(check_point['optimizer'])
... ... 
scheduler = torch.optim.lr_scheduler.xxxxLR(...) 
scheduler.load_state_dict(check_point['lr']) 

3.model.load_state_dict()

这是模型加载state_dict的语句,也就是说,它的输入是一个state_dict,也就是一个字典。模型定义好并且实例化后会自动进行初始化,上面的例子中我们定义的模型MLP在实例化以后显示的模型参数都是自动初始化后的随机数。

在训练模型或者迁移学习中我们会使用已经训练好的参数来加速训练过程,这时候就用load_state_dict()语句加载训练好的参数并将其覆盖在初始化参数上,也就是说执行过此语句后,加载的参数将代替原有的模型参数。

既然加载的是一个字典,那么需要注意的就是字典的键一定要相同才能进行覆盖,比如加载的字典中的'hidden.weight'只能覆盖当前模型的'hidden.weight',如果键不同,则不能实现有效覆盖操作。键相同而值的shape不同,则会将新的键值对覆盖原来的键值对,这样在训练时会报错。所以我们在加载前一般会进行数据筛选,筛选是对字典的键进行对比来操作的:

pretrained_dict = torch.load(log_dir)  # 加载参数字典
model_state_dict = model.state_dict()  # 加载模型当前状态字典
pretrained_dict_1 = {k:v for k,v in pretrained_dict.items() if k in model_state_dict}  # 过滤出模型当前状态字典中没有的键值对
model_state_dict.update(pretrained_dict_1)  # 用筛选出的参数键值对更新model_state_dict变量
model.load_state_dict(model_state_dict)  # 将筛选出的参数键值对加载到模型当前状态字典中

以上代码简单的对预训练参数进行了过滤和筛选,主要是通过第3条语句粗略的过滤了键值对信息,进行筛选后要用Python更新字典的方法update()来对模型当前字典进行更新,update()方法将pretrained_dict_1中的键值对添加到model_state_dict中,若pretrained_dict_1中的键名和model_state_dict中的键名相同,则覆盖之;若不同,则作为新增键值对添加到model_state_dict中。显然,这里需要的是将pretrained_dict_1中的键值对覆盖model_state_dict的相应键值对,所以对应的键的名称必须相同,所以第3条语句中按键名称进行筛选,过滤出当前模型字典中没有的键值对。否则会报错。

如果想要细粒度过滤或更改某些参数的维度,如进行卷积核参数维度的调整,假如预训练参数里conv1有256个卷积核,而当前模型只需要200个卷积核,那么可以采用类似以下语句直接对字典进行更改:

pretrained_dict['conv1.weight'] = pretrained_dict['conv1.weight'][:200,:,:,:]   # 假设保留前200个卷积核

以上。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,657评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,662评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,143评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,732评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,837评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,036评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,126评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,868评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,315评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,641评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,773评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,470评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,126评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,859评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,095评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,584评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,676评论 2 351

推荐阅读更多精彩内容