Multi-head attention 多头注意力机制

Multi-head attention

本文基于《dive into deep learning》-pytorch

代码参考 《dive into deep learning》-pytorch

multi-head attention

基本信息

我们可以会希望注意力机制可以联合使用不同子空间的key,value,query的表示。因此,不是只用一个attention pooling,query、key、value可以被h个独立学到的线性映射转换。最后,h个attention pooling输出concat 并且再次通过一个线性映射得到最后的输出。

这种设计就是multi-head attention, h个attention pooling输出中的每一个就是一个头。使用全连接层来实现线性转换。

multi-attention1.png

理解纠错

【我过去有一个误解,就是multi-head是和CNN类似的机制,用多个的W降维,之后再计算多个注意力分数,再concat。直到我用pytorch中自带的multi-head attention,要求num_heads是hidden层维度可以整除的数,才发现这里的multi-head是针对子空间的】
【但是这里可以理解,用同样的维度,训练多个空间,可以更好地增强表达能力】

这部分解答可以参考:

https://www.zhihu.com/question/350369171 -transformer中multi-head attention中每个head为什么要进行降维?(实际上用切割来表示更为准确)

https://www.zhihu.com/question/446385446 - BERT中,multi-head 768*64*12与直接使用768*768矩阵统一计算,有什么区别?

对于 Multi-Head Attention,简单来说就是多个 Self-Attention 的组合,但多头的实现不是循环的计算每个头,而是通过 transposes and reshapes,用矩阵乘法来完成的。

In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors. —— 来自 google BERT 源代码注释

Transformer中把 d ,也就是hidden_size/embedding_size这个维度做了reshape拆分,具体可以看对应的 pytorch 代码

hidden_size (d) = num_attention_heads (m) * attention_head_size (a),也即 d=m*a

【↑作者:海晨威
链接:https://www.zhihu.com/question/350369171/answer/1718672303
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。】

正如回答中有写:

transformer中multi-head attention中每个head为什么要进行降维? - LooperXX的回答 - 知乎 https://www.zhihu.com/question/350369171/answer/860552006

回到题主的问题上来,如果只使用 one head 并且维度为 d_model ,相较于 8 head 并且维度为d_model/8,存在高维空间下学习难度较大的问题,文中实验也证实了这一点,于是将原有的高维空间转化为多个低维子空间并再最后进行拼接,取得了更好的效果,十分巧妙。

在实现的时候,multi-head把维度从[batch, len, embeding]变为[batch, len, head, embeding/head], 然后head就是多头,对每一个 embeding/head部分计算对应的attention。

pytorch实现

class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size,num_hiddens,bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)#映射到numhiddens
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
    def forward(self,queries,keys,values,valid_lens):
        #注意最后的 [batch_size` * `num_heads`,number of  key-value pairs,num_hiddens` / `num_heads]
        #这里涉及到reshape操作
        queries = transpose_qkv(self.W_q(queries),self.num_heads)#batch,seq,embed -> batch*num_head,seq,embed/num_head
        keys = transpose_qkv(self.W_k(keys),self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        if valid_lens is not None:#相当于每个batch扩充num_heads遍
            valid_lens = torch.repeat_interleave(valid_lens,repeats=self.num_heads,dim=0)
        print(queries.shape)#10,4,20
        print(values.shape)#10,6,20
        print(")*&^%$^&*()")
        output = self.attention(queries, keys, values, valid_lens)#attention计算是transpose之后的向量
        #得到,batch×head, seq,embed/head的矩阵,每一个embed/head是这一部分词向量子空间的attention加权和。
        weights= self.attention.attention_weights
        print(weights.shape)#10,4,6  query: 2 4 100 key: 2,6,100 ,一共10组,每组 4×6,query和key的交互值
        
        output_concat = transpose_output(output, self.num_heads)#transpose的逆运算
        return self.W_o(output_concat)#最后做一次线性变换 #2,4,100
        
def transpose_qkv(X, num_heads):
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)#batch seq head embed/head

    X = X.permute(0, 2, 1, 3) # batch head seq embed/head

    return X.reshape(-1, X.shape[2], X.shape[3])# batch×head, seq,embed/head
def transpose_output(X, num_heads):# batch×head, seq,embed/head
    """Reverse the operation of `transpose_qkv`"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])# batch,head, seq,embed/head
    X = X.permute(0, 2, 1, 3)## batch,seq,head, embed/head
    return X.reshape(X.shape[0], X.shape[1], -1)#batch,seq,embed
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))#2,4,100
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))#2,6,100
attention(X, Y, Y, valid_lens).shape #2,4,100 query有4个,得到4个对应的结果
#中间 attention weight大小是 10,4,6

self-attention

输入和输出大小一样

query,key,value一样

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))#2,4,100
attention(X, X, X, valid_lens).shape #2,4,100
  • 自注意力机制中,query,key,value来自于相同的空间
  • CNN和self-attention都有利于并行运算,self-attention有要求最短的最大路径长度。但是由于复杂度是序列长度的平方,长序列会计算比较慢。
  • 为了使用序列顺序信息,我们可以通过向输入表示添加位置编码注入绝对位置或相对位置信息,如transformer的 PositionalEncoding

补充
Transformer/CNN/RNN的对比(时间复杂度,序列操作数,最大路径长度) - Gordon Lee的文章 - 知乎
https://zhuanlan.zhihu.com/p/264749298
https://spaces.ac.cn/archives/4765
↑对self-attention的分析也很好,self-attention有不能充分编入位置信息的硬伤等

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 218,525评论 6 507
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,203评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,862评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,728评论 1 294
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,743评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,590评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,330评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,244评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,693评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,885评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,001评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,723评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,343评论 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,919评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,042评论 1 270
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,191评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,955评论 2 355

推荐阅读更多精彩内容