Transformer系列:Greedy Search贪婪搜索解码流程原理解析

关键词:TransformerGreedy Search贪婪搜索

前言

在本系列前文Transformer系列:图文详解Decoder解码器原理中已经介绍了Decoder解码器在训练阶段的网络结构,本节介绍解码器在预测阶段的工作方式。

内容摘要

  • 解码器预测流程简述
  • 解码器自注意力层Q,K,V分配和维护
  • 解码器预测阶段源码分析
  • Greedy Search贪婪搜索简述
  • 解码器和贪婪搜索总结

解码器预测流程简述

Encoder-Decoder这类框架需要在解码器中分别拿到前文已经翻译的输入,以及编码器的输出这两个输入,一起预测出下一个翻译的单词。在训练阶段,一个句子通过右移一位的方式转化为从第二个词到最后一个词的逐位预测任务,一个答案句子通过shift right构造出两个句子分别作为输入和预测目标,如图所示

训练shifted right方式

训练阶段虽然输入的是完整的句子,但是配合下三角掩码使得自注意力层每个单词的信息聚合只能利用到该词和该词之前的单词,不许偷看后词,和实际的部署预测场景一致。
在预测场景下,解码器的工作流程变得简单,不再需要掩码,设置初始位置为<start>单词,从<start>开始逐位置开始预测下一个单词,而预测出的单词会和之前已经预测的结果进行拼接,以此继续迭代预测,不断重复这个过程直到预测到<end>截止。


逐位预测方式

解码器自注意力层Q,K,V分配和维护

解码器有两个注意力层,分别是自注意力层和交互注意力层,其中交互注意力层中的采用全局可见的编码器输出,因此K,V固定为完整的编码器输出,而自注意力层不同,为了达到不偷看的目的,Q,K,V需要动态维护的。先上结论

Q:当前位置(上一个预测出的)单词的信息embedding
K:截止到当前位置的所有单词的信息embedding,保证每个词不能偷看到后词信息
V:截止到当前位置的所有单词的信息embedding,保证每个词不能偷看到后词信息

一般Q位置输入为[B,L,D]三维矩阵,其中L代表文本长度seq length,Q的seq length决定了自注意力输出的seq length,在预测阶段的解码器自注意层,Q的seq length长度为1,只需要输出当前词的信息即可,输出也是一个token的embedding表征,这个token天然地代表下一个需要预测的词,原因是shifted right的方式使得上一个单词的Decoder输出embedding成为了下一个单词计算概率分布的依据。
在训练阶段由于Q输入的是带有下三角掩码的完整答案句子,而K,V直接取和Q一致,不需要额外再给到K,V,但是在预测阶段由于每次Q只输入最新的一个token信息,除了<start>位置Q,K,V三者相等,从第二个位置之后再也无法根据Q直接得到K,V,需要在每一步预测的时候记录下历史的K,V,然后将最新的Q信息拼接组合起来,注意K,V只能采用拼接的方式而不能采用模型重新计算的结果,如果重新计算代表K,V已经偷看了后词信息,计算示意图如下

预测阶段对解码器Q,K,V的维护示意图

以一个2层的Decoder为例,上图中||代表拼接,<start> out代表第一层Decoder之后输出的token embedding,output1代表第二层Decoder的输出,output1直接用于翻译单词的概率分布预测。
对于Q:在第一层以源文本的embedding作为Q,在第二层以源文本的embedding经过第一层Decoder的输出embedding作为Q。在<start>位置由于文本总长度也只有1,此时Q,K,V三者一致,从第二个单词开始,Q为当前位置单词。
对于K,V:K,V是截止到当前词之前的所有单词信息,此时需要额外对第一层和第二层的K,V进行维护,具体方式是将Q信息拼接到历史K,V上,对于第一层Decoder实际上是将源文本原始embedding拼接来得到完整的截止到当前的句子信息,而对于第二层Dcoder,这种拼接方式保证了对前词的Sefl Attention结果不做任何修改,例如在输入2的第二层,在<start> out后面直接拼接上<我> out,完整沿用了<start> out而没有重新计算,保证了<start>的自注意表征没有偷看到后词。


解码器预测源码分析

在Keras实现的Transformer源码中(github:attention-is-all-you-need-keras),decode_batch_greedy函数以Greedy Search贪婪搜索策略为基础,实现了Transformer的逐位预测过程。

def decode_batch_greedy(src_seq, encode_model, decode_model, start_mark, end_mark, max_len=128):
    ...

函数的输入如下:

  • src_seq:输入的待翻译文本
  • encode_model:编码器模型
  • decode_model:解码器模型
  • start_mark:<start>位置对应的数值id标志
  • end_mark:<end>位置对应的数值id标志
  • max_len:最长预测步长,即翻译的结果文本长度最大值为128

首先计算出待翻译文本在编码器的输出,该输出在整个预测过程的所有步长中全局可用

enc_ret = encode_model(src_seq).numpy()

然后开始准备逐步预测,初始化一个长度永远为1的当前词的数字id矩阵,初次使用以<start>位置开始,矩阵每一行代表该批次下的一条样本。

bs = src_seq.shape[0]
# TODO 该批次下每个样本从<start>开始
target_one = np.zeros((bs, 1), dtype='int32')
target_one[:, 0] = start_mark  # 2

在逐步预测之前定了各层的K,V列表,ended标志列表,解码结果列表

dec_outputs = [np.zeros((bs, 1, d_model)) for _ in range(n_dlayers)]
ended = [0 for x in range(bs)]
decoded_indexes = [[] for x in range(bs)]
  • dec_outputs:K,V维护列表,用于记录历史K,V,已经每步拼接当前词的信息形成新的K,V
  • ended:该批次下每条样本的结束标志,如果已经结束该列表对应索引位置下改为1
  • decoded_indexes:翻译预测结果,list to list格式,记录了每条样本的翻译结果,每条翻译结果为字符的数字id列表

逐步预测代码在一个循环体中,代码注解如下

    for i in range(max_len - 1):
        print("max_len-1 {}".format(i))
        # TODO src_seq的作用是在交互注意力层添加encoder信息的mask
        # TODO target_one是上一个预测出的单词
        # TODO target_one和dec_outputs,是随着步长变化的,编码器的信息不变
        # TODO 在解码器内部 target_one 和 dec_outputs 会进行相加
        outputs = [x.numpy() for x in decode_model([target_one, src_seq, enc_ret] + dec_outputs)]
        # TODO 这个output [1, 1, softmax]是decoder最终输出+Linear映射的得分
        # TODO new_dec_outputs 上一个预测词的原始embedding输入,和经过一层decoder之后的embedding输入
        new_dec_outputs, output = outputs[:-1], outputs[-1]
        for dec_output, new_out in zip(dec_outputs, new_dec_outputs):
            # TODO new_out 永远等于[batch_size, 1, 256]
            # TODO 更新前词(不包含当前词)的向量信息
            dec_output[:, -1, :] = new_out[:, 0, :]
        # TODO [[batch_size, n+1, embedding], [batch_size, n+1, embedding]]
        # TODO 这个地方预留0是为了在Decoder里面和当前词相加,获得完整的k,v
        dec_outputs = [np.concatenate([x, np.zeros_like(new_out)], axis=1) for x in dec_outputs]
        # TODO output[:, 0, :] => [1, softmax]
        # TODO 每次取当前位置的最大预测结果 贪婪 不一定是全局最优
        sampled_indexes = np.argmax(output[:, 0, :], axis=-1)
        for ii, sampled_index in enumerate(sampled_indexes):
            if sampled_index == end_mark:
                # TODO 该批次下某条样本已经翻译结束
                ended[ii] = 1
            if not ended[ii]:
                # TODO 如果没有结束,收集每条样本的翻译token
                # TODO 之前已经结束的,也不会再增加新的翻译结果了
                decoded_indexes[ii].append(sampled_index)
        # TODO 该批次下所有样本翻译结束
        if sum(ended) == bs:
            break
        # TODO target_one 永远只记录上一个预测出来的最新的单词
        target_one[:, 0] = sampled_indexes

在循环体中不断更新target_one,dec_outputs,其中target_one只表征当前一个单词,dec_outputs表征截止当前词之前在两层Decoder的K,V信息,每步计算出最终的Decoder输出output,使用贪婪搜索策略argmax来拿到每步得分最大的单词作为预测结果。

sampled_indexes = np.argmax(output[:, 0, :], axis=-1)

如果某条样本没有预测到结束则添加到最新翻译结果,否则不再添加新的预测单词

if sampled_index == end_mark:
    ended[ii] = 1
if not ended[ii]:
    decoded_indexes[ii].append(sampled_index)

最终等到该批次样本哦每一条都预测到end,跳出循环体任务结束

if sum(ended) == bs:
    break

Greedy Search贪婪搜索简述

Greedy Search贪婪搜索比较简单,每一步都选择概率最大的单词输出,最后组成整个句子输出。这种方法给出的结果一般情况结果比较差,因为只考虑了每一步的最优解,并不一定是全局最优,因为贪婪搜索会错过隐藏在低概率单词后面的高概率单词,另外当如果每个位置选错了,后续位置生成的内容很可能也是错误的,具有错误的累加效果。
贪婪搜索由于每步只需要关注最大得分,考虑的因素少,因此实现容易,执行速度快,因此在源码中作者将Greedy Search封装在decode_sequence_fast方法下,代表对序列的快速解码。

贪婪搜索示意图

解码器和贪婪搜索总结

  • 1.解码器在预测阶段从初始的<start>占位符开始,利用历史前词和全局编码器结果,逐位预测下一个单词,知道预测到end停止
  • 2.解码器的输入作为Query,Query只输入当前最新单词信息,自注意力层的Key,Value需要手动维护历史并且拼接当前最新信息
  • 3.贪婪搜索的优点是实现简单,可作为解码的快速实现,但是它只考虑了每一步的局部最优并不一定是全局最优,往往效果较差
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 217,826评论 6 506
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,968评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,234评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,562评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,611评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,482评论 1 302
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,271评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,166评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,608评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,814评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,926评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,644评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,249评论 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,866评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,991评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,063评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,871评论 2 354

推荐阅读更多精彩内容