Attention机制(原理+代码)
直接进入正题,在介绍Attention机制之前需要知道什么是seq2seq模型,也就是Encoder-Decoder模型,下面对seq2seq及逆行简单介绍。
1.seq2seq模型
作为RNN模型的一种变体:N vs M(N,M意思是输入和输出不是等长),此结构又称为Encoder-Decoder模型,也就是我们常说的seq2seq模型。seq2seq模型的出现解决了许多应用的问题,比如解决了传统的序列等长的问题,在机器翻译等领域得到了很好的运用。
seq2seq模型先将输入数据编码成一个上下文向量c:
这里的上下文向量c有多种方式,可以将最后一个隐状态赋值给c,也可以将最后一个隐状态做变量赋值给c,也可以将所有的状态变量赋值给c。
得到c之后,然后用另外一个RNN模型进行解码,也就是Decoder过程。你可以理解为将c作为新的作为输入到Decoder结构中。
输入的情况分为上面两种,这种结构不限制输入和输出的序列长度,所以在许多领域得到了应用,但其本身存在着有些问题,在Encoder-Decoder结构中,Encoder把所有的输入序列都编码成一个统一的语义特征c再解码,因此,c中必须包含原始序列中的所有信息,它的长度就成了限制模型性能的瓶颈。如机器翻译问题,当要翻译的句子较长时,一个c可能存不下那么多信息,就会造成翻译精度的下降,为了解决这个问题,采用了attention机制。
2.Attention机制
Attention机制的定义:
1.给定一组向量集合values,以及一个向量query,attention机制是一种根据query计算values的加权求和的机制。
2.attention的重点就是这个集合values中的每个value的“权值”的计算方法。
3.有时候也把这种attention的机制叫做query的输出关注了(或者说叫考虑到了)原文的不同部分。(Query attends to the values)
举例:刚才seq2seq中,哪个是query,哪个是values?
each decoder hidden state attends to the encoder hidden states (decoder的第t步的hidden state——st是query,encoder的hidden state是values)
从定义来看Attention的感性认识:
The weighted sum is a selective summary of the information contained in the values, where the query determines which values to focus on.
3.attention的计算变体
首先,从大的概念来讲,针对attention的变体主要有两种方式:
1.一种是attention向量的加权求和和计算方式上进行创新;
2.另一种是attention score(匹配度或者叫权值)的计算方式上进行创新。
当然还有一种就是把二者都有改变的结合性创新,或者是迁移性创新,比如借鉴CNN的Inception思想等等,后续会提到一点,详细的应该是在以后可能要将的Transformer里面会详细提到。
我们先针对第一种方法讲讲区别,其实虽然名字变来变去,他们的差异没有那么多。
3.1 针对attention向量计算方式变体
大概分成这么几种:
Soft attention, global attention, 动态attention
Hard attention
‘半硬半软’的attention(local attention)
静态attention
强制前向attention
Soft attention, global attention, 动态attention
这三个其实就是Soft attention,也就是我们上面讲过的那种最常见的attention,是在求注意力分配概率分布的时候,对于输入句子X中任意一个单词都给出了个概率,是概率分布,把attention变量(context vecor)用ctc_tct表示,attention得分在经过了softmax后的权值用alpha表示
Hard attention
Soft是给每个单词都赋予一个单词match概率,那么如果不这样做,直接从输入句子里面找到特定的单词,然后把目标句子单词和这个单词对齐,而其他收入句子中的单词硬性地认为对齐概率为0,就是Hard Attention Model思想。
5.代码示例
引自:https://blog.csdn.net/sun_xiao_kai/article/details/95873046?utm_term=attention%E4%BB%A3%E7%A0%81python&utm_medium=distribute.pc_aggpage_search_result.none-task-blog-2allsobaiduweb~default-0-95873046&spm=3001.4430