和gnmt的区别在于:
train环节
for segNum in range(int(max_decoder_input_len / 10)):
tmp_max_encoder_input_len = segNum * 10 + max_encoder_input_len#当前最大值
# choose a bucket(depends on the encoder length)
bucket_id = choose_bucket(tmp_max_encoder_input_len)
encoder_inputs, decoder_inputs, target_weights = model.pad_pair(encoder_inputs_ori,
decoder_inputs_ori,
segNum, FLAGS.segment_length,
bucket_id, segNum == 0)
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
target_weights, bucket_id, False)
step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint
loss += step_loss / FLAGS.steps_per_checkpoint
即:每次根据最长长度取不同的bucket,然后调用pad_pair
函数来生成encoder_inputs, decoder_inputs
,step步骤和正常一样。
decode环节
main contribution在这里。
generate_response
是直接负责的函数,它返回最终生成的回答句子。具体过程是:
- 准备工作,如果encoder那边的长度超过限定值,则要截取一部分;并为
src
和dst_generated
pad上GO和EOS的ID。 - 调用get_segment返回两个候选的candidate_response,和这两个候选的prob_response(每一个单词的概率之和)
- 根据以上公式计算每个候选得分。
- 循环进行两个步骤直到EOS或长度超过限制。