更新日志
- 【2019/01/07 02:43】开始记录自己用SVT数据集训练CRNN模型,写这句话的时候,我的毕业论文还没写完and明天是最后抽盲审的deadline,但是不知为何,我完全不方甚至有心情来写一波博文。
- 【2019/01/07 15:06】早上老师打电话来让我去把盲审抽了,感觉完全不方and也没抽中,感到墙裂舒适。
正文
由于实验室项目需要(实际上是写毕业论文时没有数据),随便上网搜了个数据集,于是发现了SVT数据集
>>>SVT数据集地址点这里<<<
因为不需要识别英文字母,所以用这货,没有别的原因,而且后来看到CRNN的论文里面也使用了SVT进行测试,且达到了很高的acc,所以也让我下定决心用SVT来训练PyTorch下的CRNN。这里使用github上某位大佬写的crnn pytorch版本。
>>>crnn.pytorch代码地址点这里<<<
准备SVT数据集
首先当然是用上面那个链接把SVT下载下来,压缩包中主要包含的是一个.m文件,一个.mat文件以及茫茫多的png文件,且分为训练集,测试集以及一个包含额外训练集的extra包,这里我就只用train文件夹进行训练,用test文件夹下的数据进行测试。由于项目中使用的是lmdb,且懒得再重新写一个Dataset,因此决定生成SVT的lmdb文件,同时也可以学习一下lmdb数据集具体的使用。
>>>生成lmdb数据集的方法看这里<<<
通过下载这个git目录tool下的create_dataset.py,生成自己的lmdb数据。
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
create_dataset函数里主要要注意的参数是imagePathList和labelList,一个是图像文件名的列表,一个是标签的列表,注意,这里不是标签文件名的列表,如果使用了文件名的列表,在生成batch时会报错说字典里找不到“/”或者“.”,原因十分显然了,不要问我为什么知道的(也可能只有我会这么蠢吧,在这里折腾了半天)。在得到了训练集和测试集的lmdb数据后,就可以开始训练模型了。
安装warpctc的pytorch版本
在训练之前,需要装一个CTCLoss函数作为criterion,因为用0.4.0版本的话是没有这个东西的,这里就是天坑之一。按照上面crnn.pytorch代码的索引,来到warp-ctc这里clone下来然后make。按照教程装好后,就能用CTCLoss啦!。。。?
>>>warp-ctc的安装看这里<<<
安装好了warp-ctc可以试着import一下看看是否安装成功,就像这样↓↓↓
当然也可以试试自带的例子(反正我是没试过。。)
开始训练
在安好warpctc就可以使用CTCLoss来算损失函数了,下面是大佬给的训练的命令。
python train.py --adadelta --trainRoot {train_path} --valRoot {val_path} --cuda
然后就会发现。。。
看了别人发的issues,发现也有人这里报错,作者表示PyTorch 0.4下这里还没修好。
>>>issues 139<<<
可以通过加个参数--random_sample解决这个问题。改好后,就可以开始训练了,燃鹅,事情没有这么简单。使用SVT训练,我发现我的Loss全是0。
于是我死皮赖脸地又去翻别人的issues了,发现果然有老哥跟我一样。
>>>issues 167<<<
通过上面这个链接,看到warpctc那边也有人发现了出现这个问题的原因:
下面的评论中有人尝试使用了CPU进行了训练,结果是正确的。我自己测试了一下,发现的确如此。为了解决这个问题,我重新安装了至少5遍的pytorch,试了0.4.0,试了0.4.1,试了1.0.0和1.0.12【具体的版本号忘了】,也用conda和pip分别尝试了多次后,依旧不行。直到某一次在安装warpctc跟这个老哥报了一样的错:
>>>warpctc_pytorch with cpp_extension<<<
我看到了Soumith大佬的评论↓↓↓
原来1.0已经自带了CTCLoss啊,真是……%¥%¥#()*了,于是我马上重新装回了1.0,然后把warpctc给换成了nn.CTCLoss(),燃鹅这还没完,新的问题又出现了,在训练了几个iteration后损失函数会变成nan。
官方的答复在这,好像的确是个bug。并发了一个推,my god。
>>>具体的问题在这<<<
解决方法是用钩子函数把nan的梯度直接变为0。
终于啊,苍天啊,终于可以训练了,然后你又会发现adam训练不出来,于是我直接用了adadelta,并且把学习率调整为了0.01,这里就不再废话。
训练过程大概是这个样子的↓↓↓
华中大佬的论文里的正确率是在98.3%貌似,我这里50个epoch里最高只有95.5%。接下打算自己训练一下FOTS试试。