介绍
seq2seq中的decoder是一个自回归的生成模型,那么在训练阶段,第t步输入的前缀序列是来自真实数据分布的,这种学习方式称为教师强制(Teacher Forcing)。然而在预测阶段,前缀序列则是来自模型分布的
。由于模型分布和真实数据分布并不严格一致,因此一旦预测前缀
的过程中存在错误,会导致错误传播,使得后续生成的序列偏离真实分布,这个问题称为曝光偏差(Exposure Bias)。
一个简单的想法就是在训练decoder的时候将真实前缀序列中某些位置随机替换成随机词,让decoder不过分依赖前缀输入。但是每一步中不管输入如何选择, 目标输出依然是来自于真实数据. 这可能使得模型预测一些不正确的序列。比如一个真实的序列是 “吃饭”, 如果在第一步生成时使用模型预测的词是 “喝”, 模型就会强制记住 “喝饭” 这个不正确的序列,这个问题被称为过度纠正(Overcorrection)。
方案
ACL2019最佳长文Bridging the Gap between Training and Inference for Neural Machine Translation提出了Oracle Word的概念,也就是说不是随机选取词来替换,而是在word level或者sentence level考虑“合乎情理”的词来替换。
- Word-Level Oracle是指在第t-1步根据softmax输出的概率分布做词采样,为了增加鲁棒性,可以在概率分布上加入Gumbel noise。
- Sentence-Level Oracle是指先利用beam search获得一些候选翻译结果,再和真实结果计算BLEU值,选择对应最优BLEU值的候选翻译结果作为decoder输入。其中针对候选翻译结果可能和真实结果长度不一样,又引入了Force Decoding技巧。
- 文中Sampling with Decay技巧是考虑在模型未得到充分训练时,decoder的解码结果可能很不可靠,为了避免模型无法收敛,替换前缀序列概率伴随着训练的step缓慢增加。
实现
初读这篇文章的时候有这样的疑问:文章中的实现都是基于RNN的,怎么在基于Transformer的机器翻译模型中应用以上的方法,难不成为了使用这些技巧,放弃模型的并行性?相比Sentence-Level Oracle,Word-Level Oracle更易于并行实现,以下我的实现方案:
- 预先前向计算一次decoder部分,并映射到字典维度,得到logits
- 利用由top_K logits计算的概率分布并采样词,得到候选替换序列
- 根据训练步数global_steps计算替换概率p,利用预设的最大概率值截断
- 依照p,替换decoder的输入序列
注:预先前向计算后需要使用tf.stop_gradients,防止反向传播时冗余的梯度回传。
# 利用由top_K logits计算的概率分布并采样词,得到候选替换序列
def sample_with_topk(logits, k):
reshaped_logits = (tf.reshape(logits, [-1, shape_list(logits)[-1]]))
reshaped_logits_values, reshaped_logits_indices = tf.nn.top_k(input=reshaped_logits, k=k, sorted=True)
choices = tf.multinomial(reshaped_logits_values, 1)
choices = tf.concat(
[tf.expand_dims(tf.cast(tf.range(tf.reduce_prod(shape_list(logits)[0:-1])), dtype=tf.int64), axis=-1),
choices], axis=-1)
choices = tf.gather_nd(params=reshaped_logits_indices, indices=choices)
choices = tf.reshape(choices, shape_list(logits)[:logits.get_shape().ndims - 1])
return tf.cast(choices, dtype=tf.int64)
# encoder_outpur: encoder所有隐层结果
# encoder_decoder_attention_bias: decoder中计算enc_dec_atten所涉及的mask偏置
# targets: 目标id序列
# decoder_input: decoder输入序列的嵌入向量
# decoder_self_attention_bias: decoder中self_atten的偏置
targets = common_layers.flatten4d3d(targets)
decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(targets, hparams, features=features)
# decoder_output_tmp: 预先计算decoder的最后一层隐层输出
decoder_output_tmp = self.decode(
decoder_input,
encoder_output,
encoder_decoder_attention_bias,
decoder_self_attention_bias,
hparams,
nonpadding=features_to_nonpadding(features, "targets"),
losses=losses)
# 将隐层向量映射到字典维度
with tf.variable_scope(self._variable_scopes['model_fn'],reuse=tf.AUTO_REUSE):
logits_tmp = self.top(decoder_output_tmp, features)
# 防止梯度冗余传播
logits_tmp = tf.stop_gradient(logits_tmp)
# 采样得到候选序列
targets_proposal = sample_with_topk(logits_tmp, 10)
# 获取全局训练步数
global_steps = tf.cast(tf.train.get_global_step(), dtype=tf.float32)
# 计算保留概率=1-替换概率,最小保留概率是0.5
p = tf.maximum(1.0 - tf.math.floordiv(global_steps, 10000.) * 0.5 / 75., 0.5)
# 判断本次是保留, 还是替换;0表示保留, 1表示替换
pred = tf.cond(tf.less(tf.random.uniform(shape=(), minval=0, maxval=1), p),
true_fn=lambda: 0.,
false_fn=lambda: 1.)
# 随机选择15%序列中位置做替换
mask = tf.less(tf.random_uniform(tf.shape(features["targets_raw"])), 0.15 * pred)
# 利用mask融合原始目标序列和候选目标序列
targets_proposal = (cast_like(mask, targets_proposal) * targets_proposal +
cast_like(tf.logical_not(mask), targets_proposal) *
cast_like(features["targets_raw"], targets_proposal)) * \
cast_like(common_layers.weights_nonzero(features["targets_raw"]), targets_proposal)
# 利用融合的目标序列作为decoder的输入序列,计算decoder隐层向量
with tf.variable_scope(self._variable_scopes['symbol_modality_{}_{}'.format(hparams.problem_hparams.vocab_size["targets"],hparams.hidden_size)],reuse=tf.AUTO_REUSE):
targets_proposal = self._problem_hparams.modality["targets"].bottom(targets_proposal)
targets_proposal = common_layers.flatten4d3d(targets_proposal)
decoder_input_random, _ = transformer_prepare_decoder(targets_proposal, hparams, features=features)
decoder_output_random = self.decode(
decoder_input_random,
encoder_output,
encoder_decoder_attention_bias,
decoder_self_attention_bias,
hparams,
nonpadding=features_to_nonpadding(features, "targets"),
losses=losses)
总结
- 为了缓解过度纠正问题,选择Oracle Word策略不能过于随机,一定程度上需要考虑当前语义。
- 缓解Exposure Bias问题不能以牺牲并行性为代价。
拓展
- 苏神的实现方式Seq2Seq中Exposure Bias现象的浅析与对策,并引入对抗训练的概念来缓解Exposure Bias问题,也发人深省。
- 本质上Exposure Bias问题来源于自回归生成中训练和测试的mismatch,目前利用前缀序列的方式都是离散的,是不是可以连续的利用前缀序列,从而在不失并行性的前提下统一训练与测试两阶段还有待探究。