缓解Exposure Bias的一种实现

介绍

seq2seq中的decoder是一个自回归的生成模型,那么在训练阶段,第t步输入的前缀序列是来自真实数据分布的x_{1:(t-1)},这种学习方式称为教师强制(Teacher Forcing)。然而在预测阶段,前缀序列则是来自模型分布的\hat{x}_{1:(t-1)}。由于模型分布和真实数据分布并不严格一致,因此一旦预测前缀\hat{x}_{1:(t-1)}的过程中存在错误,会导致错误传播,使得后续生成的序列偏离真实分布,这个问题称为曝光偏差(Exposure Bias)。

一个简单的想法就是在训练decoder的时候将真实前缀序列x_{1:(t-1)}​中某些位置随机替换成随机词,让decoder不过分依赖前缀输入。但是每一步中不管输入如何选择, 目标输出依然是来自于真实数据. 这可能使得模型预测一些不正确的序列。比如一个真实的序列是 “吃饭”, 如果在第一步生成时使用模型预测的词是 “喝”, 模型就会强制记住 “喝饭” 这个不正确的序列,这个问题被称为过度纠正(Overcorrection)。

方案

ACL2019最佳长文Bridging the Gap between Training and Inference for Neural Machine Translation提出了Oracle Word的概念,也就是说不是随机选取词来替换,而是在word level或者sentence level考虑“合乎情理”的词来替换。

  • Word-Level Oracle是指在第t-1步根据softmax输出的概率分布做词采样,为了增加鲁棒性,可以在概率分布上加入Gumbel noise
  • Sentence-Level Oracle是指先利用beam search获得一些候选翻译结果,再和真实结果计算BLEU值,选择对应最优BLEU值的候选翻译结果作为decoder输入。其中针对候选翻译结果可能和真实结果长度不一样,又引入了Force Decoding技巧。
  • 文中Sampling with Decay技巧是考虑在模型未得到充分训练时,decoder的解码结果可能很不可靠,为了避免模型无法收敛,替换前缀序列概率伴随着训练的step缓慢增加。

实现

初读这篇文章的时候有这样的疑问:文章中的实现都是基于RNN的,怎么在基于Transformer的机器翻译模型中应用以上的方法,难不成为了使用这些技巧,放弃模型的并行性?相比Sentence-Level OracleWord-Level Oracle更易于并行实现,以下我的实现方案:

  1. 预先前向计算一次decoder部分,并映射到字典维度,得到logits
  2. 利用由top_K logits计算的概率分布并采样词,得到候选替换序列
  3. 根据训练步数global_steps计算替换概率p,利用预设的最大概率值截断
  4. 依照p,替换decoder的输入序列

注:预先前向计算后需要使用tf.stop_gradients,防止反向传播时冗余的梯度回传。


# 利用由top_K logits计算的概率分布并采样词,得到候选替换序列
def sample_with_topk(logits, k):
    reshaped_logits = (tf.reshape(logits, [-1, shape_list(logits)[-1]]))
    reshaped_logits_values, reshaped_logits_indices = tf.nn.top_k(input=reshaped_logits, k=k, sorted=True)
    choices = tf.multinomial(reshaped_logits_values, 1)
    choices = tf.concat(
        [tf.expand_dims(tf.cast(tf.range(tf.reduce_prod(shape_list(logits)[0:-1])), dtype=tf.int64), axis=-1),
         choices], axis=-1)

    choices = tf.gather_nd(params=reshaped_logits_indices, indices=choices)
    choices = tf.reshape(choices, shape_list(logits)[:logits.get_shape().ndims - 1])
    return tf.cast(choices, dtype=tf.int64)

# encoder_outpur: encoder所有隐层结果
# encoder_decoder_attention_bias: decoder中计算enc_dec_atten所涉及的mask偏置
# targets: 目标id序列
# decoder_input: decoder输入序列的嵌入向量
# decoder_self_attention_bias: decoder中self_atten的偏置
targets = common_layers.flatten4d3d(targets)
decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(targets, hparams, features=features)

# decoder_output_tmp: 预先计算decoder的最后一层隐层输出
decoder_output_tmp = self.decode(
    decoder_input,
    encoder_output,
    encoder_decoder_attention_bias,
    decoder_self_attention_bias,
    hparams,
    nonpadding=features_to_nonpadding(features, "targets"),
    losses=losses)

# 将隐层向量映射到字典维度
with tf.variable_scope(self._variable_scopes['model_fn'],reuse=tf.AUTO_REUSE):
    logits_tmp = self.top(decoder_output_tmp, features)

# 防止梯度冗余传播
logits_tmp = tf.stop_gradient(logits_tmp)

# 采样得到候选序列
targets_proposal = sample_with_topk(logits_tmp, 10)

# 获取全局训练步数
global_steps = tf.cast(tf.train.get_global_step(), dtype=tf.float32)

# 计算保留概率=1-替换概率,最小保留概率是0.5
p = tf.maximum(1.0 - tf.math.floordiv(global_steps, 10000.) * 0.5 / 75., 0.5)

# 判断本次是保留, 还是替换;0表示保留, 1表示替换
pred = tf.cond(tf.less(tf.random.uniform(shape=(), minval=0, maxval=1), p),
               true_fn=lambda: 0.,
               false_fn=lambda: 1.)

# 随机选择15%序列中位置做替换
mask = tf.less(tf.random_uniform(tf.shape(features["targets_raw"])), 0.15 * pred)

# 利用mask融合原始目标序列和候选目标序列
targets_proposal = (cast_like(mask, targets_proposal) * targets_proposal +
                    cast_like(tf.logical_not(mask), targets_proposal) *
                    cast_like(features["targets_raw"], targets_proposal)) * \
                   cast_like(common_layers.weights_nonzero(features["targets_raw"]), targets_proposal)

# 利用融合的目标序列作为decoder的输入序列,计算decoder隐层向量
with tf.variable_scope(self._variable_scopes['symbol_modality_{}_{}'.format(hparams.problem_hparams.vocab_size["targets"],hparams.hidden_size)],reuse=tf.AUTO_REUSE):
    targets_proposal = self._problem_hparams.modality["targets"].bottom(targets_proposal)
targets_proposal = common_layers.flatten4d3d(targets_proposal)
decoder_input_random, _ = transformer_prepare_decoder(targets_proposal, hparams, features=features)
decoder_output_random = self.decode(
    decoder_input_random,
    encoder_output,
    encoder_decoder_attention_bias,
    decoder_self_attention_bias,
    hparams,
    nonpadding=features_to_nonpadding(features, "targets"),
    losses=losses)

总结

  • 为了缓解过度纠正问题,选择Oracle Word策略不能过于随机,一定程度上需要考虑当前语义。
  • 缓解Exposure Bias问题不能以牺牲并行性为代价。

拓展

  • 苏神的实现方式Seq2Seq中Exposure Bias现象的浅析与对策,并引入对抗训练的概念来缓解Exposure Bias问题,也发人深省。
  • 本质上Exposure Bias问题来源于自回归生成中训练和测试的mismatch,目前利用前缀序列的方式都是离散的,是不是可以连续的利用前缀序列,从而在不失并行性的前提下统一训练与测试两阶段还有待探究。
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

友情链接更多精彩内容