背景
最近在AXERA M55H工具链做一个语义分割模型量化。
AXERA官方文档表示argmax只能接在conv算子后,而我的deeplabv3+模型最后两个节点是resize上采样接argmax。
试了一下模型转换(编译、量化成M55H支持的.joint模型),果然在执行argmax相关操作时报错。
于是只能手动将onnx文件的argmax节点删除,在后处理来做argmax了。
ONNX删除节点
由于模型只有结尾处有一个argmax节点,所以直接找到op_type == "ArgMax"的节点将其删除即可。
node_to_rm = next(node for node in model.graph.node if node.op_type == "ArgMax")
model.graph.node.remove(node_to_rm)
onnx.save(model, dst_model)
此时用生成的新model推理会报错,大概错误信息是output节点不在graph中。
查了一些资料发现model.graph.output和model.graph.node是平行的存在,也就是说输出节点是区别于中间节点独立存储在model.graph.output中的。(输入节点也类似)
上述的操作删除了最后一个argmax节点,但是没有删除输出节点。并且,一个graph必须包含1个以上的输入和输出节点。所以我们需要删除原有的输出节点并创建新的,即更换输出节点。
ONNX更换输出节点
model.graph.output是一个list,包含所有输出节点。
目前包含一个输出节点,就是之前的经过argmax的分割特征图。
输出节点跟普通节点的数据结构不同,它包含了节点名、输出的数据结构等信息,
因此只需要在现有节点基础上进行如下修改即可: (也可以通过onnx.helper创建新的节点)
node_to_out = next(node for node in model.graph.node if node.output == node_to_rm.input) # 找到删除节点的上游节点,作为输出节点的前置
out = model.graph.output[0] # 在原来的输出节点基础上改即可
out.name = node_to_out.output[0] # 修改为新的输出节点名字
out.type.tensor_type.shape.dim[1].dim_value = 4 # 该维度指channel数,argmax以后为1,改为4,因为该模型有4类,onehot表示
out.type.tensor_type.elem_type = 1 # 1表示float32, 经过argmax后是7,表示int64
附录: onnx的elem_type
elem_type: 1 --> float32
elem_type: 2 --> uint8
elem_type: 3 --> int8
elem_type: 4 --> uint16
elem_type: 5 --> int16
elem_type: 6 --> int32
elem_type: 7 --> int64
elem_type: 8 --> string
elem_type: 9 --> boolean
elem_type: 10 --> float16
elem_type: 11 --> float64
elem_type: 12 --> uint32
elem_type: 14 --> uint64
elem_type: 15 --> complex128
elem_type: 16 --> bfloat16
from: https://blog.csdn.net/weixin_43945848/article/details/122474749