pytorch加载模型时的魔改

如果缺失了一些层,指定false就可以

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YoloBody()
model_dict = model.state_dict()
pretrained_dict = torch.load("yolo_weights.pth", map_location=device)
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, False)

如果最后的类别数目不同,最后的filters是加载不了的,可以通过判断shape的方式来对dict做筛选

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YoloBody()
model_dict = model.state_dict()
pretrained_dict = torch.load("yolo_weights.pth", map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。