定义生成xml格式的函数
# coding:utf-8
import os
import numpy as np
import json
from PIL import Image
import cv2
def xml(num, width, height, labelname, box, imageName, imagePath):
"""
写xml文件
:param num: 第num个文件
:param width: 图的宽
:param height: 图的高
:param ImgArray: 存放图的list
:param labelname: 贴的图的名字
:param box: 贴图的框坐标
:return: 写好的xml文件
"""
xml_file = open(num, 'w') ######################gai res1,res2
xml_file.write('<annotation>\n')
xml_file.write(' <folder>JPEGImages</folder>\n')
xml_file.write(' <filename>' + imageName + '</filename>\n')
xml_file.write(' <path>' + imagePath + '</path>\n')
xml_file.write(' <source>\n')
xml_file.write(' <database>' + 'Unknown' + '</database>\n')
xml_file.write(' </source>\n')
xml_file.write(' <size>\n')
xml_file.write(' <width>' + str(width) + '</width>\n')
xml_file.write(' <height>' + str(height) + '</height>\n')
xml_file.write(' <depth>3</depth>\n')
xml_file.write(' </size>\n')
xml_file.write(' <segmented>0</segmented>\n')
print(len(labelname))
for i in range(len(labelname)):
xml_file.write(' <object>\n')
xml_file.write(' <name>' + str(labelname[i]) + '</name>\n')
xml_file.write(' <pose>Unspecified</pose>\n')
xml_file.write(' <truncated>0</truncated>\n')
xml_file.write(' <difficult>0</difficult>\n')
xml_file.write(' <bndbox>\n')
xml_file.write(' <xmin>' + str(box[i][0]) + '</xmin>\n')
xml_file.write(' <ymin>' + str(box[i][1]) + '</ymin>\n')
xml_file.write(' <xmax>' + str(box[i][2]) + '</xmax>\n')
xml_file.write(' <ymax>' + str(box[i][3]) + '</ymax>\n')
xml_file.write(' </bndbox>\n')
xml_file.write(' </object>\n')
xml_file.write('</annotation>')
return xml_file
加载
def load_annoataion(p):
text_polys = []
text_tags = []
label = 'text'
with open(p, "r", encoding='UTF-8-sig') as f:
# with open(p, "r", encoding='unicode_escape') as f:
data = f.readlines()
for item2 in data:
# print(p,item2)
# print(item.split(','))
item = item2.split(',')
if int(item[0]) > 1 and int(item[1]) > 1 and int(item[4]) > 1 and int(item[5]) > 1:
text_polys.append([item[0], item[1], item[4], item[5]])
text_tags.append(label)
# print(data)
return np.array(text_polys, dtype=np.int32), np.array(text_tags, dtype=np.str)
主程序
base_dir = "data"
if __name__ == "__main__":
xml_path = r'C:\Users\YYQ\Desktop\11_xml\\' # 自定义生成的xml文件路径
img_path = r'C:\Users\YYQ\Desktop\11' # 自定义训练集图片路径
if not os.path.exists(xml_path):
os.makedirs(xml_path)
print(len(os.listdir(img_path)))
for image in os.listdir(img_path):
print(image)
saveXml = xml_path + image[: -4] + ".xml"
print(saveXml)
imagePath = os.path.join(img_path, image)
print(imagePath)
w = 1920
h = 1080
tag_list = ["waste_bag"]
box_list = [(1086, 480, 1306, 582)]
xml(saveXml, w, h, tag_list, box_list, image, imagePath)