一、数据加载相关函数
def load_data(self, data_path: str) -> list:
"""
:param data_path: 路径
:return:
{'img_path': 'datasets\\train\\img\\img_1.jpg',
'img_name': 'img_1',
'text_polys': array([[[377., 117.],
[463., 117.],
[465., 130.],
[378., 130.]],
...
...], dtype=float32),
'texts': ['Genaxis Theatre', '[06]', '###', '62-03', 'Carpark', '###', '###'],
'ignore_tags': [False, False, True, False, False, True, True]}
"""
data_list = get_datalist(data_path)
t_data_list = []
for img_path, label_path in data_list:
data = self._get_annotation(label_path)
if len(data['text_polys']) > 0:
item = {'img_path': img_path, 'img_name': pathlib.Path(img_path).stem}
item.update(data)
t_data_list.append(item)
else:
print('there is no suit bbox in {}'.format(label_path))
return t_data_list
def _get_annotation(self, label_path: str) -> dict:
boxes = []
texts = []
ignores = []
with open(label_path, encoding='utf-8', mode='r') as f:
for line in f.readlines():
params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',')
try:
box = order_points_clockwise(np.array(list(map(float, params[:8]))).reshape(-1, 2))
if cv2.contourArea(box) > 0:
boxes.append(box)
label = params[8]
texts.append(label)
ignores.append(label in self.ignore_tags)
except:
print('load label failed on {}'.format(label_path))
data = {
'text_polys': np.array(boxes),
'texts': texts,
'ignore_tags': ignores,
}
return data
二、数据增强
预处理时使用的数据增强方式有随机水平翻转、仿射变换、改变图片大小(scale:0.5-3)和随机裁剪方式。imgaug库数据增强
该文件实现数据增强的序列
class AugmenterBuilder(object):
def __init__(self):
pass
def build(self, args, root=True):
if args is None or len(args) == 0:
return None
elif isinstance(args, list):
if root:
sequence = [self.build(value, root=False) for value in args]
return iaa.Sequential(sequence)
else:
return getattr(iaa, args[0])(*[self.to_tuple_if_list(a) for a in args[1:]])
elif isinstance(args, dict):
cls = getattr(iaa, args['type'])
return cls(**{k: self.to_tuple_if_list(v) for k, v in args['args'].items()})
else:
raise RuntimeError('unknown augmenter arg: ' + str(args))
def to_tuple_if_list(self, obj):
if isinstance(obj, list):
return tuple(obj)
return obj
class IaaAugment():
def __init__(self, augmenter_args):
self.augmenter_args = augmenter_args
self.augmenter = AugmenterBuilder().build(self.augmenter_args)
def __call__(self, data):
image = data['img']
shape = image.shape
if self.augmenter:
aug = self.augmenter.to_deterministic()
data['img'] = aug.augment_image(image)
data = self.may_augment_annotation(aug, data, shape)
return data
def may_augment_annotation(self, aug, data, shape):
if aug is None:
return data
line_polys = []
for poly in data['text_polys']:
new_poly = self.may_augment_poly(aug, shape, poly)
line_polys.append(new_poly)
data['text_polys'] = np.array(line_polys)
return data
def may_augment_poly(self, aug, img_shape, poly):
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
[imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints
poly = [(p.x, p.y) for p in keypoints]
return poly
to_deterministic()函数是将增强器设置为确定模式,否则会造成label标签与增强后图片中的目标不对应。
参考链接
示例 图一为原图,图二为增强后的结果
图一
图二
import cv2
import imgaug.augmenters as iaa
import numpy as np
import imgaug
img = cv2.imread("img_1.jpg", 1)
sequence = [iaa.Fliplr(p=0.5), iaa.Affine(rotate=(10,10)), iaa.Resize(size=(0.5,3))]
seq = iaa.Sequential(sequence)
aug = seq
aug_img = aug.augment_image(img)
pts = [[377,117,463,117,465,130,378,130],
[493,115,519,115,519,131,493,131],
[374,155,409,155,409,170,374,170],
[492,151,551,151,551,170,492,170],
[376,198,422,198,422,212,376,212],
[494,190,539,189,539,205,494,206],
[374,1,494,0,492,85,372,86]]
for pt in pts:
cv2.polylines(img, [np.array(pt).reshape(-1, 2)] , True, (0, 255,0))
print(img)
aug_pts = []
for pt in pts:
pt = np.array(pt).reshape(-1, 2)
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in pt]
keypoints = aug.augment_keypoints(
[imgaug.KeypointsOnImage(keypoints, shape=img.shape)])[0].keypoints
poly = [(p.x, p.y) for p in keypoints]
aug_pts.append(poly)
for pt in aug_pts:
cv2.polylines(aug_img, [np.array(pt).astype(np.int32)], True, (0,0, 255))
print(aug_img)
随机裁剪图片
从图片中随机裁剪出文本区域,并归一化尺寸为(640,640)
class EastRandomCropData():
def __init__(self, size=(640, 640), max_tries=50, min_crop_side_ratio=0.1, require_original_image=False, keep_ratio=True):
self.size = size
self.max_tries = max_tries
self.min_crop_side_ratio = min_crop_side_ratio
self.require_original_image = require_original_image
self.keep_ratio = keep_ratio
def __call__(self, data: dict) -> dict:
"""
从scales中随机选择一个尺度,对图片和文本框进行缩放
:param data:
:return:
"""
im = data['img']
text_polys = data['text_polys']
ignore_tags = data['ignore_tags']
texts = data['texts']
all_care_polys = [text_polys[i] for i, tag in enumerate(ignore_tags) if not tag]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys)
# crop 图片 保持比例填充
scale_w = self.size[0] / crop_w
scale_h = self.size[1] / crop_h
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)
if self.keep_ratio:
if len(im.shape) == 3:
padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype)
else:
padimg = np.zeros((self.size[1], self.size[0]), im.dtype)
padimg[:h, :w] = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
img = padimg
else:
img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], tuple(self.size))
# crop 文本框
text_polys_crop = []
ignore_tags_crop = []
texts_crop = []
for poly, text, tag in zip(text_polys, texts, ignore_tags):
poly = ((poly - (crop_x, crop_y)) * scale).tolist()
if not self.is_poly_outside_rect(poly, 0, 0, w, h):
text_polys_crop.append(poly)
ignore_tags_crop.append(tag)
texts_crop.append(text)
data['img'] = img
data['text_polys'] = np.float32(text_polys_crop)
data['ignore_tags'] = ignore_tags_crop
data['texts'] = texts_crop
return data
def is_poly_in_rect(self, poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
return False
if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
return False
return True
def is_poly_outside_rect(self, poly, x, y, w, h):
"""判断经过resize变换后多边形上的点是否在裁剪区域内"""
poly = np.array(poly)
if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
return True
if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
return True
return False
def split_regions(self, axis):
regions = []
min_axis = 0
for i in range(1, axis.shape[0]):
if axis[i] != axis[i - 1] + 1:
region = axis[min_axis:i]
min_axis = i
regions.append(region)
return regions
def random_select(self, axis, max_size):
xx = np.random.choice(axis, size=2)
xmin = np.min(xx)
xmax = np.max(xx)
xmin = np.clip(xmin, 0, max_size - 1)
xmax = np.clip(xmax, 0, max_size - 1)
return xmin, xmax
def region_wise_random_select(self, regions, max_size):
selected_index = list(np.random.choice(len(regions), 2))
selected_values = []
for index in selected_index:
axis = regions[index]
xx = int(np.random.choice(axis, size=1))
selected_values.append(xx)
xmin = min(selected_values)
xmax = max(selected_values)
return xmin, xmax
def crop_area(self, im, text_polys):
h, w = im.shape[:2]
h_array = np.zeros(h, dtype=np.int32)
w_array = np.zeros(w, dtype=np.int32)
for points in text_polys:
points = np.round(points, decimals=0).astype(np.int32)
minx = np.min(points[:, 0])
maxx = np.max(points[:, 0])
w_array[minx:maxx] = 1
miny = np.min(points[:, 1])
maxy = np.max(points[:, 1])
h_array[miny:maxy] = 1
# 确保裁剪区域不跨越文本区域
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return 0, 0, w, h
h_regions = self.split_regions(h_axis)
w_regions = self.split_regions(w_axis)
for i in range(self.max_tries):
if len(w_regions) > 1:
xmin, xmax = self.region_wise_random_select(w_regions, w)
else:
xmin, xmax = self.random_select(w_axis, w)
if len(h_regions) > 1:
ymin, ymax = self.region_wise_random_select(h_regions, h)
else:
ymin, ymax = self.random_select(h_axis, h)
if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:
# area too small
continue
num_poly_in_rect = 0
for poly in text_polys:
if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin):
num_poly_in_rect += 1
break
if num_poly_in_rect > 0:
return xmin, ymin, xmax - xmin, ymax - ymin
return 0, 0, w, h
参考链接
1、https://github.com/WenmuZhou/DBNet.pytorch
2、https://github.com/MhLiao/DB