普通attention
根据相似度的计算方法不同大致可以分为两种:
- 点积(大多数)
- 全连接网络
点积
首先根据torch的nn.LSTM(Bi_direction)
可以==>
output, (h, c) = nn.Lstm()
# output的size为seq_len, batch_size, 2*hidden_size
# h的size为2, batch_size, hidden_size
首先定义一个函数进行相似度的计算
def sim(encode_state, decode_state):# encode_s指的是encode的output, decode_s指的是上一步的隐藏状态
# encode_s: (seq_len, batch_size, 2*hidden_size)
# decode_s "(2, batch_size, hidden_size)"
# 这是个双向的lstm所以状态要拿两个
decode_s = torch.cat((decode_s[0], decode_s[1]), 1) # batch_size, 2*hidden_size
encode_s = encode_s.permute(0, 1) # batch_size, seq_len, 2*hidden_size
decode_s = decode_s.unsqueeze(2) # batch_size, 2*hidden_size, 1
sim = torch.bmm(encode_s, decode_s) # batch_size, seq_len, 1