弥合机器翻译中训练和推理之间的差距
链接: https://arxiv.org/pdf/1906.02448
机器翻译基于语境序列化地生成预测的目标单词。在训练时,它用真实单词作为语境,推理时则必须从零开始生成整个序列。这个是否提供语境的矛盾导致了错误的累积。此外,单词级别的训练需要生成序列和真实序列之间的严格匹配,这会导致对已经是合理的翻译的过度矫正。该文用采样不仅来自真实序列的语境单词也有来自模型训练时预测序列的单词的方式来解决这些问题,预测序列是来自句子级别的最优选择。在中英和英德翻译任务上的实验结果说明该方法在多种数据集上都有明显的提升。
该文的模型首先从预测单词中选取oracle单词并从oracle单词和真实单词中进行采样以作为语境。同时,oracle单词的选取不只是逐个单词的搜索,而且带有句子级别的评估即BLEU,这使交叉熵成对匹配的限制有了更大的灵活性。训练初始阶段,模型给予真实单词语境更高的概率 。当模型逐渐收敛时,oracle单词更频繁地被选取作为语境。在这种机制下,模型在推理时有机会学习处理错误并对可选翻译的过度矫正的恢复能力。该方法在RNN搜索模型和更强大的Transformer模型上进行证明。
方法
oracle单词选择
一般来说,在第 j 步预测,NMT模型需要真实单词 作为语境单词来预测第 j 个目标单词时 ,因此可以选择oracle单词 来模拟语境单词。oracle单词应该与真实单词相似或者是同义词。使用不同的策略能产生不同的oracle单词。一种选择是单词级别的逐个搜索。此外可以扩大为beam搜索然后使用句子级别的度量来对候选翻译进行排序。
单词级别oracle
对第 { j - 1} 个解码步,选择单词级别oracle的直接方法是以最高的概率从单词分布中选取单词。实践中可以使用简单有效地从分类分布中采样的Gumbel-Max技术(Gumbel, 1954; Maddison et al., 2014)以获得更强健的单词级别oracle。注意Gumbel噪音仅用来选择oracle,它不影响训练的损失函数。
句子级别oracle
首先使用beam搜索获得k个最佳的候选翻译,beam搜索也可以使用Gumbel噪音来获得每个单词的生成。然后用BLEU分数评估每个翻译,并使用最高BLEU分数的翻译作为oracle句子。但有一个问题,模型每一步从真实单词和句子级别oracle采样时,两个序列应该有相同数量的单词,然而简单beam搜索解码算法并不能保证这一点。基于这一点,该文介绍强制解码方法以确保两个序列有同样的长度。
强制解码
真实单词序列的长度为,强制解码的目的是生成个单词,末尾附加句子结束(EOS)符号。因此beam搜索时当候选单词以EOS结尾且长度长于或短于时,会强制它生成个单词。
衰减采样
在训练开始阶段,模型频繁使用作为会导致很低的收敛率,甚至会受陷入局部最优点。另一方面,如果语境仍然以大概率从真实单词进行选取,那么模型可能在推理阶段不知道如何行动。因此真实单词的选择概率 p 不能是固定的,应该随着训练的进行逐步下降。借用但有所不同于Bengio et al.(2015)的思想,定义 p 为基于训练迭代次数 e (从0开始)的衰减函数:
其中是超参数。函数是严格单调递减函数。
实验及结果
该研究在NIST中文到英文和 WMT'14 英文到德文翻译任务上进行实验。实验的模型包括RNNsearch、SS-NMT(在RNN搜索基础上使用计划采样(SS)方法Bengio et al.(2015)) 、MIXER(混合了递增交叉熵强化学习的方法,句子级别度量为BLEU,平均回报由带有单层线性回归器的离线方法获得。)以及OR-NMT(该文提出的方法)。表1显示实验结果,结果显示该文提出的方法带了很大的性能提升。