2025-05-29【cosyvoice】模型TorchScript化

为将CosyVoice2的三个模型llmflowhift移植到C++的libtorch中使用,现记录一些关键步骤。
先加载完整模型:

cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B',
                       load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)

1. llm

llm_script = torch.jit.script(cosyvoice.model.llm)
llm_script.save('my_model/llm_script.pt')
  • 报错1
    此时运行后报错:
torch.jit.frontend.NotSupportedError: Comprehension ifs are not supported yet:
  File "C:\Users\31458\.conda\envs\cosyvoice\lib\site-packages\transformers\models\qwen2\modeling_qwen2.py", line 1082

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,

因为JIT的目的是将python代码结构和模型编译成静态计算图,一些python内的动态语法不受支持,需更改源码。
变更为:

#原代码
# return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
#/原代码
#修改
output = []
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]:
     if v is not None:
         output.append(v)
return tuple(output)
#/修改
  • 报错2
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "C:\Users\31458\.conda\envs\cosyvoice\lib\site-packages\transformers\models\qwen2\modeling_qwen2.py", line 742
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        **kwargs,
         ~~~~~~~ <--- HERE
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        if "padding_mask" in kwargs:

需移除**kwargs
原代码:


变更为:

  • 报错3
Wrong type for attribute assignment. Expected int but got Tensor (inferred):
  File "C:\Users\31458\.conda\envs\cosyvoice\lib\site-packages\transformers\models\qwen2\modeling_qwen2.py", line 110
    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

'Qwen2RotaryEmbedding._set_cos_sin_cache' is being compiled since it was called from 'Qwen2RotaryEmbedding.forward'
  File "C:\Users\31458\.conda\envs\cosyvoice\lib\site-packages\transformers\models\qwen2\modeling_qwen2.py", line 122
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

        return (

变更前:



变更后:


  • 报错4
RuntimeError:
Unknown type name 'torch.jit.ScriptModule':
  File "C:\Users\31458\.conda\envs\cosyvoice\lib\site-packages\torch\nn\modules\module.py", line 499

        """
        if persistent is False and isinstance(self, torch.jit.ScriptModule):
                                                    ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            raise RuntimeError("ScriptModule does not support non-persistent buffers")

'Qwen2RotaryEmbedding.register_buffer' is being compiled since it was called from 'Qwen2RotaryEmbedding._set_cos_sin_cache'
  File "C:\Users\31458\.conda\envs\cosyvoice\lib\site-packages\transformers\models\qwen2\modeling_qwen2.py", line 121
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
'Qwen2RotaryEmbedding._set_cos_sin_cache' is being compiled since it was called from 'Qwen2RotaryEmbedding.forward'
  File "C:\Users\31458\.conda\envs\cosyvoice\lib\site-packages\transformers\models\qwen2\modeling_qwen2.py", line 127
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

        return (

变更前:



变更后:


在libtorch中使用export模型

使用export导出的模型如何保存,能在C++的libtorch中读取吗@FCV9

Answer

是的,通过 torch.export 导出的模型可以保存并在 C++ 的 libtorch 中加载,但需要额外步骤确保兼容性。以下是具体操作指南:


1. 正确保存导出模型(Python 端)

# 当前代码已正确使用 torch.export.save
torch.export.save(llm_export, 'my_model/llm_export.pt')  # 保存为PT文件

2. C++ libtorch 加载准备

必须条件:

  • libtorch 版本 ≥ 2.2(与 PyTorch Python 版本严格匹配)
  • 编译时启用 -DUSE_EXPORT_MODEL=ON

加载代码示例(C++):

#include <torch/script.h>

// 1. 加载导出的模型
auto model = torch::jit::load("my_model/llm_export.pt");
model.eval();

// 2. 准备输入(需与Python导出时的形状/类型一致)
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::randn({...}));  // 对应text
// ... 添加其他输入(顺序需与forward一致)

// 3. 执行推理
auto output = model.forward(inputs);

3. 关键注意事项

  1. 输入/输出规范

    • C++ 输入必须与 Python 导出时的 args/kwargs 顺序和类型完全一致
    • 如果 Python 端使用 kwargs,C++ 仍需按 forward 参数顺序传参
  2. 动态形状支持

    • 若导出时配置了 dynamic_shapes,C++ 输入尺寸需在定义的动态范围内
    • 例如 Python 设置 {0: "batch"},则 C++ 的对应维度可变化
  3. 操作符兼容性

    • 检查模型是否包含 libtorch 不支持的 Python 操作(如自定义算子)
    • 可通过 torch.export.export(model).graph.print() 查看所有操作符

4. 验证步骤(推荐)

  1. 在 Python 中测试序列化/反序列化

    loaded = torch.export.load('my_model/llm_export.pt')
    out1 = llm_export(*args, **kwargs)  # 原始导出模型
    out2 = loaded(*args, **kwargs)      # 重新加载的模型
    assert torch.allclose(out1, out2)   # 验证一致性
    
  2. 检查 C++ 可用性

    # 导出为TorchScript(备用方案,兼容性更好)
    torch.jit.save(torch.jit.script(model), "llm_backup.pt")
    

5. 常见问题解决

  • 报错 Unsupported operator
    在导出时添加 strict=False 参数:
    llm_export = torch.export.export(llm, llm_args, llm_kwargs, strict=False)
    
  • 内存泄漏
    确保 C++ 端用 torch::NoGradGuard no_grad; 包裹推理代码
  • 性能问题
    在 C++ 中调用 torch::jit::optimize_for_inference(model)

通过以上步骤,你可以确保导出的模型在 libtorch 中正常加载和推理。建议优先测试简单的静态形状模型,再逐步验证动态特性。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容