pytroch学习(二十一)—C++(libTorch)调用pytroch预训练模型

前言

当我们训练好一个CNN模型之后,可能要集成到项目工程中,或者移植到到不同的开发平台(比如Android, IOS), 一般项目工程或者App大多数采用C/C++, Java等语言,但是采用pytroch训练的模型用的是python语言,这样就存在一个问题,如何使用C/C++、Java调用预训练好的模型, 如果解决了这个问题,那么训练好的模型才可以走出实验室,在App中得到广泛应用。

本章内容将完整介绍pytroch预训练模型的C++调用,关于Java的调用,其实也不难,如果掌握了C++调用 pytorch模型的方法, Java可以通过JNI调用C++。


开发环境

  • Ubuntu 18.04
  • Clion
  • CMake
  • opencv
  • libTorch

配置libTorch

首先,在pytorch官网下载libtroch, 官网提供了win/linux/Mac系统编译好的库,省去了编译库的过程。

https://pytorch.org/

image.png

下载好之后,解压到一个路径即可。


image.png
image.png
image.png

测试一个简单demo

创建一个目录,example-app

image.png

新建2个文件


image.png
  • example-app.cpp
#include <torch/torch.h>
#include <iostream>

int main() {
  torch::Tensor tensor = torch::rand({2, 3});
  std::cout << tensor << std::endl;
}

  • CMakeLists.txt
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)

find_package(Torch REQUIRED)

add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)

然后打开终端,输入:mkdir build

image.png

继续: cd build

开始编译:


image.png
image.png

执行:


image.png

libTorch调用预训练好的性别分类模型

上面的例子是pytorch官网的demo, 下面本人模仿官方的demo, 将使用libTorch C++ API调用自己预训练好的.pth模型。

第一步

将预训练好的模型转换

import torch
import torchvision.models as models
import torch.nn as nn


# 载入预训练的型
model = models.squeezenet1_1(pretrained=False)
model.classifier[1] = nn.Conv2d(in_channels=512, out_channels=2, kernel_size=(1, 1), stride=(1, 1))
model.num_classes = 2
model.load_state_dict(torch.load('./model/model_squeezenet_utk_face_20.pth', map_location='cpu'))
model.eval()
print(model)

# model = models.resnet18(pretrained=True)
# print(model)

example = torch.rand(1, 3, 224, 224)

traced_script_module = torch.jit.trace(model, example)

output = traced_script_module(torch.ones(1, 3, 224, 224))
print(output)


# ----------------------------------
traced_script_module.save("./model/model_squeezenet_utkface.pt")


第二步

编写C++代码

  • 准备好上一步骤转换的模型文件
  • 准备几张测试图像

由于涉及到图像的加载与处理,本人使用opencv进行读取和处理。

Tips:
训练过程中,采用PIL.Image加载图像(3通道 RGB),然后Resize到224 x 224大小, 之后再进行ToTensor。因此使用C++ libTorch时候也需要按照上述过程对图像进行预处理。

  1. cv::imread() 默认读取为三通道BGR,需要进行B/R通道交换,这里采用 cv::cvtColor()实现。

  2. 缩放 cv::resize() 实现。

  3. opencv读取的图像矩阵存储形式:H x W x C, 但是pytorch中 Tensor的存储为:N x C x H x W, 因此需要进行变换,就是np.transpose()操作,这里使用tensor.permut()实现,效果是一样的。

  4. 数据归一化,采用tensor.div(255) 实现。

#include <torch/script.h> // One-stop header.
#include <opencv2/opencv.hpp>
#include <iostream>
#include <memory>

//https://pytorch.org/tutorials/advanced/cpp_export.html

string image_path = "/home/weipenghui/Project-dev/Cpp/testLibTorch2/image";

int main(int argc, const char* argv[]) {

    // Deserialize the ScriptModule from a file using torch::jit::load().
    std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("/home/weipenghui/Project-dev/Cpp/testLibTorch2/model/model_squeezenet_utkface.pt");

    assert(module != nullptr);
    std::cout << "ok\n";


    //输入图像
    auto image = cv::imread(image_path +"/"+ "7.jpg",cv::ImreadModes::IMREAD_COLOR);
    cv::Mat image_transfomed;
    cv::resize(image, image_transfomed, cv::Size(224, 224));
    cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);

    // 图像转换为Tensor
    torch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols,3},torch::kByte);
    tensor_image = tensor_image.permute({2,0,1});
    tensor_image = tensor_image.toType(torch::kFloat);
    tensor_image = tensor_image.div(255);

    tensor_image = tensor_image.unsqueeze(0);


    // 网络前向计算
    // Execute the model and turn its output into a tensor.
    at::Tensor output = module->forward({tensor_image}).toTensor();

    auto max_result = output.max(1,true);
    auto max_index = std::get<1>(max_result).item<float>();

    if (max_index == 0){
        cv::putText(image, "male", cv::Point(50, 50), 1, 1,cv::Scalar(0, 255, 255));
    }else{
        cv::putText(image, "female", cv::Point(50, 50), 1, 1,cv::Scalar(0, 255, 255));
    }

    cv::imwrite("./result7.jpg", image);

    //cv::imshow("image", image);
    //cv::waitKey(0);


//    at::Tensor prob = torch::softmax(output,1);
//    auto prediction = torch::argmax(output, 1);
//
//    auto aa = prediction.slice(/*dim=*/0, /*start=*/0, /*end=*/2).item();
//
//    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/2) << '\n';
//    std::cout << prob.slice(/*dim=*/1, /*start=*/0, /*end=*/2) << '\n';
//    std::cout <<prediction.slice(/*dim=*/0, /*start=*/0, /*end=*/2)<<"\n";

}


编写CMakeLists.txt文件
目的是将libTorch, opencv配置好,确保程序可以正常编译链接。

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)

#include_directories(/home/weipenghui/Lib-dev/opencv_3.4.3_contrib/opencv-3.4.3/build/install_cv/include)
# set(CMAKE_PREFIX_PATH "/home/weipenghui/Lib-dev/opencv_3.4.3_contrib/opencv-3.4.3/build/install_cv")
find_package(OpenCV REQUIRED)

set(CMAKE_PREFIX_PATH
        /home/weipenghui/Lib-dev/libtorch-shared-with-deps-latest/libtorch
        /home/weipenghui/Lib-dev/opencv_3.4.3_contrib/opencv-3.4.3/build/install_cv)

find_package(Torch REQUIRED)



add_executable(example-app main.cpp)
target_link_libraries(example-app ${TORCH_LIBRARIES} ${OpenCV_LIBS})
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)

识别结果

result.jpg
result2.jpg
result3.jpg
result4.jpg
result5.jpg
result6.jpg
result7.jpg

End

参考:
https://pytorch.org/cppdocs/installing.html
https://pytorch.org/tutorials/advanced/cpp_export.html

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,287评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,346评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,277评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,132评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,147评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,106评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,019评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,862评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,301评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,521评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,682评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,405评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,996评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,651评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,803评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,674评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,563评论 2 352

推荐阅读更多精彩内容