从零理解图注意力网络(GAT):理论推导与 PyTorch 实现
最近在看论文需要用到图神经网络,学习过程干脆就把算法推导和代码实现记录下来,下面连接是我理解图神经网络GAT的过程链接地址
图卷积网络(GCN)通过归一化邻接矩阵实现了高效的邻居信息聚合,但其聚合权重完全由图结构(度数)决定,缺乏灵活性。图注意力网络(GAT)则引入可学习的注意力机制,让模型动态决定“应该关注哪些邻居”,从而在异配图、噪声边等场景下表现更优。
本文将:
- 从数学原理出发,手算一个 4 节点小图的 GAT 注意力过程;
- 对比原始 GAT 的拼接注意力与 Transformer 风格的点积注意力;
- 在 Cora 数据集上进行公平实验,分析性能差异;
- 提供两种注意力机制的完整 PyTorch 实现。
一、GAT 理论笔记
1. 基本符号与数据示例
节点特征矩阵
:
其中第一列是年龄,第二列是分数。
邻接矩阵
:
表示节点
与
相连(例如是朋友)。
另外约定如果 是向量,
表示向量
的第
个元素,如果
是矩阵,
表示向量
的第
行往往表示第
个样本或者第
个节点
💡 注意:数据与 GCN 示例略有不同,单纯是草稿本上誊抄的时候写错了。
2. GAT 核心思想:可学习的邻居权重
GCN 的聚合权重为 ,仅由度数决定。
而 GAT 认为:即使两个邻居度数相同,若特征与中心节点更“匹配”,就应赋予更高权重。
为此,GAT 引入注意力机制,动态计算每条边的重要性。
3. GAT核心公式分步拆解
Step1. 进行线性变换
首先对输入的节点特征矩阵进行线性变换:
其中是第
层输入,
是可以学习的权重参数。
Step2. 计算邻居间的注意力系数
然后拼接并用可学习向量 打分:
其中是可以学习的长度为2
的参数向量,
表示把节点
和节点
的特征向量拼接到一起,
表示节点
对节点
的注意力系数
下面是手动计算,我们只计算第一层,当
时
,为了方便手工计算假设
,即使
,且
(1)先给邻间矩阵A添加自环得到
(2)由于节点1和自身是连接的,把他们拼接到一起得到
节点1和节点2,节点1和节点4也是好朋友,同样拼接得到
(3)计算节点1和其邻居的注意力系数
Step3. 计算注意力得分公式为
接着上面的例子
Step4. 加权聚合其公式为
其中是非线性激活函数如
,节点1聚合后的特征为
注意:之所以用的是
是为因为我们前面说了为了方面手工计算假设的
类似的可以得到其余三个节点的最终输出,则最终的输出
至此GAT前向更新过程就全部结束,接下来是代码部分
✅ 关键优势:权重由特征决定,而非图结构!
二. 矩阵化推导
在写代码的时候计算注意力系数的时候不要用两个for循环去扫,效率太低,跟上一篇GCN一样可以矩阵化实现,先从数学推导开始
设节点1和2的特征分别为
那么节点1和节点2的拼接注意力系数为
注意到红色部分只与即源节点有关,蓝色部分只与
即目标节点有关那
不就等于
吗
把拆分成两个部分
分别对应代码self.a_left = nn.Parameter(torch.empty(out_features))和self.a_right = nn.Parameter(torch.empty(out_features))
那么用整个节点特征矩阵分别乘以
和
可以得到所有源节点系数和目标节点系数
分别对应代码source = Wh @ self.a_left.unsqueeze(1)和targrt= Wh @ self.a_right.unsqueeze(1)
令矩阵对应代码
attention_logits = source + target.T
这样,然后用邻间矩阵作为掩码矩阵就可以一个循环都不用
三. 为什么不用点乘注意力呢?
1. 点积注意力(Transformer 风格)
在学习GAT的时候我就在想为什么论文作者不像《Attention is All You Need》直接做样本间的注意力而要用拼接注意力呢?于是我按照Transformer 的风格写了一版(就叫点乘注意力好了)在cora上做了一下对比实验以深度理解原始 GAT 论文(Veličković et al., ICLR 2018)的动机。
按照《Attention is All You Need》我是这样设计Transformer 风格的注意力机制的:
即:
- 查询(Query):
- 键(Key):
- 得分:
再经 LeakyReLU + softmax 得到注意力权重。
❓ 疑问:既然 Transformer 成功了,为何 GAT 不直接用点积?
2. Cora 实验:拼接 vs 点积注意力
先把实验结果放前面再来分析其原因
我们在标准引文网络 Cora 上训练两种 GAT 变体(两层,统一使用邻接矩阵表示,公平对比):
实验设置
- 输入维度:1433(词袋)
- 隐藏维度:8
- Dropout: 0.6
- 优化器:Adam, lr=0.005
- 训练/验证/测试划分:标准 140/500/1000
结果对比
| 指标 | 拼接注意力(GAT 原版) | 点积注意力 |
|---|---|---|
| 最佳验证准确率 | 0.8985 | 0.8875 |
| 测试准确率 | 0.8672 | 0.8616 |
| 参数数量 | 11,550 | 46,080 |
| 训练时间(秒) | 53.56 | 135.58 |
| 收敛轮次 | 149 | 186 |
分析
- 性能:拼接注意力略优(+1.6% 测试准确率),说明其更适合图结构;
-
效率:点积版本参数更多(需独立
),训练慢接近 3 倍;
- 收敛:拼接注意力更快达到最优,优化更稳定。
💡 结论:尽管点积在 NLP 中成功,但在局部邻居聚合任务中,显式建模节点对交互的拼接机制更有效。
从结果反推为什么拼接注意力(GAT 原版)仅用
的参数无论是在准确率还是训练时间都完爆点积注意力呢?
3. 实验结果分析
| 对比项 | 拼接注意力 | 点积注意力 |
|---|---|---|
| 交互建模 | 显式建模 |
隐式依赖相似性 |
| 对称性 | 天然支持非对称(可通过 |
点积对称: |
| 表达能力 | 更强(双线性+偏置) | 较弱(仅内积) |
| 归纳偏置 | “任意关系都可学” | “相似即重要” |
📌 在图中,“相似即重要”不一定成立。例如:预测“导师-学生”关系时,差异大的节点反而更相关。
四、注意力可视化

随机选择6个节点,把训练后GAT的注意力得分用三种方法可视化
1. 注意力分布条形图(前6个子图)
展示每个选定节点对其邻居的注意力权重分布:
- Y轴:邻居节点(N0、N1等)
- X轴:注意力权重值(0-1)
- 颜色:表示邻居节点的真实类别
- 观察:可以看出不同节点关注不同的邻居模式
2. 注意力矩阵热力图(左下角第7个子图)
展示选定节点间的相互注意力:
- 颜色深浅:表示注意力权重大小(颜色越深权重越高)
- 对角线:自注意力(节点对自身的关注)
- 非对角线:节点间的相互注意力(由于GAT只关注邻居,非邻居位置为0)
- 重要发现:热力图大部分为0,这是因为GAT的局部性约束——只关注直接邻居
3. 图结构可视化(右下角第8个子图)
直观展示节点连接关系和注意力强度:
- 节点大小:中心节点(选定)较大
- 节点颜色:表示节点类别
- 边粗细:边越粗表示注意力权重越高
- 边颜色:统一为灰色
- 观察:可以直观看出哪些连接对模型更重要
五、GraphAttentionLayer 完整代码实现
1. 拼接注意力(原始 GAT)
💡注意:无论是什么图网络,成熟库的实现都是用的边索引edge_index而不是邻间矩阵adj_matrix
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphAttentionLayer(nn.Module):
def __init__(self, in_features, out_features, negative_slope=0.2):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.negative_slope = negative_slope
# 线性变换
self.linear = nn.Linear(in_features, out_features, bias=False)
# 注意力向量
self.a_left = nn.Parameter(torch.empty(out_features))
self.a_right = nn.Parameter(torch.empty(out_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.linear.weight)
nn.init.xavier_uniform_(self.a_left.unsqueeze(0))
nn.init.xavier_uniform_(self.a_right.unsqueeze(0))
def forward(self, x, adj, add_self_loop=True):
N = x.size(0)
device = x.device
# 添加自环
if add_self_loop:
adj = adj + torch.eye(N, device=device)
# 线性变换
Wh = self.linear(x) # [N, F']
# 计算注意力
source = Wh @ self.a_left.unsqueeze(1) # [N, 1]
target = Wh @ self.a_right.unsqueeze(1) # [N, 1]
attention_logits = source + target.T # [N, N]
# LeakyReLU
attention_logits = F.leaky_relu(attention_logits, negative_slope=self.negative_slope)
# 掩码:只保留邻居
mask = -1e9 * (1 - adj)
attention_logits = attention_logits + mask
# Softmax 归一化
attention = F.softmax(attention_logits, dim=1) # [N, N]
# 聚合
h_prime = attention @ Wh # [N, F']
# 激活
h_prime = F.elu(h_prime)
return h_prime
2. 点乘注意力(Transformer风格)
class GraphDotAttentionLayer(nn.Module):
"""点积注意力:基于邻接矩阵"""
def __init__(self, in_features, out_features, negative_slope=0.2):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.negative_slope = negative_slope
# 投影层(Q/K/V 共享或分离)
self.q_proj = nn.Linear(in_features, out_features, bias=False)
self.k_proj = nn.Linear(in_features, out_features, bias=False)
self.v_proj = nn.Linear(in_features, out_features, bias=False)
self.linear_out = nn.Linear(in_features, out_features, bias=False)
self.reset_parameters()
def reset_parameters(self):
for layer in [self.q_proj, self.k_proj, self.v_proj, self.linear_out]:
nn.init.xavier_uniform_(layer.weight)
def forward(self, x, adj):
N = x.size(0)
Q = self.q_proj(x) # [N, out_features]
K = self.k_proj(x) # [N, out_features]
scores = (Q @ K.T)/x.size(1) # [N, N]
scores = F.leaky_relu(scores, negative_slope=self.negative_slope)
mask = -1e9 * (1 - adj)
scores = scores + mask
attention = F.softmax(scores, dim=1) # [N, N]
aggregated = attention @ x # [N, in_features]
out = self.linear_out(aggregated + x) # [N, out_features]
out = F.elu(out)
return out