import io
import numpy as np
import torch
import torch.onnx
from model.nets.yolo4 import YoloBody
from conf.my_conf import anchors_path, classes_path, model_path_train
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_classes(classes_path):
'''loads the classes'''
with open(classes_path) as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names
def get_anchors(anchors_path):
'''loads the anchors from a file'''
with open(anchors_path) as f:
anchors = f.readline()
anchors = [float(x) for x in anchors.split(',')]
return np.array(anchors).reshape([-1, 3, 2])[::-1, :, :]
def test():
# -------------------------------#
# 获得先验框和类
# -------------------------------#
class_names = get_classes(classes_path)
anchors = get_anchors(anchors_path)
num_classes = len(class_names)
# 创建模型
model = YoloBody(len(anchors[0]), num_classes)
model_path = r'D:\slife\service_project\object-detections-yolov4-1\data\weights\Epoch75-Total_Loss1.3492-Val_Loss3.9141.pth'
# 加快模型训练的效率
print('Loading weights into state dict...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
dummy_input1 = torch.randn(1, 3, 64, 64)
# dummy_input2 = torch.randn(1, 3, 64, 64)
# dummy_input3 = torch.randn(1, 3, 64, 64)
input_names = ["actual_input_1"]
output_names = ["output1"]
# torch.onnx.export(model, (dummy_input1, dummy_input2, dummy_input3), "YOLO.onnx", verbose=True, input_names=input_names, output_names=output_names)
torch.onnx.export(model, dummy_input1, "YOLO.onnx", verbose=True, input_names=input_names,
output_names=output_names)
if __name__ == "__main__":
test()
直接将YoloBody替换成需要转换的模型,然后修改 model_path,输入和onnx模型名字然后执行即可。
注意:上面代码中注释的dummy_input2,dummy_input3,torch.onnx.export对应的是多个输入的例子