2023-07-25 pytorch模型转ONNX模型

参考:

一、pytorch 模型保存、加载

有两种方式保存、加载pytorch模型:1)保存模型结构和参数;2)只保留模型参数。

同时保存模型结构和参数

import torch
model = ModelNet()
torch.save(model, "save.pt")
model = torch.load("save.pt")

只保存模型参数

import torch
model = ModelNet()
torch.save(model.state_dict(), "save.pt")
model.load_state_dict(torch.load("save.pt"))

二、pytorch模型转ONNX模型

根据pytorch模型可以保存模型结构和参数、只保留参数。

import torch

only_save_param = True

if only_save_param:
    model = ModelNet()                    # onnx文件中仅保留参数
else:
    model = torch.load("save.pt")     # onnx文件中同时保存模型结构和权重参数

batch_size = 1  #批处理大小
input_shape = (3,244,244)   #输入数据

# set the model to inference mode
torch_model.eval()

x = torch.randn(batch_size,*input_shape)                # 生成张量
export_onnx_file = "test.onnx"                  # 目的ONNX文件名
torch.onnx.export(torch_model,
                             x,
                             export_onnx_file,
                             opset_version=10,
                             do_constant_folding=True,         # 是否执行常量折叠优化
                             input_names=["input"],            # 输入名
                             output_names=["output"],                  # 输出名
                             dynamic_axes={"input":{0:"batch_size"},        # 批处理变量
                                                       "output":{0:"batch_size"}})

onnx_model = onnx.load(output_path)
try:
    onnx.checker.check_model(onnx_model)
except Exception:
    print("Model incorrect")
else:
    print("Model correct")

注:dynamic_axes字段用于批处理.若不想支持批处理或固定批处理大小,移除dynamic_axes字段即可.

三、torch.onnx.export() 详细介绍

export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
           input_names=None, output_names=None, operator_export_type=None,
           opset_version=None, do_constant_folding=True, dynamic_axes=None,
           keep_initializers_as_inputs=None, custom_opsets=None,
           export_modules_as_functions=False)

1. model

需要转换的模型,支持的模型类型有:torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction

2. args (tuple or torch.Tensor)

  1. 一个tuple
args = (x, y, z)

这个tuple应该与模型的输入相对应,任何非Tensor的输入都会被硬编码入onnx模型,所有Tensor类型的参数会被当做onnx模型的输入。

  1. 一个Tensor
args = torch.Tensor([1, 2, 3])

一般这种情况下模型只有一个输入

  1. 一个带有字典的tuple
args = (x, {'y': input_y, 'z': input_z})

这种情况下,所有字典之前的参数会被当做“非关键字”参数传入网络,字典种的键值对会被当做关键字参数传入网络。如果网络中的关键字参数未出现在此字典中,将会使用默认值,如果没有设定默认值,则会被指定为None。

一个特殊情况,当网络本身最后一个参数为字典时,直接在tuple最后写一个字典则会被误认为关键字传参。所以,可以通过在tuple最后添加一个空字典来解决。

#错误写法:
 
torch.onnx.export(
    model,
    (x,
     # WRONG: will be interpreted as named arguments
     {y: z}),
    "test.onnx.pb")
 
# 纠正
 
torch.onnx.export(
    model,
    (x,
     {y: z},
     {}),
    "test.onnx.pb")

f

一个文件类对象或一个路径字符串,二进制的protocol buffer将被写入此文件.

export_params (bool, default True)

如果为True则导出模型的参数。如果想导出一个未训练的模型,则设为False.

verbose (bool, default False)

如果为True,则打印一些转换日志,并且onnx模型中会包含doc_string信息。

training (enum, default TrainingMode.EVAL)

枚举类型包括:
TrainingMode.EVAL - 以推理模式导出模型。
TrainingMode.PRESERVE - 如果model.training为False,则以推理模式导出;否则以训练模式导出。
TrainingMode.TRAINING - 以训练模式导出,此模式将禁止一些影响训练的优化操作。

input_names (list of str, default empty list)

按顺序分配给onnx图的输入节点的名称列表。

output_names (list of str, default empty list)

按顺序分配给onnx图的输出节点的名称列表。

operator_export_type (enum, default None)

默认为OperatorExportTypes.ONNX, 如果Pytorch built with DPYTORCH_ONNX_CAFFE2_BUNDLE,则默认为OperatorExportTypes.ONNX_ATEN_FALLBACK。

枚举类型包括:

OperatorExportTypes.ONNX - 将所有操作导出为ONNX操作。

OperatorExportTypes.ONNX_FALLTHROUGH - 试图将所有操作导出为ONNX操作,但碰到无法转换的操作(如onnx未实现的操作),则将操作导出为“自定义操作”,为了使导出的模型可用,运行时必须支持这些自定义操作。支持自定义操作方法见链接

OperatorExportTypes.ONNX_ATEN - 所有ATen操作导出为ATen操作,ATen是Pytorch的内建tensor库,所以这将使得模型直接使用Pytorch实现。(此方法转换的模型只能被Caffe2直接使用)

OperatorExportTypes.ONNX_ATEN_FALLBACK - 试图将所有的ATen操作也转换为ONNX操作,如果无法转换则转换为ATen操作(此方法转换的模型只能被Caffe2直接使用)。例如:

opset_version (int, default 9)

默认是9。值必须等于_onnx_main_opset或在_onnx_stable_opsets之内。具体可在torch/onnx/symbolic_helper.py中找到。例如:

_default_onnx_opset_version = 9
 
_onnx_main_opset = 13
 
_onnx_stable_opsets = [7, 8, 9, 10, 11, 12]
 
_export_onnx_opset_version = _default_onnx_opset_version

do_constant_folding (bool, default False)

是否使用“常量折叠”优化。常量折叠将使用一些算好的常量来优化一些输入全为常量的节点。

example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None)

当需输入模型为ScriptModule 或 ScriptFunction时必须提供。此参数用于确定输出的类型和形状,而不跟踪(tracing )模型的执行。

dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict)

通过以下规则设置动态的维度:

KEY(str) - 必须是input_names或output_names指定的名称,用来指定哪个变量需要使用到动态尺寸。

VALUE(dict or list) - 如果是一个dict,dict中的key是变量的某个维度,dict中的value是我们给这个维度取的名称。如果是一个list,则list中的元素都表示此变量的某个维度。

具体可参考如下示例:

class SumModule(torch.nn.Module):
    def forward(self, x):
        return torch.sum(x, dim=1)
 
 
 
# 以动态尺寸模式导出模型
 
torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                  input_names=["x"], output_names=["sum"],
                  dynamic_axes={
                      # dict value: manually named axes
                      "x": {0: "my_custom_axis_name"},
                      # list value: automatic names
                      "sum": [0],
                  })
 
### 导出后的节点信息
 
##input
 
input {
  name: "x"
  ...
      shape {
        dim {
          dim_param: "my_custom_axis_name"  # axis 0
        }
        dim {
          dim_value: 2  # axis 1
...
 
 
##output
output {
  name: "sum"
  ...
      shape {
        dim {
          dim_param: "sum_dynamic_axes_1"  # axis 0
...
 

keep_initializers_as_inputs (bool, default None)

custom_opsets (dict<str, int>, default empty dict)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容