image.png
又是一个年久待修的问题
这个好像是pytorch的一个通病,在不同的Python版本下会有问题。
原代码:
def load_param(self, trained_path):
param_dict = torch.load(trained_path)
for i in param_dict:
if 'classifier' in i:
continue
self.state_dict()[i].copy_(param_dict[i])
将
param_dict = torch.load(trained_path)
改为:
param_dict = torch.load(trained_path).state_dict()