探索 Seq2Seq 模型及 Attention 机制

1 什么是 Seq2Seq ?

Seq2Seq 是一个 Encoder-Decoder 结构的神经网络,它的输入是一个序列(Sequence),输出也是一个序列(Sequence),因此而得名 “Seq2Seq”。在 Encoder 中,将可变长度的序列转变为固定长度的向量表达,Decoder 将这个固定长度的向量转换为可变长度的目标的信号序列。

如下是 Seq2Seq 模型工作的流程:

最基础的 Seq2Seq模型 包含了三个部分(上图有一部分没有显示的标明),即 Encoder、Decoder 以及连接两者的中间状态向量 C,Encoder通过学习输入,将其编码成一个固定大小的状态向量 C(也称为语义编码),继而将 C 传给Decoder,Decoder再通过对状态向量 C 的学习来进行输出对应的序列。

当然,在模型的训练阶段,工作方式如下所示:

这里的每一个 Box 代表了一个 RNN 单元,通常是 LSTM 或者 GRU 。其实,Basic Seq2Seq 是有很多弊端的,首先 Encoder 将输入编码为固定大小状态向量(hidden state)的过程实际上是一个“信息有损压缩”的过程。如果信息量越大,那么这个转化向量的过程对信息造成的损失就越大。同时,随着 sequence length的增加,意味着时间维度上的序列很长,RNN 模型也会出现梯度弥散。最后,基础的模型连接 Encoder 和 Decoder 模块的组件仅仅是一个固定大小的状态向量,这使得Decoder无法直接去关注到输入信息的更多细节。由于 Basic Seq2Seq 的种种缺陷,随后引入了 Attention 的概念以及 Bi-directional encoder layer 等,能够取得更好的表现。

总结起来说,基础的 Seq2Seq 主要包括 Encoder,Decoder,以及连接两者的固定大小的 State Vector。

2 Attention

seq2seq 模型虽然强大,但如果仅仅是单一使用的话,效果会大打折扣。注意力模型就是基于 Encoder-Decoder 框架下的一种模拟 Human 注意力直觉的一种模型。

人脑的注意力机制本质上是一种注意力分配的模型,比如说我们在阅读一篇论文的时候,在某个特定时刻注意力肯定只会在某一行文字描述,在看到一张图片时,我们的注意力肯定会聚焦于某一局部。随着我们的目光移动,我们的注意力肯定又聚焦到另外一行文字,另外一个图像局部。所以,对于一篇论文、一张图片,在任意一时刻我们的注意力分布是不一样的。这便是著名的注意力机制模型的由来。

早在计算机视觉目标检测相关的内容学习时,我们就提到过注意力机制的思想,目标检测中的 Fast R-CNN 利用 RoI(兴趣区域)来更好的执行检测任务,其中 RoI 便是注意力模型在计算机视觉上的应用。

注意力模型的使用更多是在自然语言处理领域,在机器翻译等序列模型应用上有着更为广泛的应用。在自然语言处理中,注意力模型通常是应用在经典的 Encoder-Decoder 框架下的,也就是 RNN 中著名的 N vs M 模型,seq2seq 模型正是一种典型的 Encoder-Decoder 框架。

Encoder-Decoder 作为一种通用框架,在具体的自然语言处理任务上还不够精细化。换句话说,单纯的Encoder-Decoder 框架并不能有效的聚焦到输入目标上,这使得像 seq2seq 的模型在独自使用时并不能发挥最大功效。比如说在上图中,编码器将输入编码成上下文变量 C,在解码时每一个输出 Y 都会不加区分的使用这个 C 进行解码。而注意力模型要做的事就是根据序列的每个时间步将编码器编码为不同 C,在解码时,结合每个不同的 C 进行解码输出,这样得到的结果会更加准确,如下所示:

简单的注意力模型通常有以上三个公式来描述:
1)计算注意力得分
2)进行标准化处理
3)结合注意力得分和隐状态值计算上下文状态 C 。

其中 u 为解码中某一时间步的状态值,也就是匹配当前任务的特征向量,vi 是编码中第 i 个时间步的状态值,a() 为计算 u 和 vi 的函数。a()通常可以取以下形式:

Attention works by first, calculating an attention vector, a, that is the length of the source sentence. The attention vector has the property that each element is between 0 and 1, and the entire vector sums to 1. We then calculate a weighted sum of our source sentence hidden states, H, to get a weighted source vector, w.

w = \sum_i a_ih_i

We calculate a new weighted source vector every time-step when decoding, using it as input to our decoder RNN as well as the linear layer to make a prediction.

Attention模型的出现是上述的seq2seq模型存在缺陷,即无论之前的encoder的context有多长,包含多少信息量,最终都要被压缩成一个几百维的vector。这意味着context越大,decoder的输入之一的last state 会丢失越多的信息。对于机器翻译问题,意味着输入sentence长度增加后,最终decoder翻译的结果会显著变差。

Attention 注意力机制提供了一个可以和远距离单词保持联系的方式, 解决了一个 vector 保存信息不足的问题。

Attention实质上是一种 content-based addressing 的机制,即从网络中某些状态集合中选取与给定状态较为相似的状态,进而做后续的信息抽取;

说人话就是: 首先根据 Encoder 和 Decoder 的特征计算权值,然后对Encoder的特征进行加权求和,作为Decoder的输入,其作用是将Encoder的特征以更好的方式呈献给Decoder,即:并不是所有 context 都对下一个状态的生成产生影响,Attention 就是选择恰当的context用它生成下一个状态。

图中, 明确的标出了,最后一个state包含了整个句子的信息。 但是问题是,一个 vector 是无法包含这么多信息的。 这样势必导致比较远的单词信息的丢失。 另一个优化上的问题是, 求微分时,较远的单词会受到 diminishing gradient 的影响, 导致失去了long term dependency 的关系。

Attention 注意力机制提供了一个可以和远距离单词保持联系的方式, 解决了一个vector保存信息不足的问题。

image.png

注意力机制的计算发生在decoder 的每一个步骤, 包含了四个步骤。首先decoder state 和encoder 所有的source state 进行softmax, 计算, 算出attention weights. 比如Fig. 6 里的 "the" hidden state, 就去和原文中的"les pauvres sont demunis"对应的state 进行softmax 计算。基于这个attention weights (attention distribution 在Fig. 6里), 我们算出一个上下文向量(attention output)。 这个向量是通过加权平均的 source state.attention output 再和 decoder hidden state 拼接起来, 最后算出y2. 下面的公式总结了这三步计算过程。

image.png

3 小任务:反转一个变长序列

设计网络结构,反转一个变长序列(最大长度N=20),即246910000反转为196420000,其中1-9为需要反转的有效字符,0为补位字符

考核点:

  1. 序列长度很长时,如何记住前序信息
  2. output structure是序列的优化方法

Constraints:

  1. 结构设计,不用调参,固定 batch=32, lr=0.02, optimizer=Adam, epoch=1
  2. 将序列中的数字映射到8维空间作为输入
  3. 禁止直接将 input 与 output 层相连

3.1 解决思路

对于这样一个回文序列问题,使用神经网络来做的话首当其冲需要考虑的就是序列长度的问题,因为回文序列的 Long Term Dependency 很远,长度一旦超过网络的学习能力的 capacity 之后,神经网络便学习不到相应的关系,进而就会产生错误的结果。

Seq2Seq 模型很适合解决这样的问题,特别是引入 Attention 机制后,记住前序信息的能力有明显的增强。另外,针对该问题的特点,相信使用双向的循环神经网络(BiLSTM or BiGRU)会有更好的表现。

当然,还有一种思路就是纯粹使用 Attention 来做,摒弃掉 RNN 部分,结合使用 Position Embedding 来“记住”序列的顺序或者位置信息。

至于 output structure是序列的优化方法 这一句想到现在也没有理解其含义...

3.2 模型实现

那么,现在就开始动起手来实现吧!

3.2.1 数据预处理

在神经网络中,对于文本的数据预处理无非是将文本转化为模型可理解的数字,这里都比较熟悉,不作过多解释。但在这里我们需要加入以下四种字符,<PAD>主要用来进行字符补全,<EOS>和<GO>都是用在Decoder端的序列中,告诉解码器句子的起始与结束,<UNK>则用来替代一些未出现过的词或者低频词。

<PAD>: 补全字符。
<GO>/<SOS>: 解码器端的句子起始标识符。
<EOS>: 解码器端的句子结束标识符。
<UNK>: 低频词或者一些未登陆词等。

我们首先需要对target端的数据进行一步预处理。在我们将target中的序列作为输入给Decoder端的RNN时,序列中的最后一个字母(或单词)其实是没有用的。我们来用下图解释:


我们此时只看右边的Decoder端,可以看到我们的target序列是[<go>, W, X, Y, Z, <eos>],其中<go>,W,X,Y,Z是每个时间序列上输入给RNN的内容,我们发现,<eos>并没有作为输入传递给RNN。因此我们需要将target中的最后一个字符去掉,同时还需要在前面添加<go>标识,告诉模型这代表一个句子的开始。


这个图代表我们的predict阶段,在这个阶段,我们没有target data,这个时候前一阶段的预测结果就会作为下一阶段的输入。

当然,predicting虽然与training是分开的,但他们是会共享参数的,training训练好的参数会供predicting使用。

目前为止我们已经完成了整个模型的构建,但还没有构造batch函数,batch函数用来每次获取一个batch的训练样本对模型进行训练。

在这里,我们还需要定义另一个函数对batch中的序列进行补全操作。这是啥意思呢?我们来看个例子,假如我们定义了batch=2,里面的序列分别是

[['h', 'e', 'l', 'l', 'o'],
['w', 'h', 'a', 't']]
那么这两个序列的长度一个是5,一个是4,变长的序列对于RNN来说是没办法训练的,所以我们这个时候要对短序列进行补全,补全以后,两个序列会变成下面的样子:

[['h', 'e', 'l', 'l', 'o'],
['w', 'h', 'a', 't', '<PAD>']]
这样就保证了我们每个batch中的序列长度是固定的。

To train we run the input sentence through the encoder, and keep track of every output and the latest hidden state. Then the decoder is given the <SOS> token as its first input, and the last hidden state of the encoder as its first hidden state.

"Teacher forcing" is the concept of using the real target outputs as each next input, instead of using the decoder's guess as the next input. Using teacher forcing causes it to converge faster but when the trained network is exploited, it may exhibit instability http://minds.jacobs-university.de/sites/default/files/uploads/papers/ESNTutorialRev.pdf.

关于 <SOS>和<EOS> : 它们都是用在 Decoder 端的序列中,告诉解码器句子的起始与结束。所以,target 序列需要加上 <SOS>
在 seq2seq 模型的预测流程中,Encoder 需要依靠 <EOS> 来判断序列的结束,所以对于 source 序列需要在结尾加上 <EOS> 。

在 seq2seq 模型的训练流程中,

  1. target sequence 加上 <EOS>,train decoder input + <SOS>

  2. predic decoder <SOS>

作为补充,参考一些案例代码,同时也实现了 TensorFlow 版本。

Github Repository:Seq2Seq-Attention

3.2 思考与改进

Encoder-Decoder 模型确实

PyTorch 相关 API 概述

PyTorch 中的词嵌入是通过函数 nn.Embedding(m, n) 来实现的,其中 m 表示所有的单词数目,n 表示词嵌入的维度。

nn.GRU(hidden_size, hidden_size)

nn.Linear(hidden_size, output_size)

nn.LogSoftmax(dim=1)

F.relu(output)

编程实现的知识点

Python Dict get

dict.get(key, default=None)

params:
- key: 字典中需要查找的键。
- default: 如果指定键的值不存在时候,返回该默认值。

总结

模型分为 encoder 和 decoder 两个部分,decoder 部分比较简单,就是一层 Embedding 层加上两层 GRU。之前处理的 batch 的格式主要是为了使用pack_padded_sequence 和 pad_packed_sequence 这两个类对GRU输入输出批量处理。一定要注意各个变量的shape。

FAQs:

1.为什么要在 target sequence 前面添加 <GO>/<SOS> 呢?不能直接输入target吗?在末尾添加<eos>倒是可以理解。

因为预测阶段你没有 target sequence 呀。首先,train阶段,因为你知道 target sequence 数据,当然可以直接输,但是 predict 的时候咋办,你不知道第一个词是啥,这个时候一般用<GO>启动

train 的时候对 decoder_input 预处理也要加<go>,这是因为训练和预测需要从同样的输入开始(就是<go>)。比如source是common,target是 cmmnoo,当encode结束,开始decode时,给decoder的第一个输入都是<go>,train的阶段,你是希望网络能学习到第一个输出是c,这样在predict阶段,你用<go>启动网络才能输出c。如果不加<go>,你第一个输入就是c了,网络怎么能学习到用<go>做输入时该输出什么呢

参考文献

作者:涛涛江水向坡流
链接:https://www.jianshu.com/p/ead1b446a124
來源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。


作者:qq_18603599
来源:CSDN
原文:https://blog.csdn.net/qq_18603599/article/details/80581115
版权声明:本文为博主原创文章,转载请附上博文链接!


最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 194,390评论 5 459
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 81,821评论 2 371
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 141,632评论 0 319
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,170评论 1 263
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,033评论 4 355
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,098评论 1 272
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,511评论 3 381
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,204评论 0 253
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,479评论 1 290
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,572评论 2 309
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,341评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,213评论 3 312
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,576评论 3 298
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 28,893评论 0 17
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,171评论 1 250
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,486评论 2 341
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,676评论 2 335

推荐阅读更多精彩内容

  • 最近人工智能随着AlphaGo战胜李世乭这一事件的高关注度,重新掀起了一波新的关注高潮,有的说人工智能将会如何超越...
    MiracleJQ阅读 2,781评论 2 1
  • 近日,谷歌官方在 Github开放了一份神经机器翻译教程,该教程从基本概念实现开始,首先搭建了一个简单的NMT模型...
    MiracleJQ阅读 6,296评论 1 11
  • 神经网络。《Make Your Own Neural Network》,用非常通俗易懂描述讲解人工神经网络原理用代...
    利炳根阅读 4,953评论 0 7
  • 作者 | 武维AI前线出品| ID:ai-front 前言 自然语言处理(简称NLP),是研究计算机处理人类语言的...
    AI前线阅读 2,546评论 0 8
  • 最近看到网上都是关于重生的小说,但我觉得这个设置有个逻辑错误,比如人的性情和事情发展,都是相互作用下产生的,...
    张毓巧_阅读 76评论 0 0