前言
之前,测试通了pytroch版的yolo-v2/v3, ssd-mobilenetv1/v2目标检测代码。 相对于测试,如何用自己的数据训练一个目标检测模型才更令人兴奋。俗话曰:兵马未动,粮草先行, 在训练之前,首先需要准备好训练数据。
在许多例子中,一般都用VOC, COCO格式的数据集进行训练和测试。对于我们自己的数据,一般不是VOC/COCO格式的数据,所以一个比较笨的方法就是写一个脚本进行数据格式转换,再不济可以手动创建文件夹,直接把相应的数据复制到制定的目录,这样很麻烦。麻烦的对方主要在于:1. VOC中标签都是1张图像对应一个xml文件, xml结构数据本身相对解析麻烦,不如JSON,YAML轻巧。 2. 电脑中需要将原始数据复制2份,一份用作VOC格式数据, 另一份是原始数据。
下面将直接使用原始数据,使用pytroch提供的类对数据进行简单封装,实现数据集的索引和读取。然后再转换为VOC格式的数据。
快速起见, 采用一个公开数据集,Wider Face, 这个数据集用于做人脸检测,训练集合包含12k的图像,而且提供人脸矩形框标签
目的
- pytroch实现对widerFace数据的API封装
- 将widerFace数据转换为VOC格式的数据
开发环境
- Ubuntu 18.04
- pycharm
- Anaconda3, python3.6
- pytroch 1.0, torchvision
widerFace 人脸检测数据集
Wider Face
-
BaiDuYun链接: https://pan.baidu.com/s/1HjEsIzkQtS5ea2mOVoRFtA
标签
训练集
简单起见,将wider_face_train_bbx_gt.txt 复制到训练集合所在路径, images文件夹包含图像。
使用Pytroch Dataset API进行封装
- 代码
在pytroch中,数据集定义很简单,按照pytroch提供的套路就可以。 一般的, 首先定义一个类继承troch.utils.data.Dataset
, 然后override__len()__
,__getitem()__
方法。
__len()__
: 返回数据集容量大小__getitem()__
: 返回数据集迭代时候每一个样本及其标签数据。
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transfroms
import matplotlib.pyplot as plt
import os
import PIL.Image as Image
import PIL
import cv2
import numpy as np
class WiderFaceDataset(Dataset):
def __init__(self, images_folder, ground_truth_file, transform=None, target_transform=None):
super(WiderFaceDataset, self).__init__()
self.images_folder = images_folder
self.ground_truth_file = ground_truth_file
self.images_name_list = []
self.ground_truth = []
with open(ground_truth_file, 'r') as f:
for i in f:
self.images_name_list.append(i.rstrip())
self.ground_truth.append(i.rstrip())
self.images_name_list = list(filter(lambda x: x.endswith('.jpg') or x.endswith('.bmp'),
self.images_name_list))
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.images_name_list)
def __getitem__(self, index):
image_name = self.images_name_list[index]
# 查找文件名
loc = self._search(image_name)
# 解析人脸个数
face_nums = int(self.ground_truth[loc + 1])
# 读取矩形框
rects = []
for i in range(loc + 2, loc + 2 + face_nums):
line = self.ground_truth[i]
x, y, w, h = line.split(' ')[:4]
x, y, w, h = list(map(lambda k: int(k), [x, y, w, h]))
rects.append([x, y, w, h])
# 图像
image = PIL.Image.open(os.path.join(self.images_folder, image_name))
if self.transform:
image = self.transform(image)
if self.target_transform:
rects = list(map(lambda x: self.target_transform(x), rects))
return {'image': image, 'label': rects, 'image_name': os.path.join(self.images_folder, image_name)}
def _search(self, image_name):
for i, line in enumerate(self.ground_truth):
if image_name == line:
return i
if __name__ == '__main__':
images_folder = '/media/weipenghui/Extra/WiderFace/WIDER_train/images'
ground_truth_file = open('/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt', 'r')
dataset = WiderFaceDataset(images_folder='/media/weipenghui/Extra/WiderFace/WIDER_train/images',
ground_truth_file='/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt',
transform=transfroms.ToTensor(),
target_transform=lambda x: torch.tensor(x))
var = next(iter(dataset))
image_transformed = var['image']
label_transformed = var['label']
image_name = var['image_name']
#plt.figure()
image_transformed = image_transformed.numpy().transpose((1, 2, 0))
image_transformed = np.floor(image_transformed * 255).astype(np.uint8)
image = cv2.imread(image_name)
for rect in label_transformed:
x, y, w, h = rect
x, y, w, h = list(map(lambda k: k.item(), [x, y, w, h]))
cv2.rectangle(image, pt1=(x, y), pt2=(x + w, y + h),color=(255,0,0))
cv2.imshow('image',image)
cv2.waitKey(0)
plt.imshow(image_transformed)
plt.show()
# for i, sample in enumerate(dataset):
# print(i, sample['image'])
#
# print(len(dataset))
widerFace 转换为VOC格式数据
VOV格式
VOC数据的目录
标签xml文件
图像
xml格式
<annotation>
<folder>VOC2007</folder>
<filename>000001.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
<flickrid>341012865</flickrid>
</source>
<owner>
<flickrid>Fried Camels</flickrid>
<name>Jinky the Fruit Bat</name>
</owner>
<size>
<width>353</width>
<height>500</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>dog</name>
<pose>Left</pose>
<truncated>1</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>48</xmin>
<ymin>240</ymin>
<xmax>195</xmax>
<ymax>371</ymax>
</bndbox>
</object>
<object>
<name>person</name>
<pose>Left</pose>
<truncated>1</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>8</xmin>
<ymin>12</ymin>
<xmax>352</xmax>
<ymax>498</ymax>
</bndbox>
</object>
</annotation>
VOC转换过程
VOC的标签采用xml文件表示,因此我们需要将图像的标签写入到xml文件中。参考了许多资料,好多都是不完整的代码,于是就参考着写了一个xml生成的代码。
为了简单一些,再定义一个类表示数据集中的一个样本WiderFaceSample
。 使用WiderFace进行人脸检测,目标只有2类,即人脸和背景, 因此在xml中直接将object/name 属性固定写为face。
import dataset
import numpy as np
import os
import shutil
from lxml.etree import Element, SubElement, tostring
import pprint
from xml.dom.minidom import parseString
from xml.dom.minidom import Document
class WiderFaceSample:
def __init__(self):
self.face_rects = []
self.image_name = ''
self.image = None
self.image_width = 0
self.image_height = 0
def save_image(self, folder_path, new_name):
shutil.copy(src=self.image_name, dst=os.path.join(folder_path, new_name))
def save_label_to_txt(self, folder_path, new_name):
with open(os.path.join(folder_path, new_name), 'w') as f:
for rect in self.face_rects:
f.write('{} {} {} {}\n'.format(rect[0], rect[1], rect[2], rect[3]))
def save_label_to_xml(self, folder_path, new_xml_name, new_image_name):
doc = self._generate_xml('WiderFace_VOC', new_image_name,self.face_rects)
var = doc.toprettyxml(indent='\t', encoding='utf-8')
with open(os.path.join(folder_path, new_xml_name), 'w') as f:
f.write(var.decode())
def _generate_xml(self, folder_str, filename_str, face_rects):
# https://www.cnblogs.com/haigege/p/5712854.html
# https://www.cnblogs.com/zjutzz/p/6847848.html
# https://www.cnblogs.com/qw12/p/6185126.html
doc = Document()
# root
annotation = doc.createElement('annotation')
doc.appendChild(annotation)
# -----------folder-----------------------
folder = doc.createElement('folder')
folder_text = doc.createTextNode(folder_str)
folder.appendChild(folder_text)
annotation.appendChild(folder)
# ------------filename-------------------
filename = doc.createElement('filename')
filename_text = doc.createTextNode(filename_str)
filename.appendChild(filename_text)
annotation.appendChild(filename)
# -------------size--------------------
size = doc.createElement('size')
width_text = doc.createTextNode(str(self.image_width))
height_text = doc.createTextNode(str(self.image_height))
depth_text = doc.createTextNode(str(3))
width = doc.createElement('width')
height = doc.createElement('height')
depth = doc.createElement('depth')
width.appendChild(width_text)
height.appendChild(height_text)
depth.appendChild(depth_text)
size.appendChild(width)
size.appendChild(height)
size.appendChild(depth)
annotation.appendChild(size)
# ---------------segmented-------------
segmented_text = doc.createTextNode(str(0))
segmented = doc.createElement('segmented')
segmented.appendChild(segmented_text)
annotation.appendChild(segmented)
# --------------object----------------
for rect in face_rects:
object = doc.createElement('object')
name_text = doc.createTextNode('face')
pose_text = doc.createTextNode('Left')
truncated_text = doc.createTextNode(str(1))
difficult_text = doc.createTextNode(str(0))
name = doc.createElement('name')
name.appendChild(name_text)
pose = doc.createElement('pose')
pose.appendChild(pose_text)
truncated = doc.createElement('truncated')
truncated.appendChild(truncated_text)
difficult = doc.createElement('difficult')
difficult.appendChild(difficult_text)
object.appendChild(name)
object.appendChild(pose)
object.appendChild(truncated)
object.appendChild(difficult)
x_min, y_min, x_max, y_max = rect[0], rect[1], rect[0] + rect[2], rect[1] + rect[3]
bndbox = doc.createElement('bndbox')
xmin_text = doc.createTextNode(str(x_min))
ymin_text = doc.createTextNode(str(y_min))
xmax_text = doc.createTextNode(str(x_max))
ymax_text = doc.createTextNode(str(y_max))
xmin = doc.createElement('xmin')
ymin = doc.createElement('ymin')
xmax = doc.createElement('xmax')
ymax = doc.createElement('ymax')
xmin.appendChild(xmin_text)
ymin.appendChild(ymin_text)
xmax.appendChild(xmax_text)
ymax.appendChild(ymax_text)
bndbox.appendChild(xmin)
bndbox.appendChild(ymin)
bndbox.appendChild(xmax)
bndbox.appendChild(ymax)
object.appendChild(bndbox)
annotation.appendChild(object)
return doc
if __name__ == '__main__':
# 原始定义的Wider Face 数据集
original_dataset = dataset.WiderFaceDataset(
images_folder='/media/weipenghui/Extra/WiderFace/WIDER_train/images',
ground_truth_file='/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt')
voc_root = '/media/weipenghui/Extra/WiderFace/WiderFace_VOC'
# 生成VOC目录
folders = ['Annotations', 'ImageSets', 'JPEGImages', 'SegmentationClass', 'SegmentationObject']
for i in folders:
if not os.path.exists(os.path.join(voc_root, i)):
os.mkdir(os.path.join(voc_root, i))
os.mkdir(os.path.join(voc_root + '/' + 'ImageSets', 'Main'))
os.mkdir(os.path.join(voc_root + '/' + 'ImageSets', 'Layout'))
os.mkdir(os.path.join(voc_root + '/' + 'ImageSets', 'Segmentation'))
train_txt = open(os.path.join(voc_root + '/' + 'ImageSets/Main', 'trainval.txt'), 'w')
wfsample = WiderFaceSample()
for i, sample in enumerate(original_dataset, 1):
wfsample.image_name = sample['image_name']
wfsample.face_rects = sample['label']
wfsample.image_width, wfsample.image_height = sample['image'].size[0], sample['image'].size[1]
# 写入图像
wfsample.save_image(os.path.join(voc_root, 'JPEGImages'), str(i).zfill(6)+'.jpg')
# 写入xml
wfsample.save_label_to_xml(os.path.join(voc_root, 'Annotation'), new_xml_name=str(i).zfill(6)+'.xml', new_image_name=str(i).zfill(6)+'.jpg')
# 写入txt
train_txt.write(str(i).zfill(6) + '\n')
print('Write: {}'.format(i))
最终的结果
-
主目录
-
Annotations
- ImageSets
- JPEGImages