本来想把关于CTC的所有东西都写在一篇文章,但后面发现内容太多,遂拆分成如下几个部分:
CTC算法详解之训练篇
CTC算法详解之测试篇
CTC算法详解之总结展望篇(待更)
引言
在日常生活中,许多数据是序列化的,比如语音、文字和图像文本等。在处理序列任务时,一个经典的思路是“分而治之”,把输入序列拆分成最小语义单元,然后将序列任务转换成对单元的分类任务。然而在实际应用中,把序列中的单元精准地分开是很难的,人工标注的代价也很大,可不可以直接对序列数据进行“端到端”地预测呢?CTC(Connectionist Temporal Classification)则是解决了这样一个问题。CTC算法可以让以端到端地方式对序列数据进行学习,在语音识别、图像文字识别等领域取得了很好的应用效果。
本文先对CTC用于序列任务的流程做了大致介绍,并定义了相关的符号。然后再训练一节中介绍了如何通过前后向算法计算CTC的损失函数。yudonglee的博客[2]给了我很多帮助,训练一节中的很多图也是采用他的。我在写的过程中也是在不断学习,有错误和不到位的地方希望大家指出。
任务定义
CTC用于序列任务,流程大致如下:神经网络把输入序列转换成序列在字典上的概率分布,从这个分布中我们可以得到若干条路径,每个路径都可以转换成输出序列,我们的任务就是找到输出概率最大的序列。具体的符号定义如下:
输入序列长度为,用表示神经网络,用于提取序列特征,网络的输出为,长度也为,用 表示输出单元的激活概率,即序列在时刻被分类成的概率,定义在类别集合上,为任务字典符号集,比如在文字识别任务中可以定义成中英文字符,为CTC的 保留符号,用于分隔标签中的不同符号单元。CTC是一种过分割的序列解码算法,比如标签中的一个字符a可能在译码路径中被切分成多个连续的a,而标签中也可能存在连续但应该被区分的字符,比如apple中出现了两个p,那这时候译码路径要在这两个不同的p之间插入至少一个。由网络的输出可以计算任意译码路径的概率。与输入等长,我们最终是要得到标签,的长度小于的长度,所以还要定义一个映射来将译码路径转换为标签。映射规则为:移除所有空白符号并合并所有的重复连续符号,比如。可以看到,这个映射是多对一映射(many-to-one),也就是说正确的标签可以来自许多不同的路径(不管是黑猫还是白猫,只要能捉到耗子就是好猫),后面我们会重点研究many-to-one所带来的计算速度和可微问题,尤其是在训练阶段。整个序列识别的流程可以参见下图:
训练
为了使整个网络可用梯度下降优化,训练过程中必须算出可导的CTC损失函数,CTC也采用了常规的分类任务的最大似然误差(maximum likelihood error):。因为B是many-to-one映射的缘故,,计算Loss要穷举所有可行路径,然而穷举所有的路径是非常困难的,因为其空间复杂度为(N为字典大小,T为路径长度),所以[1]借鉴了HMM中的前后向算法(Forward-Backward Algorithm,FBA),这是一种动态规划算法,下面我们来说一下算法思路。
在种路径中,只有很少的一部分路径是有效的,我们只需要考虑这一小部分路径就行了。当我们把所有可行路路径列出来,会发现,如果按时间展开译码过程,我们可以以递推的方式计算出某个节点的前向(时间增大的方向)或后向(时间减小的方向)路径概率总和。这也是算法名称的由来。我们会先用一个”apple”的例子来直观解释FBA算法的递推关系,最后给出计算式。
给定一个标签,长度为,为了找出所有满足的路径,我们要构建一个拓展标签,它是在原始标签的首尾和每个字符中间加上空格符号得到的,长度为,比如当时。我们接下来的搜索过程都在由展开的搜索栅格上进行。
然而并不是在图3上的任意一条路径都是合法的,合法的路径要满足如下几点条件:
(1)转换只能向右或右下(纵轴上单调非减)
(2)相同的字符间至少有一个空格,否则标签中的连续相同字符会被错误地合并;
(3)除{blank}符号外不能跳过;
(4)路径起点必须从前两个符号开始,即或;
(5)路径必须在最后两个符号结束。
最终所有可能的路径如下:
读者可以自行验证,在由构成的搜索栅格上,遵守上述5条规则可以得到所有的正确的路径。我们并不需要穷举所有的种路径就能计算出想要的结果,这就是动态规划的核心思想:提前剔除掉不可能的结果,在更小的搜索空间上进行计算。
如何计算图5路径的概率总和呢?我们首先定义为t时刻取值为s的全部前缀路径概率总和:
累乘符号表示同一路径上的不同节点概率相乘,累加符号表示不同路径的概率相加。比如(t2,a)这个点的左边和左上角各有一条前向路径,即为这两条路径的概率之和。
可以用递推方式求得,我们分三种情况讨论,在t时刻:
(1)s取值为时,,参见图6的红圈节点,有效的前向路径可来自左边或左上,左侧没有有效路径意味着。
(2)s取值和s-2取值一样时,,参见图6的蓝圈节点。
(3)其余情况下,,参见图6的黑色圈。
所有情况可汇总如下:
初态:
最终的CTC损失函数为:。因为整个计算过程涉及到的运算都是可微的,所以可以用链式求导计算导数,进行反向传播。类似的,也可以用反向路径概率和来表示损失函数,读者可在[1]或[2]中找到相应内容。
参考资料
[1] Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks
[2] CTC Algorithm Explained Part 1:Training the Network(CTC算法详解之训练篇)
[3] Facebook大规模文本检测与识别系统Rosetta
[4] CTC Networks and Language Models: Prefix Beam Search Explained