如果缺失了一些层,指定false就可以
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YoloBody()
model_dict = model.state_dict()
pretrained_dict = torch.load("yolo_weights.pth", map_location=device)
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, False)
如果最后的类别数目不同,最后的filters是加载不了的,可以通过判断shape的方式来对dict做筛选
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YoloBody()
model_dict = model.state_dict()
pretrained_dict = torch.load("yolo_weights.pth", map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)