记录使用bert 编码的序列,获取其中的 self-attention层矩阵,使用热度图可视化,可视化代码参考博客:attention机制的热度图
获取 编码层:
from transformers import BertTokenizer, BertModel
import torch
import numpy
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
print(inputs)
outputs = model(**inputs,output_attentions=True)
last_hidden_states = outputs.last_hidden_state
attention=outputs.attentions
print(len(attention),attention[0].shape)
result=attention[0][0][0]
print(result.shape)
print(result)
result1=result.detach().numpy()
可视化:
import matplotlib.pyplot as plt
import numpy
def plot_attention(data, X_label=None, Y_label=None):
'''
Plot the attention model heatmap
Args:
data: attn_matrix with shape [ty, tx], cutted before 'PAD'
X_label: list of size tx, encoder tags
Y_label: list of size ty, decoder tags
'''
fig, ax = plt.subplots(figsize=(20, 8)) # set figure size
heatmap = ax.pcolor(data, cmap=plt.cm.Blues, alpha=0.9)
# Set axis labels
if X_label != None and Y_label != None:
X_label = [x_label for x_label in X_label]
Y_label = [y_label for y_label in Y_label]
xticks = range(0,len(X_label))
ax.set_xticks(xticks, minor=False) # major ticks
ax.set_xticklabels(X_label, minor = False, rotation=45) # labels should be 'unicode'
yticks = range(0,len(Y_label))
ax.set_yticks(yticks, minor=False)
ax.set_yticklabels(Y_label, minor = False) # labels should be 'unicode'
ax.grid(True)
plot_attention(result.detach().numpy(),['CLS', 'Hello',',', 'my', 'dog', 'is', 'cute', 'SEP'],['CLS', 'Hello',',', 'my', 'dog', 'is', 'cute', 'SEP'])
X_label ,Y_label 是输入编码序列的词 list
只是演示记录,不体现attention 效果