attention

普通attention

c_t = \sum_{t=1}^Tsoftmax(sim(Q,K))V

根据相似度的计算方法不同大致可以分为两种:

  • 点积(大多数)
  • 全连接网络

点积

首先根据torch的nn.LSTM(Bi_direction)可以==>

output, (h, c) = nn.Lstm()
# output的size为seq_len, batch_size, 2*hidden_size
# h的size为2, batch_size, hidden_size

首先定义一个函数进行相似度的计算

def sim(encode_state, decode_state):# encode_s指的是encode的output, decode_s指的是上一步的隐藏状态
    # encode_s: (seq_len, batch_size, 2*hidden_size)
    # decode_s "(2, batch_size, hidden_size)"
    # 这是个双向的lstm所以状态要拿两个
    decode_s = torch.cat((decode_s[0], decode_s[1]), 1) # batch_size, 2*hidden_size
    encode_s = encode_s.permute(0, 1) # batch_size, seq_len, 2*hidden_size
    decode_s = decode_s.unsqueeze(2) # batch_size, 2*hidden_size, 1
    sim = torch.bmm(encode_s, decode_s) # batch_size, seq_len, 1
    
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容