用pytorch训练了安全帽检测的二分类模型,训练的模型在python中效果还不错,现在想通过libtorch在C++调用,找了网上的教材经验,已经可以调用预测了,但是效果跟python里面差很多,非常不准,有经验的人都说要先保证输入是一致的。尝试解决了很久没搞定,不熟悉网络预测,下面把用到的代码都贴出来,希望能有些帮助。
(1)下面首先是在python中的用法:
def resnet50():
model = torchvision.models.resnet50(pretrained = True)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 2)
return model
class HatRecognize(object):
def __init__(self):
self.device = torch.device('cuda')
self.model = resnet50()
self.model.to(self.device)
self.model.load_state_dict(torch.load('/data/mode/50/vest.pt'))
self.model.eval()
self.label_list = {0:'vest',1:'novest'}
@torch.no_grad()
def func(self,img):
img_tensor = self.transform(img)
result = self.model(img_tensor)
re = result.cpu().numpy().tolist()
label = int(result.data.cpu().argmax())
return self.label_list[label], max(re[0])
def transform(self,part):
part = Image.fromarray(cv2.cvtColor(part,cv2.COLOR_BGR2RGB))
test_transform = transforms.Compose([
transforms.Resize((224,224),interpolation=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
part = test_transform(part)
part = torch.unsqueeze(part,dim=0)
part = part.to(self.device)
return part
if __name__ == "__main__":
vest_recognize = HatRecognize()
img = cv2.imread(path)
result, acc = vest_recognize.func(img)
print(result)
(2)下面是模型转换代码:
model = torchvision.models.resnet50(pretrained = True)
in_features= model.fc.in_features
model.fc = nn.Linear(in_features,2)
model.load_state_dict(torch.load('/data/mode/50/vest.pt',map_location='cuda'))
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
output = traced_script_module(torch.ones(1, 3, 224, 224))
traced_script_module.save("/data/testcode/ptc++/modelfile.pt")
(3)下面是C++里面的用法:
int main() {
std::shared_ptr<torch::jit::script::Module> ptModule = torch::jit::load("/data/testcode/modelfile.pt");
assert(ptModule != nullptr);
ptModule->to(torch::kCUDA);
cv::Mat srcImage,resImage,traImage,norImage,detImage;
srcImage = cv::imread(strFile, cv::ImreadModes::IMREAD_COLOR);
cv::cvtColor(srcImage, traImage, cv::COLOR_BGR2RGB);
cv::resize(traImage, resImage, cv::Size(224, 224), 0, 0, CV_INTER_AREA);//CV_INTER_LINEAR CV_INTER_AREA
cv::normalize(resImage, norImage, 1, 0, cv::NORM_MINMAX);
cv::convertScaleAbs(norImage, detImage);
at::Tensor tensorImage = torch::from_blob(detImage.data, {1, detImage.rows, detImage.cols, 3}, torch::kByte);
tensorImage = tensorImage.permute({0, 3, 1, 2});
tensorImage = tensorImage.toType(torch::kFloat);
tensorImage = tensorImage.to(torch::kCUDA);
torch::Tensor result = ptModule->forward({tensorImage}).toTensor();
auto max_result = result.max(1,true);
auto max_index = std::get<1>(max_result).item<float>();
printf("detres id=%d\n", max_index);
return 0;
}