2025-06-04【cosyvoice】移植模型到C++

1 如何移植

将模型导出oonx格式,然后在C++端加载,虽然oonx支持动态维度输入,但不支持python控制流,如if else for等操作,所以需分割推理代码。
将所有的控制流代码使用C++重写,将控制流内涉及的模型推理用oonx导出。

注意区分哪些是控制流代码,哪些是模型推理

区分控制流代码和必须导出为oonx的代码

# 5. step by step decode
out_tokens = []
cache = None
for i in range(max_len):
    y_pred, cache = self.llm.forward_one_step(lm_input,
                                              masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
                                              cache=cache)
    logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
    top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
    # 测试
    input_ = torch.tensor([1.88, 2.55], dtype=float, device=device)
    ouput_ = self.sampling_ids(input_, out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
    print(ouput_)
    # 测试
    if top_ids == self.speech_token_size:
        break
    if top_ids > self.speech_token_size:
        continue
    # in export mode, return tokens in list
    out_tokens.append(top_ids)
    lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
return out_tokens

以上是cosyvoice中的一段源码,其中的for循环就是无法在oonx中导出的控制流代码。for循环中,如
forward_one_stepllm_decoder涉及一些torch::nn::XXX参数化模块,他们需要正确加载训练后的权重才能正确工作,所以必须导出为oonx格式,而sampling_ids只涉及torch::softmax, torch::matmul不继承nn.Module,是函数式操作,直接使用C++的libtorch重写即可

2 oonx输出数量的确定规则

return语句

  • 返回单个Tensor (return x),则视为1个输出
  • 返回的是 Tuple(如 return (x, y, z)),ONNX 会将 Tuple 中的每个元素视为独立的输出,数量等于 Tuple 的长度。
  • 如果返回的是 嵌套 Tuple(如 return (a, (b, c))),ONNX 会递归展开,最终输出数量为所有叶子节点的 Tensor 数量(此例为 3 个:a, b, c)

3 一些基本的oonx验证操作

  1. 检查模型的输出数量
import onnx
model = onnx.load("model.onnx")
print("输出数量:", len(model.graph.output)) 
  1. 检查模型格式合法性
import onnx
onnx_path = 'my_model/llm_encoder_export.onnx'
model_onnx = onnx.load("my_model/llm_encoder_export.onnx")
onnx.checker.check_model(model_onnx)  # 检查模型格式合法性
rprint(f'[green]{onnx_path}的输入格式: [/green]', [input.name for input in model_onnx.graph.input])
  1. 运行一个模型,实现输入输出

注意输入必须全为Tensor,然后再转为平台必须格式

python: numpy数组

import onnxruntime as ort
input_dict = {
    'text': llm_model_input['text'], # 变长 1*x
    'prompt_text': llm_model_input['prompt_text'], # 变长 1*x
}
ort_session = ort.InferenceSession("my_model/llm_export.onnx")
onnx_out = ort_session.run(None, {k: v.to('cpu').numpy() for k,v in input_dict.items()})

第一个参数指定输出节点名称,也就是导出模型时指定的output_names中的节点名。
如果指定了输出节点名称,则只输出指定节点对应的输出。

C++: 将torch::Tensor转为Ort::Value

  • 提取 Tensor 数据(转为float*int64_t*等)。
  • 获取 Tensor 的形状信息(std::vector<int64_t>)。
  • 构造 Ort::Value 用于 ONNX Runtime 推理。
#include <onnxruntime_cxx_api.h> // ONNX Runtime C++ API
#include <torch/torch.h>

// 假设已经加载了 ONNX 模型,并初始化了 Ort::Session
Ort::Session session(env, "model.onnx", session_options);

// 1. 准备一个 PyTorch Tensor(示例)
torch::Tensor torch_tensor = torch::rand({1, 3, 224, 224}); // 假设是输入张量

// 2. 转换为 ONNX Runtime 可用的格式
// 2.1 获取 Tensor 数据指针(必须是连续内存)
float* tensor_data = torch_tensor.contiguous().data_ptr<float>();

// 2.2 获取 Tensor 的形状
std::vector<int64_t> input_shape;
for (auto dim : torch_tensor.sizes()) {
    input_shape.push_back(dim);
}

// 3. 构造 Ort::Value
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(
    OrtAllocatorType::OrtArenaAllocator, 
    OrtMemType::OrtMemTypeDefault
);

// 创建 Ort::Value
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
    memory_info,
    tensor_data,                // 数据指针
    torch_tensor.numel(),       // 元素总数
    input_shape.data(),         // 形状指针
    input_shape.size()          // 形状维度数
);

// 4. 运行推理
std::vector<Ort::Value> outputs = session.Run(
    Ort::RunOptions{nullptr}, 
    input_names.data(),         // 输入节点名称(需提前定义)
    &input_tensor,             // 输入 Ort::Value
    1,                         // 输入数量
    output_names.data(),       // 输出节点名称(需提前定义)
    1                          // 输出数量
);

4. 修复llm decoder1

  • /model/model/trilu节点的两个输入(bool int),其中bool类型不支持tensorRT转换,需更改源码




    构造输入时,cache_input的第2轴不能为0,否则丢失cat节点,不支持TensorRT导出

  • /model/model/where_5节点对opt尺寸的广播失败,因为aixs 3不相等或不为1,修改代码

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。