一、最近工作
很久都没有写简书了,不知不觉进入这个热爱的行业已经七个月了,这其中基于机器学习做过简单的商品销量预测(还有seq2seq的方案),作为主力算法开发人员搭建了基于Tensorflow Hub的AutoML图像分类模块,其中空闲时间还学习了yolo、faster-rcnn等目标检测算法。现在的主要工作是开发在自然场景下端到端的OCR算法。代码Train部分整个流程已经写完,但是由于是自己写的算法,缺乏预训练模型,也就是必须要自己重新训练一个完整的网络。这其中也有一个巨大的难题,数据。
二、人工合成数据
网上公开的OCR方面的真实数据集少之又少,显然直接用真实数据集去训练效果肯定大打折扣,所以想到借助于人工合成的数据集预训练网络后在用真实数据集进行训练。感谢VGG 实验室这篇 CVPR2016 的Paper,提供了人工合成数据的方法,更重要的是公布了代码!!!并且提供了41G的基于谷歌图像生成的OCR数据!!!我这里因为懒直接使用了已经生成的数据,如果对字体方面有特殊需求的可以尝试自己去人工合成。附上论文和Github地址。
Paper:Synthetic Data for Text Localisation in Natural Images
GitHub : Source code
三、数据转换
下载下来的数据包含200个文件夹分别存放的图片,还有一个mat格式的标注信息。之前写好的读取数据的脚本的代码都是读取XML文件,这里的mat文件我也做了转换,因为标注的labels也需要做一定的清洗。我也就直接放代码了。其中涉及了mat文件格式的读取分析数据结构,XML文件的生成(之前没有生成过XML文件),python多进程处理(这部分比较重要,在数据处理上用的比较多,写过好几个处理数据的脚本都用了多进程,打算找个时间把这部分也总结一下)。
import os
import numpy as np
import scipy.io as sio
from PIL import Image
from xml.dom.minidom import Document
import multiprocessing as mp
import pandas as pd
def main_fun(i, impath,
imnames,
labels,
wordbb,
xml_path):
# 得到图片的路径和宽高
image_path = os.path.join(impath, imnames[0])
im = Image.open(image_path)
imgwidth, imgheight = im.size
# 创建DOM文档对象
doc = Document()
# 创建根元素
root = doc.createElement('annotation')
doc.appendChild(root)
# 创建filename size
doc_filename = doc.createElement('filename')
doc_size = doc.createElement('size')
# 创建 width height depth
doc_width = doc.createElement('width')
doc_height = doc.createElement('height')
doc_depth = doc.createElement('depth')
# 添加子节点
doc_size.appendChild(doc_width)
doc_size.appendChild(doc_height)
doc_size.appendChild(doc_depth)
root.appendChild(doc_filename)
root.appendChild(doc_size)
# 写入filename width height depth
doc_filename.appendChild(doc.createTextNode(imnames[0]))
doc_width.appendChild(doc.createTextNode(str(imgwidth)))
doc_height.appendChild(doc.createTextNode(str(imgheight)))
doc_depth.appendChild(doc.createTextNode('3'))
nums = wordbb.shape[-1]
# 提取label
label_list = []
for label in labels.tolist():
ll = label.strip().split('\n')
for l in ll:
label_list.extend(l.split())
if nums != len(label_list):
return {image_path: i}
# 确定bounding box 的坐标
bbox = wordbb.transpose((2, 1, 0))
for idx in range(nums):
p1 = bbox[idx].min(axis=0)
p2 = bbox[idx].max(axis=0)
# 创建object作为root子节点
doc_object = doc.createElement('object')
root.appendChild(doc_object)
# 创建name节点并作为object子节点,写入数据
doc_name = doc.createElement('name')
doc_object.appendChild(doc_name)
doc_name.appendChild(doc.createTextNode(label_list[idx]))
# 创建boundbox节点并作为object子节点
doc_bndbox = doc.createElement('bndbox')
doc_object.appendChild(doc_bndbox)
# 创建x1 x2 y1 y2节点为bndbox的子节点,写入数据
doc_x1 = doc.createElement('x1')
doc_y1 = doc.createElement('y1')
doc_x2 = doc.createElement('x2')
doc_y2 = doc.createElement('y2')
# 添加bndbox节点
doc_bndbox.appendChild(doc_x1)
doc_bndbox.appendChild(doc_y1)
doc_bndbox.appendChild(doc_x2)
doc_bndbox.appendChild(doc_y2)
# 写入object数据
doc_x1.appendChild(doc.createTextNode(str(p1[0])))
doc_y1.appendChild(doc.createTextNode(str(p1[1])))
doc_x2.appendChild(doc.createTextNode(str(p2[0])))
doc_y2.appendChild(doc.createTextNode(str(p2[1])))
xml_name = imnames[0].replace('/', '_') + '.xml'
path = os.path.join(xml_path, xml_name)
f = open(path, 'w')
doc.writexml(
f,
indent='\t',
newl='\n',
addindent='\t',
encoding='utf-8')
f.close()
if i % 1000 == 0:
print('steps:{}'.format(i))
return {'success': i}
if __name__ == '__main__':
# 为了方便我这里直接写死三个路径
mat_path = '/home/wjj/ocr_data/gt.mat'
impath = '/home/wjj/ocr_data/SynthText'
xml_path = '/home/wjj/ocr_data/SynthXML'
if not os.path.exists(xml_path):
os.mkdir(xml_path)
# 读取matlab文件
data = sio.loadmat(mat_path)
print('successful load {}'.format(mat_path))
labels = data['txt'][0]
imnames = data['imnames'][0]
wordbb = data['wordBB'][0]
# 创建进程池
pool = mp.Pool(25)
dict = dict()
result = []
for idx in range(len(labels)):
d = pool.apply_async(main_fun, (idx, impath,
imnames[idx],
labels[idx],
wordbb[idx],
xml_path,))
result.append(d)
# if idx % 1000 == 0:
# print('steps:{}'.format(idx))
pool.close()
pool.join()
for res in result:
dict.update(res.get())
df = pd.DataFrame(index=dict.keys(), data=np.array(list(dict.values())))
df.to_csv('/home/wjj/ocr_data/error_images.csv')
四、最后
因为懒,很多细节都没有写。但是之前很多学到的东西一直遗忘,还是希望自己能够坚持多写,哪怕只是简简单单的贴个代码。希望自己训练的模型能够达到理想的效果。