attention热度图可视化

记录使用bert 编码的序列,获取其中的 self-attention层矩阵,使用热度图可视化,可视化代码参考博客:attention机制的热度图
获取 编码层:

from transformers import BertTokenizer, BertModel
import torch
import numpy
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
print(inputs)
outputs = model(**inputs,output_attentions=True)
last_hidden_states = outputs.last_hidden_state
attention=outputs.attentions 
print(len(attention),attention[0].shape)
result=attention[0][0][0]
print(result.shape)
print(result)
result1=result.detach().numpy()

可视化:

import matplotlib.pyplot as plt
import numpy
def plot_attention(data, X_label=None, Y_label=None):
  '''
    Plot the attention model heatmap
    Args:
      data: attn_matrix with shape [ty, tx], cutted before 'PAD'
      X_label: list of size tx, encoder tags
      Y_label: list of size ty, decoder tags
  '''
  fig, ax = plt.subplots(figsize=(20, 8)) # set figure size
  heatmap = ax.pcolor(data, cmap=plt.cm.Blues, alpha=0.9)
  
  # Set axis labels
  if X_label != None and Y_label != None:
    X_label = [x_label for x_label in X_label]
    Y_label = [y_label for y_label in Y_label]
    
    xticks = range(0,len(X_label))
    ax.set_xticks(xticks, minor=False) # major ticks
    ax.set_xticklabels(X_label, minor = False, rotation=45)   # labels should be 'unicode'
    
    yticks = range(0,len(Y_label))
    ax.set_yticks(yticks, minor=False)
    ax.set_yticklabels(Y_label, minor = False)   # labels should be 'unicode'
    
    ax.grid(True)
plot_attention(result.detach().numpy(),['CLS', 'Hello',',', 'my', 'dog', 'is', 'cute', 'SEP'],['CLS', 'Hello',',', 'my', 'dog', 'is', 'cute', 'SEP'])

X_label ,Y_label 是输入编码序列的词 list



只是演示记录,不体现attention 效果

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容