街景字符识别比赛所用的数据集包括图像和JSON标注。训练集数据包括3W张照片,验证集数据包括1W张照片。数据的标注使用JSON格式,并使用文件名进行索引。对于赛题和数据集的更多信息,可参考街景字符编码识别-赛题解析。
下面我们将构建读取比赛的数据集,首先生成数据名列表的csv文件以方便后面dataloader处理:
import os
import csv
DirList = os.listdir(ImgPath)
## write data list
with open(outPath+'train.csv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
for row in DirList:
writer.writerow([row])
读入图像数据和JSON数据:
## 读入图像
from PIL import Image
im = Image.open(mainPath+'mchar_train/mchar_train/000000.png')
## 读入JSON
## 读入后为一个字典对象,key为图像名,value为对应标签
import json
import numpy as np
with open(json_trainDir, 'r') as f:
data = json.load(f)
print(data['000000.png'])
print(data['000000.png']['label'])
输出
{'height': [219, 219],
'label': [1, 9],
'left': [246, 323],
'top': [77, 81],
'width': [81, 96]}
[1, 9]
读入数据后,在载入网络前为了增加训练集数据数量和类型,我们要进行数据增广。
数据增广包括几何变换类如平移,旋转,翻转,缩放;图像色彩分布改变如直方图均衡,亮度色度调整。也有一些针对特定任务的如加噪等。此处详细可参考数据增广之详细理解
对于街景字符识别任务,可利用torchvision很方便的进行数据增广。关于torchvision中transforms的使用,可参考pytorch中transform常用的几个方法
我们将赛题抽象为一个定长字符识别问题,在赛题数据集中大部分图像中字符个数为2-4个,最多的字符个数为6个。因此将问题抽象为6个字符的识别问题,字符abc填充为abcXXX,X取"10",“0”~ "9"对于标签0~9。
import numpy as np
import os
import csv
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
class TrainDataLoader(Dataset):
def __init__(self, root, csvPath, json_Dir):
data = []
self.root = root
with open(csvPath, 'r') as csvfile:
csv_reader = csv.reader(csvfile)
for row in csv_reader:
data.append(row[0])
with open(json_Dir, 'r') as f:
info = json.load(f)
self.dataList = data
self.InfoDict = info
self.num = len(self.dataList)
def __len__(self):
return self.num
def ImgProcess(self, img):
# ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
# RandomRotation(degrees, resample=False, expand=False, center=None) 在(-degrees,+degrees)之间随机旋转
# transforms.ToTensor, 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
transform = transforms.Compose([transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(30),
transforms.ToTensor()])
imTensor = transform(img) # H,W,C,N
return imTensor
def __getitem__(self, idx):
# print('data path: ', self.root+self.dataList[idx])
imgName = self.dataList[idx]
img = Image.open(self.root+imgName)
imgInfo = self.InfoDict[imgName]
imgTensor = self.ImgProcess(img)
label = imgInfo['label']
label += [10]*(6-len(label)) ## 标签字符填充
sample = {'image':imgTensor, 'label':label}
return sample
定义好DataLoader逐batch载入数据
import matplotlib.pyplot as plt
DataRootPath = 'dir-to-your-data'
trainImgPath = dataPath+'mchar_train/mchar_train/'
ValImgPath = dataPath+'mchar_val/mchar_val/'
trainLabelPath = dataPath+'mchar_train.json'
ValLabelPath = dataPath+'mchar_val.json'
BATCH_SIZE = 1
train_dataset = TrainDataLoader(trainImgPath , DataRootPath+'mchar_train/train.csv', trainLabelPath )
train_num = len(train_dataset)
train_loader = DataLoader(dataset = train_dataset, batch_size = BATCH_SIZE, shuffle = True)
for step, sample in enumerate(train_loader):
if(step==10): break #输出10个样本观察
imgData = sample['image']
label = sample['label']
print('label:', label)
print('img size: ',imgData.size())
imgNp = imgData.squeeze_(0).numpy().transpose(1,2,0)
plt.imshow(imgNp)
plt.show()
以上,我们就完成了数据的读取和增广,下一步是选取合适的baseline进行训练。