2022-02-19

image.png

又是一个年久待修的问题
这个好像是pytorch的一个通病,在不同的Python版本下会有问题。
原代码:

def load_param(self, trained_path):
    param_dict = torch.load(trained_path)
    for i in param_dict:
        if 'classifier' in i:
            continue
        self.state_dict()[i].copy_(param_dict[i])


param_dict = torch.load(trained_path)
改为:
param_dict = torch.load(trained_path).state_dict()

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容