多头注意力机制演示以及位置编码

1 core formula

postitional encoding

可以看到偶数维度和奇数维度的括号内是相同的,只不过一个是正弦一个是余弦。所以在偶数位置[:, 0::2]使用sin(x)之后,只需要把[:, 1::2]使用cos(x)即可。还有一个我觉得比较重要的是,1/10000^(2i/d_model),其中10000这个数字被称为频率基数,随着维度增大,函数的波长会增大(在函数图像上看,sin(x)和cos(x)被拉长了),那么sin(x)/cos(x)对于position的敏感度越小。对于小维度,函数波长小(在函数图像上看,sin(x)和cos(x)被拉短了),所以在图像上看就是左边颜色变化快(函数值变化快),右边颜色变化慢。

那么有个问题:频率基数的作用是什么?
假设所有维度频率相同(如将基数设为 1):
所有维度的波长均为 2π≈6.28。
对于任何位置 pos 和 pos + 6,其编码向量几乎相同(因为 sin(pos+6)≈sin(pos))。

所以为了防止出现这种状况:
就有了多频率的解决方案
通过设置基数 10000:
高频维度(小 i):波长极短,确保相邻位置(如 1 和 2)的编码差异明显。
低频维度(大 i):波长极长,确保远距离位置(如 1 和 100)的编码仍有一定差异。
综合效果:不同位置的编码向量在所有维度上同时完全重复的概率极低,从而保证唯一性。

2 Codes:

# 注意力机制演示

import torch
import torch.nn as nn

# 假设输入
x = torch.randn(2, 4, 512)  # [batch, seq_len, d_model]
n_heads = 8
d_k = 64

# 1. 线性变换 + 分头 
# [2,8,4,64]
# [batch, n_heads, seq_len, d_k]
Q = nn.Linear(512, 512)(x).view(2, 4, 8, 64).transpose(1, 2)  # [2,8,4,64]
K = nn.Linear(512, 512)(x).view(2, 4, 8, 64).transpose(1, 2)  # [2,8,4,64]  
V = nn.Linear(512, 512)(x).view(2, 4, 8, 64).transpose(1, 2)  # [2,8,4,64]

Q.shape, K.shape, V.shape


# 2. 计算注意力(一个头为例)
head_idx = 0
Q_head = Q[:, head_idx]  # [2,4,64]
K_head = K[:, head_idx]  # [2,4,64]
V_head = V[:, head_idx]  # [2,4,64]

# Q_head的形状是 [2, seq_len, d_k] 
# K_head.transpose(-2, -1) 的形状是 [2, d_k, seq_len]
# scores 的形状是 [2, seq_len, seq_len]
# scores 需要除以 sqrt(d_k) 进行缩放,进行数值稳定。
scores = Q_head @ K_head.transpose(-2, -1) / torch.sqrt(torch.tensor(64.0))
# scores 的形状是 [2, seq_len, seq_len] V_head的形状是[2, seq_len, d_k] 
# 最后的attn的形状是:[2, seq_len, d_k]
attn = torch.softmax(scores, dim=-1) @ V_head  # [2,4,64]
attn.shape

# 3. 合并所有头
attention_heads = []
for i in range(8):
    Q_head = Q[:, i]
    K_head = K[:, i]
    V_head = V[:, i]
    attn_head = torch.softmax(Q_head @ K_head.transpose(-2,-1), dim=-1) @ V_head
    attention_heads.append(attn_head)

len(attention_heads)

# 合并操作
stacked = torch.stack(attention_heads, dim=1)  # [2,8,4,64]
combined = stacked.transpose(1, 2).contiguous()  # [2,4,8,64]
combined = combined.view(2, 4, 512)  # [2,4,512]

# 最终线性变换
output = nn.Linear(512, 512)(combined)  # [2,4,512]

# 理解的话,就是Q和K计算一个类似于相关性的东西,V呢就是每一个输入的相当于隐藏状态的东西
# 然后加权聚合,能够突出某些重要相关的输入
# 然后如果有多个头的话,就是能够捕捉好多隐藏状态可能是主语谓语名词形容词巴拉巴拉的
# 最后融合一下



import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

torch.arange(10).unsqueeze(1)
torch.arange(0, 4, 2)

def sinusoidal_positional_encoding(seq_len, d_model):
    """生成正弦余弦位置编码"""
    position = torch.arange(seq_len).unsqueeze(1)  # [seq_len, 1]
# 10000.0是频率基数
    div_term = torch.exp(torch.arange(0, d_model, 2) * 
                        -(np.log(10000.0) / d_model))  # [d_model/2]
    
    pe = torch.zeros(seq_len, d_model)
    pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度:sin
    pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度:cos
    
    return pe

# 示例:序列长度20,维度64
seq_len, d_model = 20, 64
pe = sinusoidal_positional_encoding(seq_len, d_model)
print(f"位置编码形状: {pe.shape}")  # [20, 64]

import seaborn as sns

def plot_positional_encoding(pe, seq_len=20, d_model=64):
    """Visualize positional encoding with heatmap and line plots
    
    Args:
        pe: Positional encoding matrix (shape: [max_seq_len, d_model])
        seq_len: Number of positions to visualize
        d_model: Dimension of the model
    """
    plt.figure(figsize=(12, 8))
    
    # 1. Heatmap (Position vs Dimension)
    plt.subplot(2, 1, 1)
    sns.heatmap(pe[:seq_len, :].numpy(), 
                cmap='RdBu',  # Red-Blue colormap for positive/negative values
                center=0,     # White at zero
                xticklabels=10,  # Show every 10th dimension label
                yticklabels=5)   # Show every 5th position label
    plt.title('Positional Encoding Heatmap (Position vs Dimension)')
    plt.xlabel('Dimension Index')
    plt.ylabel('Position Index')
    
    # 2. Line Plots for First Few Dimensions
    plt.subplot(2, 1, 2)
    for i in range(6):  # Plot first 6 dimensions
        plt.plot(pe[:seq_len, i].numpy(), 
                label=f'Dim {i}', 
                linewidth=2)
    plt.title('Positional Encoding Values Across Positions')
    plt.xlabel('Position')
    plt.ylabel('Encoding Value')
    plt.legend()
    plt.tight_layout()
    plt.show()

# 生成并绘制
pe = sinusoidal_positional_encoding(seq_len=100, d_model=64)
plot_positional_encoding(pe, seq_len=100, d_model=64)

3 Figures

频率基数=10000

在频率基数=10000的情况下,低维度和高维度均有所差异

频率基数=1

而基数为1的情况下,会出现很多位置编码相似的情况

dim=1024

我把维度设置的高了一些,可以看到,从左到右红色(波谷)和蓝色(波峰)的距离越来越远了,因为波长越来越大了。

4

位置编码除了能够编码绝对位置外(每一个位置有一串独一无二的位置编码),还能够编码token的相对位置,即能够通过下面公式的变换将PE(t+deltat) = Tdeltat * PE(t)
有能力的客官可以自己推导,通过正余弦公式和线性变换等得来(我也不懂,😂,数学太差)


相对位置

5 强力推荐阅读

https://zhuanlan.zhihu.com/p/454482273
是个大佬哈哈哈哈哈哈哈,膜拜

6 旋转位置矩阵

https://blog.csdn.net/v_JULY_v/article/details/134085503

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容