文献编号:5
文献著作信息:
Vision Transformer for Fast and Efficient Scene Text Recognition
论文地址
代码地址
18 May 2021
研究主题:
Scene text recognition
Transformer
Data augmentation
研究问题:
低参数量,低计算量的STR模型
主要是精度不变情况下的提速
研究原因:
因为大多数只考虑了识别的精度,并没有考虑到移动设备的需求
我的收获和疑问
为了平衡准确性、速度和效率的重要性,作者建议利用视觉转换器(Vit)的简单和高效的优势。例如数据高效图像转换器(Deit)
Deit.pdf (arxiv.org)
ViT证明,仅使用transformer的encoder(好几个连起来)就可以实现ImageNet识别中得到SOTA结果。
ViT继承了transformer的所有特性,包括速度和计算效率
作者的框架也是这样做的,因为图片,也需要位置编码
用了【我的参考文献3】的框架,相同框架下,才能更好的比较不同模型的性能
MJ和ST 各用50%,如果用100%性能会下降
再自己写论文时,可以把自己的参数设置写成表格呈现给审稿人
研究设计:
作者试图平衡准确性、速度和效率。准确性是指识别文本的正确性。速度是通过单位时间内处理多少文本图像来衡量的。效率可以用处理一张图像所需的参数和计算(如FLOPS)的数量近似表示。参数的数量反映内存需求,而FLOPS估计完成任务所需的指令数量。理想的STR是精确和快速的,而只需要很少的计算资源。
研究发现:
使用Deit的模型权重,Deit简单地是通过知识蒸馏训练的VIT
对于机器来说,在人类环境中阅读文本是一项具有挑战性的任务,因为符号的可能外观不同。图2显示了受曲率、字体样式、模糊、旋转、噪声、几何图形、照明、遮挡和分辨率影响的文本的示例。还有许多其他因素可能会影响文本图像,例如天气条件、相机传感器缺陷、运动、照明等。
研究结论:
通过针对STR的数据增强,ViTSTR可以显著提高准确性,特别是对于不规则数据集。
当规模扩大时,ViTSTR保持在前沿,以平衡精度、速度和计算要求。
带问题看论文:
相关工作
字符串以正确的顺序标识图像中文本的每个字符。与通常只有一类对象的对象识别不同,对于给定的文本图像,可以有零个或多个字符。因此,STR模型更加复杂。与许多视觉问题类似,早期的方法[24,38]使用手工制作的特征,导致性能较差。深度学习极大地推动了STR领域的发展。
2019年,Baek等人提出【我的参考文献编号3】。[1]提出了一个对现代STR设计模式进行建模的框架。图3显示了STR的四个阶段或模块。从广义上讲,即使是最近提出的基于变压器的模型、无递归序列对序列文本识别器(NRTR)[29]和自注意文本识别网络(SATRN)[18]等方法也可以适用于校正-特征提取(Backbone)-序列建模-预测框架
校正阶段去除文字图像的失真,使文本水平或规范化。这使得特征提取(Backbone)模块更容易确定不变特征。薄板样条(TPS)[3]通过寻找和校正基准点来模拟畸变。RARE(带有自动校正的健壮文本识别器)[31]、STAR-Net(空间注意残留网络)[21]和TRBA (TPS- resnet - bilstm -Attention)[1]使用TPS。ESIR(端到端可训练场景文本识别)[41]采用迭代校正网络,显著提高了文本识别模型的性能。在某些情况下,没有采用整改,如CRNN卷积循环神经网络[30],R2AM(带有注意力建模的递归循环神经网络)[17],GCRNN(门控循环卷积神经网络)[36]和Rosetta[4]
特征提取(Backbone)阶段的作用是自动确定每个字符符号的不变特征。STR在对象识别任务中使用相同的特征提取器,如VGG[32]、ResNet[11]和CNN的一个变体RCNN[17]。Rosetta, STAR-Net和TRBA使用ResNet。利用VGG提取RARE和CRNN特征。R2AM和GCRNN建立在RCNN的基础上。基于变压器的模型NRTR和SATRN使用定制的CNN块来提取变压器编码器-解码器文本识别的特征
预测阶段检查由主干或序列建模产生的特征,以达到字符预测序列。CTC(连接主义时间分类)[8]通过有效地对所有可能的输入-输出序列对齐进行求和,最大限度地提高了输出序列的可能性。CTC的替代方案是注意力机制[2],它学习图像特征和符号之间的对齐。CRNN, GRCNN, Rosetta和STAR-Net使用CTC。R2AM, RARE和TRBA是基于注意力的
与自然语言处理(NLP)一样,变压器通过并行的自我注意和预测克服了序列建模和预测。这就产生了一个快速有效的模型。如图3所示,基于电流互感器的STR模型仍然需要一个骨干和一个变压器编码器-解码器。最近,ViT[7]证明了它可以在ImageNet1k[28]分类上仅使用变压器编码器,但在非常大的数据集(如ImageNet21k和JFT-300M)上预先训练它,从而击败诸如ResNet[11]和efficiency entnet[33]等深度网络的性能。DeiT[34]证明了ViT不需要大数据集,甚至可以获得更好的结果,但必须使用知识蒸馏[13]进行训练。ViT是使用预先训练的DeiT权重的基础,是我们提出的快速有效的STR称为ViTSTR的基础。如图3所示,ViTSTR是一个非常简单的模型,只有一级,可以轻松地将基于变压器的STR的参数数量和FLOPS减少一半。
ViT和ViTSTR之间的唯一区别是预测头。ViTSTR必须识别具有正确序列顺序和长度的多个字符,而不是单一对象类识别。预测是并行进行的
在原始的ViT中,使用与可学习类嵌入相对应的输出向量进行对象类别预测。在ViTSTR中,这对应于[GO]令牌。此外,我们不再只提取一个输出向量,而是从编码器中提取多个特征向量。这个数字等于数据集中文本的最大长度加上两个[GO]和[s]令牌。我们使用[GO]标记标记文本预测的开始,并使用[s]标记注明结尾或空格。[s]在每个文本预测的末尾重复,直到最大序列长度,以标记文本字符之后没有任何内容。
图5显示了一个编码器块内的层。每个输入都经过层归一化(LN)。多头自注意层(Multi-head Self-Attention layer, MSA)确定特征向量之间的关系。Vaswani等人[35]发现,使用多个头部而不是一个头部可以让模型共同关注来自不同位置的不同表示子空间的信息。头部数为h。多层感知器(Multilayer Perceptron, MLP)进行特征提取。它的输入也是层规范化的。MLP由2层组成,GELU激活[12]。LN的输出与MSA/MLP之间存在残差连接。
or l = 1…L为编码器块的深度或数量
for i = 1…S是[GO]和[S]令牌的最大文本长度加2。表1总结了ViTSTR配置。
作者用了【我的参考文献3】的框架
为了对不同的模型做出公平的评价,一个统一的框架是很重要的。统一的框架确保在评估中使用一致的训练和测试条件。下面的讨论描述了在性能比较中一直存在争议的训练数据集和测试数据集。使用不同的训练和测试数据集可能会严重倾向于支持或反对某种性能报告。
数据集
由于缺乏大数据集的真实数据,STR模型训练的实践是使用合成数据。使用两个流行的数据集:1)MJSynth (MJ)[14]或也称为Synth90k和2)SynthText (ST)[9]。
MJ
MJSynth (MJ)是一个合成生成的数据集,由890万逼真的文字图像组成。MJSynth被设计成有3层:1)背景,2)前景和3)可选的阴影/边框。它使用了1400种不同的字体。字体的字距、粗细、下划线和其他属性是不同的。MJSynth还利用了不同的背景效果,边界/阴影渲染,基础着色,投影失真,自然图像混合和噪声。
ST
SynthText (ST)是另一个由550万单词图像合成生成的数据集。SynthText是通过在自然图像上混合合成文本生成的。它使用场景几何、纹理和表面法线来自然地混合和扭曲图像中物体表面上的文本渲染。与MJSynth类似,SynthText的文本使用随机字体。文字图像是从嵌入合成文本的自然图像中裁剪出来的
在STR框架中,每个数据集占整个列车数据集的50%。将两个数据集100%地结合在一起会导致性能下降[我的参考文献3]。图6显示了来自MJ和ST的示例图像
测试数据集是由几个小的公开的自然图像文本STR数据集组成的。这些数据集通常分为两组:1)常规和2)不规则
常规数据集的文本图像是正面的,水平的,并且有最小的失真。IIIT5K-Words[23],街景文本(SVT) [37], ICDAR2003 (IC03)[22]和ICDAR2013 (IC13)[16]被认为是常规数据集。同时,不规则数据集包含具有挑战性外观的文本,如弯曲、垂直、透视、低分辨率或扭曲。ICDAR2015 (IC15)[15]、SVT Perspective (SVTP)[25]和CUTE80 (CT)[27]属于不规则数据集。图7显示了来自规则和不规则数据集的样本。对于两个数据集,只有测试分割用于评估
规则数据集
IIT5K包含3000张用于测试的图像。图像大多来自街景,如招牌、品牌标志、门牌号或路牌。
SVT有647张图片用于测试。文本图像是从谷歌街景图片裁剪。
IC03包含来自ICDAR2003健壮阅读比赛的1,110张测试图像。图像是从自然场景中捕捉的。在删除长度小于3个字符的单词后,结果是860张图像。然而,另外7张图片被发现丢失了。因此,该框架还包含867个测试图像版本。—IC13是
IC03的扩展,共享类似的镜像。IC13是为ICDAR2013健壮阅读比赛而创建的。在文献和框架中,使用了两个版本的测试数据集:1)857和2)1015。
不规则的数据集
IC15有ICDAR2015健壮阅读比赛的文本图片。许多图像模糊、嘈杂、旋转,有时分辨率很低,因为这些图像是使用谷歌眼镜拍摄的,佩戴者处于无约束运动状态。文献和框架中使用了两个版本:1)1811张和2)2077张图像。2077个版本包含旋转、垂直、透视和弯曲的图像。
SVTP有645张来自谷歌街景的测试图像。大多数是商业标牌的图片。-
CT专注于从衬衫和产品标志中捕获的弯曲文本图像。该数据集有288张图像。
表2列出了框架中推荐的培训配置。我们复制了几个强基线模型的结果:CRNN, R2AM, GCRNN, Rosetta, RARE, STAR-Net和TRBA,以与ViTSTR进行公平的比较。我们使用不同的随机种子对所有模型进行至少5次训练。保存测试数据集上表现最好的权重以获得平均评估分数。
对于ViTSTR,我们使用相同的列车配置,除了输入被调整为224 × 224,以匹配预训练的DeiT[34]的尺寸。在训练ViTSTR之前,会自动下载DeiT预训练的权重文件。ViTSTR可以端到端训练,没有冻结参数
表3和表4显示了不同模型的性能得分。我们报告了准确性、速度、参数数量和FLOPS,以得到折衷的总体情况,如图1所示。为了准确性,我们在大多数STR模型的大小写敏感训练和大小写不敏感评估中遵循框架评估协议。对于速度,报告的数字是基于2080Ti GPU上的模型运行时间。与其他模型基准(如[19,20])不同,在评估之前,我们不旋转垂直文本图像(例如,表5 IC15)。
数据增强
使用专门针对STR的数据增强配方可以显著提高ViTSTR的准确性,在图8中,我们可以看到不同之处
数据扩充会改变图像,但不会改变其中文本的含义。表3显示,对不同的图像变换(如反转、弯曲、模糊、噪声、扭曲、旋转、拉伸/压缩、透视和收缩)应用RandAugment[6]后,ViTSTR-TINY的通用性提高了+1.8%,ViTSTR-Small的通用性提高了+1.6%,ViTSTR-Base的通用性提高了1.5%。准确率提高最大的是不规则数据集,例如CT(+9.2%极小,+6.6%小和基本)、SVTP(+3.8%极小,+3.3%小,+1.8%基本)、IC15 1,811(+2.7%极小,+2.6%小,+1.7%基本)和IC15 2,077(+2.5%极小,+2.2%小,+1.5%基本)。
注意力
图9显示了ViTSTR读出文本图像时的注意图。当注意力适当地集中在每个字符上时,ViTSTR也会关注相邻的字符。也许,上下文是在单个符号预测期间放置的。
STR模型的 性能惩罚
在STR模型中每增加一个阶段,就会获得一个精度,但代价是速度变慢和计算量增加。例如,RARE↪→TRBA提高了2.2%的准确率,但需要388m的参数,并将任务完成速度降低了4 msec/image。像STAR-Net↪→TRBA那样将CTC阶段替换为Attention,将计算速度从8.8 msec/张图像显著降低到22.8 msec/张图像,从而获得额外的2.5%的精度。事实上,从CTC到Attention的变化所带来的放缓,与在管道中添加BiLSTM或TPS相比,是> 10倍。在ViTSTR中,从小版本到小版本的过渡需要增加嵌入尺寸和头部数量。不需要额外的阶段。为了获得2.3%的精度,性能损失是参数数量增加16.1M。从微小到基本,获得3.4%的精度的性能惩罚是额外的80.4M参数。在这两种情况下,速度几乎没有变化,因为我们在MLP和MSA中使用了相同的并行张量点积、softmax和加法运算层的变压器编码器。只有张量维度增加,导致任务完成速度降低0.2到0.3 msec/图像。与多级STR不同,额外的模块需要额外的连续的前向传播层,这不能并行化,从而导致显著的性能损失
失败案例
表5显示了ViTSTR-Small在每个测试数据集中失败的预测样本。导致预测错误的主要原因是相似符号混淆(如8和B, J和I),脚本字体(如Inc中的I),字符眩光,垂直文本,严重弯曲的文本图像和部分遮挡的符号。请注意,在某些情况下,即使是人类读者也很容易犯错误。然而,人类使用语义来解决歧义。语义已经在最近的STR方法中使用了[26,39]
代码阅读
def get_args(is_train=True):
parser = argparse.ArgumentParser(description='STR')
# for test
parser.add_argument('--eval_data', required=not is_train, help='path to evaluation dataset')
parser.add_argument('--benchmark_all_eval', action='store_true', help='evaluate 10 benchmark evaluation datasets')
parser.add_argument('--calculate_infer_time', action='store_true', help='calculate inference timing')
parser.add_argument('--flops', action='store_true', help='calculates approx flops (may not work)')
# for train
parser.add_argument('--exp_name', help='Where to store logs and models')
parser.add_argument('--train_data', required=is_train, help='path to training dataset')
parser.add_argument('--valid_data', required=is_train, help='path to validation dataset')
parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting')
parser.add_argument('--workers', type=int, help='number of data loading workers. Use -1 to use all cores.', default=4)
parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for')
parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation')
parser.add_argument('--saved_model', default='', help="path to model to continue training")
parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning')
parser.add_argument('--sgd', action='store_true', help='Whether to use SGD (default is Adadelta)')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)')
parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta')
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9')
parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95')
parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5')
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
""" Data processing """
parser.add_argument('--select_data', type=str, default='MJ-ST',
help='select training data (default is MJ-ST, which means MJ and ST used as training data)')
parser.add_argument('--batch_ratio', type=str, default='0.5-0.5',
help='assign ratio for each selected data in the batch')
parser.add_argument('--total_data_usage_ratio', type=str, default='1.0',
help='total data usage ratio, this ratio is multiplied to total number of data.')
parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
parser.add_argument('--rgb', action='store_true', help='use rgb input')
parser.add_argument('--character', type=str,
default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode')
""" Model Architecture """
parser.add_argument('--Transformer', action='store_true', help='Use end-to-end transformer')
choices = ["vitstr_tiny_patch16_224", "vitstr_small_patch16_224", "vitstr_base_patch16_224", "vitstr_tiny_distilled_patch16_224", "vitstr_small_distilled_patch16_224"]
parser.add_argument('--TransformerModel', default=choices[0], help='Which vit/deit transformer model', choices=choices)
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
parser.add_argument('--FeatureExtraction', type=str, required=True,
help='FeatureExtraction stage. VGG|RCNN|ResNet')
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. None|CTC|Attn')
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
parser.add_argument('--input_channel', type=int, default=1,
help='the number of input channel of Feature extractor')
parser.add_argument('--output_channel', type=int, default=512,
help='the number of output channel of Feature extractor')
parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
# selective augmentation
# can choose specific data augmentation
parser.add_argument('--issel_aug', action='store_true', help='Select augs')
parser.add_argument('--sel_prob', type=float, default=1., help='Probability of applying augmentation')
parser.add_argument('--pattern', action='store_true', help='Pattern group')
parser.add_argument('--warp', action='store_true', help='Warp group')
parser.add_argument('--geometry', action='store_true', help='Geometry group')
parser.add_argument('--weather', action='store_true', help='Weather group')
parser.add_argument('--noise', action='store_true', help='Noise group')
parser.add_argument('--blur', action='store_true', help='Blur group')
parser.add_argument('--camera', action='store_true', help='Camera group')
parser.add_argument('--process', action='store_true', help='Image processing routines')
# use cosine learning rate decay
parser.add_argument('--scheduler', action='store_true', help='Use lr scheduler')
parser.add_argument('--intact_prob', type=float, default=0.5, help='Probability of not applying augmentation')
parser.add_argument('--isrand_aug', action='store_true', help='Use RandAug')
parser.add_argument('--augs_num', type=int, default=3, help='Number of data augment groups to apply. 1 to 8.')
parser.add_argument('--augs_mag', type=int, default=None, help='Magnitude of data augment groups to apply. None if random.')
# for comparison to other augmentations
parser.add_argument('--issemantic_aug', action='store_true', help='Use Semantic')
parser.add_argument('--isrotation_aug', action='store_true', help='Use ')
parser.add_argument('--isscatter_aug', action='store_true', help='Use ')
parser.add_argument('--islearning_aug', action='store_true', help='Use ')
# orig paper uses this for fast benchmarking
parser.add_argument('--fast_acc', action='store_true', help='Fast average accuracy computation')
parser.add_argument('--infer_model', type=str,
default=None, help='generate inference jit model')
parser.add_argument('--quantized', action='store_true', help='Model quantization')
parser.add_argument('--static', action='store_true', help='Static model quantization')
args = parser.parse_args()
return
传参
opt = get_args()
模型
请忽略缩进,需要源代码可去github上下载
class Model(nn.Module):
def __init__(self, opt):
super(Model, self).__init__()
self.opt = opt
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction,
'ViTSTR': opt.Transformer}
""" Transformation """
if opt.Transformation == 'TPS':
self.Transformation = TPS_SpatialTransformerNetwork(
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
else:
print('No Transformation module specified')
if opt.Transformer:
self.vitstr= create_vitstr(num_tokens=opt.num_class, model=opt.TransformerModel)
return
""" FeatureExtraction """
if opt.FeatureExtraction == 'VGG':
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'RCNN':
self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'ResNet':
self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
else:
raise Exception('No FeatureExtraction module specified')
self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
""" Sequence modeling"""
if opt.SequenceModeling == 'BiLSTM':
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
self.SequenceModeling_output = opt.hidden_size
else:
print('No SequenceModeling module specified')
self.SequenceModeling_output = self.FeatureExtraction_output
""" Prediction """
if opt.Prediction == 'CTC':
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
elif opt.Prediction == 'Attn':
self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
else:
raise Exception('Prediction is neither CTC or Attn')
def forward(self, input, text, is_train=True, seqlen=25):
""" Transformation stage """
if not self.stages['Trans'] == "None":
input = self.Transformation(input)
if self.stages['ViTSTR']:
prediction = self.vitstr(input, seqlen=seqlen)
return prediction
""" Feature extraction stage """
visual_feature = self.FeatureExtraction(input)
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
visual_feature = visual_feature.squeeze(3)
""" Sequence modeling stage """
if self.stages['Seq'] == 'BiLSTM':
contextual_feature = self.SequenceModeling(visual_feature)
else:
contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
""" Prediction stage """
if self.stages['Pred'] == 'CTC':
prediction = self.Prediction(contextual_feature.contiguous())
else:
prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length)
return prediction
没有数据增强的训练
RANDOM=$$
CUDA_VISIBLE_DEVICES=0 python3 train.py --train_data data_lmdb_release/training
--valid_data data_lmdb_release/evaluation --select_data MJ-ST
--batch_ratio 0.5-0.5 --Transformation None --FeatureExtraction None \
--SequenceModeling None --Prediction None --Transformer
--TransformerModel=vitstr_tiny_patch16_224 --imgH 224 --imgW 224
--manualSeed=$RANDOM --sensitive
无特征提取,序列模型,只有transformer