1 背景介绍
在上一篇文章中,我们介绍了Encoder-Decoder架构,但是这个架构有个问题。就是如果想建模长序列的话,最终encoder编码的最后一个时间步的隐状态很可能会丢失掉最初时间步的信息。所有我们想,能不能decoder解码的时候能够动态的关注encoder中的不同时间步的信息,这样就能够克服信息瓶颈。这跟我们考四六级的翻译题很相似,比如你要中译英,每翻译一个单词,我们可能都会看一下这个单词附近的主语啊副词啊动词啊形容词啊之类的。所以Bahdanau attention的核心就是decoder解码的时候,decoder的隐状态与encoder中的output(src_len,batch_size,hidden_dim)也就是每一个时间步的隐状态,计算一个attention score,说得直白点就是encoder中哪一个时间步对解码器的当前时间步影响较大,计算attention score 的时候相当于一个小型的前馈神经网络。
2 流程图

3 本人手绘流程图

可以看到encoder_output(src_len, batch_size, hidden_dim)分出来两个:一个与decoder_hid计算attention score,这个attention score你可以理解成源语言中那些单词(时间步)与decoder的当前隐藏状态相关性更高,那么decoder就会对这个单词(时间步)给予更多的关注,也就是所谓的注意力;另一个就是与attention score相乘,再在dim=0的维度上求和,这个矩阵的形状是(batch_size,hidden_dim),你会发现哎,这个矩阵的形状跟没有注意力机制的enocder-decoder完全一样呢?对,是一模一样,但是实际上由于每一个时间步的attention score不一样,因此越大的attention score在广播相乘的时候会给予相对应的时间步的隐藏状态更大的权重,求和之后,该时间步的隐藏状态占总隐藏状态的比重也会增加,也可以说模型“注意”到了该时间步,这也是Bahadanau attention的精髓所在。
4 demo (手动实现)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.vocab import build_vocab_from_iterator
from tqdm import tqdm
import spacy
import random
from datasets import load_dataset
import torch.nn.functional as F
random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
import os
import pickle
if os.path.exists("multi30k_dataset.pkl"):
pickle_file = "multi30k_dataset.pkl"
with open(pickle_file, "rb") as f:
ds = pickle.load(f)
else:
ds = load_dataset("multi30k")
with open("multi30k_dataset.pkl", "wb") as f:
pickle.dump(ds, f)
try:
spacy_en = spacy.load('en_core_web_sm')
spacy_de = spacy.load('de_core_news_sm')
except OSError:
print("请先安装spacy语言模型:")
print("python -m spacy download en_core_web_sm")
print("python -m spacy download de_core_news_sm")
exit()
def tokenize_en(text):
return [token.text for token in spacy_en.tokenizer(text)]
def tokenize_de(text):
return [token.text for token in spacy_de.tokenizer(text)]
from torchtext.vocab import Vocab
from collections import Counter
import torchtext
def yield_tokens(data_iter, tokenizer, lang):
for example in data_iter:
yield tokenizer(example[lang])
def build_vocab_compatible(data_iter, tokenizer, lang, min_freq=2):
counter = Counter()
for token in yield_tokens(data_iter, tokenizer, lang):
counter.update(token)
if hasattr(torchtext.vocab, 'build_vocab_from_iterator'):
vocab = Vocab(counter,
min_freq=min_freq,
specials=['<unk>', '<pad>', '<sos>', '<eos>'])
else:
vocab = Vocab(counter,
min_freq=min_freq,
specials=['<unk>', '<pad>', '<sos>', '<eos>'])
vocab.set_default_index(vocab['<unk>'])
return vocab
src_vocab = build_vocab_compatible(ds['train'], tokenize_de, 'de', min_freq=2)
trg_vocab = build_vocab_compatible(ds['train'], tokenize_en, 'en', min_freq=2)
class BahdanauAttention(nn.Module):
def __init__(self, hidden_dim):
super(BahdanauAttention, self).__init__()
self.hidden_dim = hidden_dim
self.Wa = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.Ua = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.va = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, decoder_hidden, encoder_outputs):
decoder_hidden = decoder_hidden.unsqueeze(0)
scores = self.va(torch.tanh(
self.Wa(encoder_outputs) + self.Ua(decoder_hidden)
)).squeeze(2)
attention_weights = F.softmax(scores, dim=0)
context_vector = (attention_weights.unsqueeze(2)*encoder_outputs).sum(dim=0)
return context_vector, attention_weights
class Encoder(nn.Module):
def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
super(Encoder, self).__init__()
self.input_dim = input_dim
self.emb_dim = emb_dim
self.hid_dim = hid_dim
self.n_layers = n_layers
self.dropout = dropout
self.embedding = nn.Embedding(input_dim, emb_dim)
# 添加 bidirectional=True 使LSTM变为双向
self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout,
bidirectional=True, batch_first=False)
self.dropout = nn.Dropout(dropout)
# 双向LSTM的输出维度是 hid_dim * 2
self.fc = nn.Linear(hid_dim * 2, hid_dim)
def forward(self, src):
# src的形状是 (src_len, batch_size)
embedded = self.dropout(self.embedding(src))
outputs, (hidden, cell) = self.rnn(embedded)
outputs = outputs[:, :, :self.hid_dim] + outputs[:, :, self.hid_dim:]
hidden = hidden.view(self.n_layers, 2, -1, self.hid_dim)
hidden = hidden[:, 0, :, :] + hidden[:, 1, :, :]
cell = cell.view(self.n_layers, 2, -1, self.hid_dim)
cell = cell[:, 0, :, :] + cell[:, 1, :, :]
return outputs, hidden, cell
class Decoder(nn.Module):
def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout, attention):
super(Decoder, self).__init__()
self.output_dim = output_dim
self.hidden_dim = hid_dim
self.embedding = nn.Embedding(output_dim, emb_dim)
self.attention = attention
self.rnn = nn.LSTM(emb_dim + hid_dim, hid_dim, n_layers, dropout=dropout, batch_first=False)
self.fc_out = nn.Linear(hid_dim * 2, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input, hidden, cell, encoder_outputs):
input = input.unsqueeze(0)
embedded = self.dropout(self.embedding(input))
decoder_hidden = hidden[-1] # 取最后一层的隐藏状态
context_vector, attention_weights = self.attention(decoder_hidden, encoder_outputs)
context_vector = context_vector.unsqueeze(0)
rnn_input = torch.cat((embedded, context_vector), dim=2)
output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell)) # hidden/cell的维度跟输入维度无关
output = output.squeeze(0)
context_vector = context_vector.squeeze(0)
prediction = self.fc_out(torch.cat((output, context_vector), dim=1))
return prediction, hidden, cell, attention_weights
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(self, src, trg, teacher_forcing_ratio=0.5):
batch_size = trg.shape[1]
trg_len = trg.shape[0]
trg_vocab_size = self.decoder.output_dim
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
attention_weights = torch.zeros(trg_len, src.shape[0], batch_size).to(self.device)
encoder_outputs, hidden, cell = self.encoder(src)
input = trg[0, :]
for t in range(1, trg_len):
output, hidden, cell, decoder_attention = self.decoder(input, hidden, cell, encoder_outputs)
outputs[t, :, :] = output
attention_weights[t, :, :] = decoder_attention
teacher_force = random.random() < teacher_forcing_ratio
top1 = output.argmax(1)
input = trg[t] if teacher_force else top1
return outputs, attention_weights
# 超参数
INPUT_DIM = len(src_vocab)
OUTPUT_DIM = len(trg_vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
BATCH_SIZE = 128
N_EPOCHS = 40
CLIP = 1
LEARNING_RATE = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attention = BahdanauAttention(HID_DIM).to(device)
encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT).to(device)
decoder = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT, attention).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"模型参数量: {count_parameters(model):,}")
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=trg_vocab['<pad>'])
def collate_fn(batch):
src_batch, trg_batch = [], []
for example in batch:
src = torch.tensor([src_vocab[token] for token in ['<sos>'] + tokenize_de(example['de']) + ['<eos>']])
trg = torch.tensor([trg_vocab[token] for token in ['<sos>'] + tokenize_en(example['en']) + ['<eos>']])
src_batch.append(src)
trg_batch.append(trg)
src_batch = pad_sequence(src_batch, padding_value=src_vocab['<pad>'])
trg_batch = pad_sequence(trg_batch, padding_value=trg_vocab['<pad>'])
return src_batch, trg_batch
train_loader = DataLoader(ds['train'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(ds['validation'], batch_size=BATCH_SIZE, collate_fn=collate_fn)
test_loader = DataLoader(ds['test'], batch_size=BATCH_SIZE, collate_fn=collate_fn)
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for src, trg in tqdm(iterator, desc='Training'):
src, trg = src.to(device), trg.to(device)
optimizer.zero_grad()
output, _ = model(src, trg)
output = output[1:].view(-1, output.shape[-1])
trg = trg[1:].view(-1)
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
model.eval()
epoch_loss = 0
with torch.no_grad():
for src, trg in tqdm(iterator, desc='Evaluating'):
src, trg = src.to(device), trg.to(device)
output, _ = model(src, trg, teacher_forcing_ratio=0)
output = output[1:].view(-1, output.shape[-1])
trg = trg[1:].view(-1)
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(iterator)
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
train_loss = train(model, train_loader, optimizer, criterion, CLIP)
valid_loss = evaluate(model, valid_loader, criterion)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'best-model.pt')
print(f'Epoch: {epoch+1:02}')
print(f'\tTrain Loss: {train_loss:.3f}')
print(f'\t Val. Loss: {valid_loss:.3f}')
model.load_state_dict(torch.load('best-model.pt'))
test_loss = evaluate(model, test_loader, criterion)
print(f'Test Loss: {test_loss:.3f}')
def translate_sentence(sentence, model, src_vocab, trg_vocab, device, max_len=50):
model.eval()
tokens = ['<sos>'] + tokenize_de(sentence) + ['<eos>']
src_indexes = [src_vocab[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(1).to(device)
with torch.no_grad():
encoder_outputs, hidden, cell = model.encoder(src_tensor)
trg_indexes = []
input_token = trg_vocab['<sos>']
attention_weights = []
for _ in range(max_len):
trg_tensor = torch.LongTensor([input_token]).to(device)
with torch.no_grad():
output, hidden, cell, attn_weights = model.decoder(trg_tensor, hidden, cell, encoder_outputs)
pred_token = output.argmax(1).item()
trg_indexes.append(pred_token)
attention_weights.append(attn_weights.squeeze().cpu().numpy())
input_token = pred_token
if pred_token == trg_vocab['<eos>']:
break
trg_tokens = [trg_vocab.itos[i] for i in trg_indexes[0:len(trg_indexes)-1]]
return ' '.join(trg_tokens), attention_weights
examples = [
"Ein Mann läuft auf der Straße.",
"Eine Frau liest ein Buch.",
"Kinder spielen im Park.",
"Ein Mann läuft auf der Straße und ein Hund folgt ihm.",
"Ein kleines Mädchen malt ein Bild.",
"Der Himmel ist blau und die Sonne scheint.",
"Eine schwarze Frau und ein weißer Mann arbeiten in einer Fabrikumgebung und packen Gläser mit Kerzen in Kartons.",
"Zwei Kinder spielen mit einem Ball im Garten.",
"Das Wetter ist heute sehr schön und die Vögel singen.",
"Ein älterer Herr sitzt auf einer Bank und füttert die Tauben.",
"Die Katze liegt auf dem Sofa und schläft.",
"Ein Junge fährt mit seinem Fahrrad durch die Stadt.",
"Eine Gruppe von Menschen tanzt auf einer Party.",
"Der Lehrer erklärt den Schülern die Mathematikaufgaben.",
"Das Kind baut eine Sandburg am Strand.",
"Eine Frau kocht in der Küche und bereitet das Abendessen zu.",
"Der Hund bellt laut, als der Postbote ankommt.",
"Ein Paar spaziert Hand in Hand im Park.",
"Die Kinder lachen und spielen im Schnee.",
"Ein Musiker spielt Gitarre auf der Straße."
]
for ex in examples:
translation, _ = translate_sentence(ex, model, src_vocab, trg_vocab, device)
print(f"德文: {ex}")
print(f"英文: {translation}\n")
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def plot_attention(attention, source_sentence, target_sentence):
fig, ax = plt.subplots(figsize=(10, 10))
attn = np.array(attention)
sns.heatmap(attn, cmap='viridis',
xticklabels=source_sentence,
yticklabels=target_sentence,
ax=ax,
square=True,
linewidths=0.5,
linecolor='gray')
ax.set_xticklabels(source_sentence, rotation=45, ha='right')
ax.set_yticklabels(target_sentence, rotation=0)
ax.set_xlabel("Source Words")
ax.set_ylabel("Target Words")
ax.set_title("Attention Weights")
plt.tight_layout()
plt.show()
example = 'Ein Musiker spielt Gitarre auf der Straße.'
translation, attention_weights = translate_sentence(example, model, src_vocab, trg_vocab, device)
translation
src_tokens = ['<sos>'] + tokenize_de(example) + ['<eos>']
trg_tokens = tokenize_en(translation) + ['<eos>']
plot_attention(attention_weights, src_tokens, trg_tokens)
5 figure
由于Bahdanau attention的一个核心的特点就是,能够计算一个attention score,所以我下面展示几个在翻译过程中的attention score的矩阵热图,方便大家理解。





有个有意思的地方是,target language中的 "A" 并没有与 "Ein"的相关性很高,而是跟名词的相关性很高,我觉得可能模型捕捉到了由名词或者语义来决定前面冠词(a/an/the)的形式。
6 结语
写在最后,Bahadanau attention的出现以及其他注意力机制的出现,使得transformer的出现“山雨欲来风满楼”了。由于Bahadanau是基于RNN的,有个最大的缺点就是,encoder的每一个时间步的隐藏状态是串联计算得来的,计算效率低。而横空出世的transformer改良了这个缺点,变成了完全并行,且输入和输出都完全基于注意力机制,所以计算效率更高,也能更好的建模长序列的依赖。下期将会介绍大名鼎鼎的Transformer。
7 感悟
实际上,我在至少半年之前就已经接触了transformer,但是我是从RNN直接跳到了transformer,而且学习的时候我的任务也不是机器翻译,但是过了半年之后,我从机器翻译的角度入手,从RNN到encoder-decoder架构,再到Bahadanau attention,再到transformer逻辑顺滑很多,学习起来也津津有味。一言以蔽之:在学习的路上,没有傻瓜,要么是老师不行,要么是切入角度不适合你自己,如是而已。