ONNX删除节点、更换输出节点

背景

最近在AXERA M55H工具链做一个语义分割模型量化。
AXERA官方文档表示argmax只能接在conv算子后,而我的deeplabv3+模型最后两个节点是resize上采样接argmax。
试了一下模型转换(编译、量化成M55H支持的.joint模型),果然在执行argmax相关操作时报错。
于是只能手动将onnx文件的argmax节点删除,在后处理来做argmax了。

ONNX删除节点

由于模型只有结尾处有一个argmax节点,所以直接找到op_type == "ArgMax"的节点将其删除即可。

node_to_rm = next(node for node in model.graph.node if node.op_type == "ArgMax")
model.graph.node.remove(node_to_rm)
onnx.save(model, dst_model)

此时用生成的新model推理会报错,大概错误信息是output节点不在graph中。
查了一些资料发现model.graph.output和model.graph.node是平行的存在,也就是说输出节点是区别于中间节点独立存储在model.graph.output中的。(输入节点也类似)
上述的操作删除了最后一个argmax节点,但是没有删除输出节点。并且,一个graph必须包含1个以上的输入和输出节点。所以我们需要删除原有的输出节点并创建新的,即更换输出节点。

ONNX更换输出节点

model.graph.output是一个list,包含所有输出节点。
目前包含一个输出节点,就是之前的经过argmax的分割特征图。
输出节点跟普通节点的数据结构不同,它包含了节点名、输出的数据结构等信息,
因此只需要在现有节点基础上进行如下修改即可: (也可以通过onnx.helper创建新的节点)

node_to_out = next(node for node in model.graph.node if node.output == node_to_rm.input)  # 找到删除节点的上游节点,作为输出节点的前置
out = model.graph.output[0]       # 在原来的输出节点基础上改即可
out.name = node_to_out.output[0]  # 修改为新的输出节点名字
out.type.tensor_type.shape.dim[1].dim_value = 4   # 该维度指channel数,argmax以后为1,改为4,因为该模型有4类,onehot表示
out.type.tensor_type.elem_type = 1    # 1表示float32, 经过argmax后是7,表示int64

附录: onnx的elem_type

elem_type: 1 --> float32
elem_type: 2 --> uint8
elem_type: 3 --> int8
elem_type: 4 --> uint16
elem_type: 5 --> int16
elem_type: 6 --> int32
elem_type: 7 --> int64
elem_type: 8 --> string
elem_type: 9 --> boolean
elem_type: 10 --> float16
elem_type: 11 --> float64
elem_type: 12 --> uint32
elem_type: 14 --> uint64
elem_type: 15 --> complex128
elem_type: 16 --> bfloat16
from: https://blog.csdn.net/weixin_43945848/article/details/122474749

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容