1 core formula

可以看到偶数维度和奇数维度的括号内是相同的,只不过一个是正弦一个是余弦。所以在偶数位置[:, 0::2]使用sin(x)之后,只需要把[:, 1::2]使用cos(x)即可。还有一个我觉得比较重要的是,1/10000^(2i/d_model),其中10000这个数字被称为频率基数,随着维度增大,函数的波长会增大(在函数图像上看,sin(x)和cos(x)被拉长了),那么sin(x)/cos(x)对于position的敏感度越小。对于小维度,函数波长小(在函数图像上看,sin(x)和cos(x)被拉短了),所以在图像上看就是左边颜色变化快(函数值变化快),右边颜色变化慢。
那么有个问题:频率基数的作用是什么?
假设所有维度频率相同(如将基数设为 1):
所有维度的波长均为 2π≈6.28。
对于任何位置 pos 和 pos + 6,其编码向量几乎相同(因为 sin(pos+6)≈sin(pos))。
所以为了防止出现这种状况:
就有了多频率的解决方案
通过设置基数 10000:
高频维度(小 i):波长极短,确保相邻位置(如 1 和 2)的编码差异明显。
低频维度(大 i):波长极长,确保远距离位置(如 1 和 100)的编码仍有一定差异。
综合效果:不同位置的编码向量在所有维度上同时完全重复的概率极低,从而保证唯一性。
2 Codes:
# 注意力机制演示
import torch
import torch.nn as nn
# 假设输入
x = torch.randn(2, 4, 512) # [batch, seq_len, d_model]
n_heads = 8
d_k = 64
# 1. 线性变换 + 分头
# [2,8,4,64]
# [batch, n_heads, seq_len, d_k]
Q = nn.Linear(512, 512)(x).view(2, 4, 8, 64).transpose(1, 2) # [2,8,4,64]
K = nn.Linear(512, 512)(x).view(2, 4, 8, 64).transpose(1, 2) # [2,8,4,64]
V = nn.Linear(512, 512)(x).view(2, 4, 8, 64).transpose(1, 2) # [2,8,4,64]
Q.shape, K.shape, V.shape
# 2. 计算注意力(一个头为例)
head_idx = 0
Q_head = Q[:, head_idx] # [2,4,64]
K_head = K[:, head_idx] # [2,4,64]
V_head = V[:, head_idx] # [2,4,64]
# Q_head的形状是 [2, seq_len, d_k]
# K_head.transpose(-2, -1) 的形状是 [2, d_k, seq_len]
# scores 的形状是 [2, seq_len, seq_len]
# scores 需要除以 sqrt(d_k) 进行缩放,进行数值稳定。
scores = Q_head @ K_head.transpose(-2, -1) / torch.sqrt(torch.tensor(64.0))
# scores 的形状是 [2, seq_len, seq_len] V_head的形状是[2, seq_len, d_k]
# 最后的attn的形状是:[2, seq_len, d_k]
attn = torch.softmax(scores, dim=-1) @ V_head # [2,4,64]
attn.shape
# 3. 合并所有头
attention_heads = []
for i in range(8):
Q_head = Q[:, i]
K_head = K[:, i]
V_head = V[:, i]
attn_head = torch.softmax(Q_head @ K_head.transpose(-2,-1), dim=-1) @ V_head
attention_heads.append(attn_head)
len(attention_heads)
# 合并操作
stacked = torch.stack(attention_heads, dim=1) # [2,8,4,64]
combined = stacked.transpose(1, 2).contiguous() # [2,4,8,64]
combined = combined.view(2, 4, 512) # [2,4,512]
# 最终线性变换
output = nn.Linear(512, 512)(combined) # [2,4,512]
# 理解的话,就是Q和K计算一个类似于相关性的东西,V呢就是每一个输入的相当于隐藏状态的东西
# 然后加权聚合,能够突出某些重要相关的输入
# 然后如果有多个头的话,就是能够捕捉好多隐藏状态可能是主语谓语名词形容词巴拉巴拉的
# 最后融合一下
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
torch.arange(10).unsqueeze(1)
torch.arange(0, 4, 2)
def sinusoidal_positional_encoding(seq_len, d_model):
"""生成正弦余弦位置编码"""
position = torch.arange(seq_len).unsqueeze(1) # [seq_len, 1]
# 10000.0是频率基数
div_term = torch.exp(torch.arange(0, d_model, 2) *
-(np.log(10000.0) / d_model)) # [d_model/2]
pe = torch.zeros(seq_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度:sin
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度:cos
return pe
# 示例:序列长度20,维度64
seq_len, d_model = 20, 64
pe = sinusoidal_positional_encoding(seq_len, d_model)
print(f"位置编码形状: {pe.shape}") # [20, 64]
import seaborn as sns
def plot_positional_encoding(pe, seq_len=20, d_model=64):
"""Visualize positional encoding with heatmap and line plots
Args:
pe: Positional encoding matrix (shape: [max_seq_len, d_model])
seq_len: Number of positions to visualize
d_model: Dimension of the model
"""
plt.figure(figsize=(12, 8))
# 1. Heatmap (Position vs Dimension)
plt.subplot(2, 1, 1)
sns.heatmap(pe[:seq_len, :].numpy(),
cmap='RdBu', # Red-Blue colormap for positive/negative values
center=0, # White at zero
xticklabels=10, # Show every 10th dimension label
yticklabels=5) # Show every 5th position label
plt.title('Positional Encoding Heatmap (Position vs Dimension)')
plt.xlabel('Dimension Index')
plt.ylabel('Position Index')
# 2. Line Plots for First Few Dimensions
plt.subplot(2, 1, 2)
for i in range(6): # Plot first 6 dimensions
plt.plot(pe[:seq_len, i].numpy(),
label=f'Dim {i}',
linewidth=2)
plt.title('Positional Encoding Values Across Positions')
plt.xlabel('Position')
plt.ylabel('Encoding Value')
plt.legend()
plt.tight_layout()
plt.show()
# 生成并绘制
pe = sinusoidal_positional_encoding(seq_len=100, d_model=64)
plot_positional_encoding(pe, seq_len=100, d_model=64)
3 Figures

在频率基数=10000的情况下,低维度和高维度均有所差异

而基数为1的情况下,会出现很多位置编码相似的情况

我把维度设置的高了一些,可以看到,从左到右红色(波谷)和蓝色(波峰)的距离越来越远了,因为波长越来越大了。
4
位置编码除了能够编码绝对位置外(每一个位置有一串独一无二的位置编码),还能够编码token的相对位置,即能够通过下面公式的变换将PE(t+deltat) = Tdeltat * PE(t)
有能力的客官可以自己推导,通过正余弦公式和线性变换等得来(我也不懂,😂,数学太差)

5 强力推荐阅读
https://zhuanlan.zhihu.com/p/454482273
是个大佬哈哈哈哈哈哈哈,膜拜