问题描述
因为在实际的深度学习中,可能在自己的base网络基础之上对网络进行一些增删操作,比如说有些attention模块可以说是即插即用的,在这样的情况之下,我们修改一小部分网络结构后希望在训练的初期将之前的没有改变的网络层训练好的参数进行加载,以节约自己的模型训练时间:先放一段代码:
import torch
from network.res_unet import ResUNet
net = ResUNet(in_ch=3, out_ch=3)
old_net = torch.load('runs/resUnet_3class/checkpoint/cp_030.pth', map_location='cpu')
print(type(old_net))
i = 0
for key, v in old_net['net'].items():
if i < 2:
i += 1
print(key)
print(v)
print('--------------------------------------------------------------------------------------------')
i = 0
for key, v in net.state_dict().items():
if i < 2:
i+=1
print(key)
print(v)
net.load_state_dict(old_net['net'],strict=False)
print('--------------------------------------------------------------------------------------------')
i = 0
for key, v in net.state_dict().items():
if i < 2:
i += 1
print(key)
print(v)
print('end_signal')
此时net就是resnet其实和加载的cp_030.pth的网络是一模一样的,先以其举例进行一下说明(其中只打印了网络的前2层)。
说明:
我的torch.load出来的old_net出来的type是一个dict,其中old_net['net']是具体的orderdict网络参数,因而在load_state_dict注意第一个入参是old_net['net']。(有时候torch.load下来的是一个orderdict的类,此时net.load_state_dict(old_net,strict=False))即可。
输出的结果展示:
1、old_net['net']第一层参数:
2、未经load_state_dict的net第一个参数:
3、经过load_state_dict的net第一个参数:
由此可见,经过load_state_dict的net已经将原始old_net的参数载入。
接下来,对Net结构进行修改:
此时打印了一下网络修改先后的前四组参数名称,发现多了两个参数:
然后我们再进行一次load_state_dict对比新添加的inc.0.bias和旧网络有的参数inc.1.bias:
1、旧网络参数:
2、新网络未经load_state_dict的参数:
3、新网络经load_state_dict的参数:
由图可知:改了网络以后再利用Load_state_dict进行网络参数装载会自适应赋值,有参数就覆盖,没有就不。这主要是通过Load_state_dict的API中strict=False字段决定的。