2026-01-25

从零理解图注意力网络(GAT):理论推导与 PyTorch 实现

最近在看论文需要用到图神经网络,学习过程干脆就把算法推导和代码实现记录下来,下面连接是我理解图神经网络GAT的过程链接地址
图卷积网络(GCN)通过归一化邻接矩阵实现了高效的邻居信息聚合,但其聚合权重完全由图结构(度数)决定,缺乏灵活性。图注意力网络(GAT)则引入可学习的注意力机制,让模型动态决定“应该关注哪些邻居”,从而在异配图、噪声边等场景下表现更优。

本文将:

  • 从数学原理出发,手算一个 4 节点小图的 GAT 注意力过程;
  • 对比原始 GAT 的拼接注意力与 Transformer 风格的点积注意力
  • 在 Cora 数据集上进行公平实验,分析性能差异;
  • 提供两种注意力机制的完整 PyTorch 实现。

一、GAT 理论笔记

1. 基本符号与数据示例

节点特征矩阵 X

X = \begin{pmatrix} 18 & 85 \\ 19 & 92 \\ 20 & 78 \\ 18 & 88 \end{pmatrix}

其中第一列是年龄,第二列是分数。

邻接矩阵 A

A = \begin{pmatrix} 0 & 1 & 0 & 1 \\ 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 0 \\ 1 & 0 & 0 & 0 \end{pmatrix}

A_{ij} = 1表示节点 ij 相连(例如是朋友)。
另外约定如果h 是向量,h_{i}表示向量h的第i个元素,如果h 是矩阵,h_{i}表示向量h的第i行往往表示第i个样本或者第i个节点

💡 注意:数据与 GCN 示例略有不同,单纯是草稿本上誊抄的时候写错了。


2. GAT 核心思想:可学习的邻居权重

GCN 的聚合权重为 \frac{1}{\sqrt{d_i d_j}} ,仅由度数决定。
而 GAT 认为:即使两个邻居度数相同,若特征与中心节点更“匹配”,就应赋予更高权重

为此,GAT 引入注意力机制,动态计算每条边的重要性。


3. GAT核心公式分步拆解

Step1. 进行线性变换

首先对输入的节点特征矩阵进行线性变换:
h^{l+1} = W h^{l}
其中h^{l}是第l层输入,W是可以学习的权重参数。

Step2. 计算邻居间的注意力系数

然后拼接并用可学习向量 \mathbf{a} 打分:

e_{ij} = LeakyRelu(\mathbf{a}^\top [Wh_i^{l} \| Wh_j^{l}])
其中\mathbf{a}^\top是可以学习的长度为2f的参数向量, \|表示把节点i和节点j的特征向量拼接到一起,e_{ij}表示节点i对节点j的注意力系数

下面是手动计算,我们只计算第一层,当l=1h^1=X,为了方便手工计算假设W=I,即使Wh=X,且\mathbf{a}^\top=(0.3,0.2,0.3,0.2)

(1)先给邻间矩阵A添加自环得到\tilde{A}
\tilde{A} =A+I= \begin{pmatrix} 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 \\ 0 & 1 & 1 & 0 \\ 1 & 0 & 0 & 1 \end{pmatrix}
(2)由于节点1和自身是连接的,把他们拼接到一起得到
v_{11} = (18, 85,18,85)
节点1和节点2,节点1和节点4也是好朋友,同样拼接得到
v_{12} = (18, 85,19,92)\\ v_{14} = (18,85, 18,88)
(3)计算节点1和其邻居的注意力系数
e_{11} = 0.3 \times 18 + 0.2 \times 85 + 0.3 \times 18 + 0.2 \times 85 = 44.8 \\ e_{13} = 0.3 \times 18 + 0.2 \times 85 + 0.3 \times 19 + 0.2 \times 92 = 46.5 \\ e_{14} = 0.3 \times 18 + 0.2 \times 85 + 0.3 \times 18 + 0.2 \times 88 = 45.4

Step3. 计算注意力得分公式为

\alpha_{ij}=\frac{exp(e_{ij})}{\sum_{k}exp(e_{ik})}

接着上面的例子

\alpha_{11}=\frac{exp(44.8)}{exp(44.8)+exp(46.5)+exp(45.4)}=0.12
\alpha_{12}=\frac{exp(46.5)}{exp(44.8)+exp(46.5)+exp(45.4)}=0.66
\alpha_{14}=\frac{exp(45.4)}{exp(44.8)+exp(46.5)+exp(45.4)}=0.22

Step4. 加权聚合其公式为

h^{new} = \sigma(\sum_{j}\alpha_{ij}Wh_j)
其中\sigma是非线性激活函数如ELU,节点1聚合后的特征为

h_1^{\text{new}} = \alpha_{11} \cdot X_1 + \alpha_{12} \cdot X_2 + \alpha_{14} \cdot X_4\\ =0.12 \cdot (18,85) + 0.66 \cdot (19,92) + 0.22 \cdot (18,88)=(18.66,90.28)

注意:之所以用的是X_i是为因为我们前面说了为了方面手工计算假设的Wh=X

类似的可以得到其余三个节点的最终输出h_2^{\text{new}},h_3^{\text{new}},h_4^{\text{new}},则最终的输出
h^{\text{new}} =\begin{pmatrix} h_1^{\text{new}} \\ h_2^{\text{new}} \\ h_3^{\text{new}} \\ h_4^{\text{new}} \end{pmatrix}
至此GAT前向更新过程就全部结束,接下来是代码部分

关键优势:权重由特征决定,而非图结构!


二. 矩阵化推导

在写代码的时候计算注意力系数e_{ij}的时候不要用两个for循环去扫,效率太低,跟上一篇GCN一样可以矩阵化实现,先从数学推导开始
设节点1和2的特征分别为h_1=(x_{11},x_{12}),h_2=(x_{21},x_{22}),\mathbf{a}^\top=(a_1,a_2,a_3,a_4)
那么节点1和节点2的拼接注意力系数为
e_{12} = \textcolor{red}{(a_1 x_{11} + a_2 x_{12})} + \textcolor{blue}{(a_3 x_{21} + a_4 x_{22})}
注意到红色部分只与h_1即源节点有关,蓝色部分只与h_2即目标节点有关那e_{12}不就等于e_{1}+e_{2}
\mathbf{a}^\top拆分成两个部分
\mathbf{a}^\top_{left}=(a_1,a_2)\\ \mathbf{a}^\top_{right}=(a_3,a_4)
分别对应代码self.a_left = nn.Parameter(torch.empty(out_features))self.a_right = nn.Parameter(torch.empty(out_features))

那么用整个节点特征矩阵h分别乘以\mathbf{a}^\top_{left}\mathbf{a}^\top_{right}可以得到所有源节点系数和目标节点系数
e_{source }=h\cdot\mathbf{a}^\top_{left}\\ e_{target}=h\cdot\mathbf{a}^\top_{right}
分别对应代码source = Wh @ self.a_left.unsqueeze(1)targrt= Wh @ self.a_right.unsqueeze(1)
令矩阵E=e_{source }\cdot1^\top+1\cdot e_{target}^\top对应代码attention_logits = source + target.T
这样e_{ij}=E[i,j],然后用邻间矩阵作为掩码矩阵就可以一个循环都不用

三. 为什么不用点乘注意力呢?

1. 点积注意力(Transformer 风格)

在学习GAT的时候我就在想为什么论文作者不像《Attention is All You Need》直接做样本间的注意力而要用拼接注意力呢?于是我按照Transformer 的风格写了一版(就叫点乘注意力好了)在cora上做了一下对比实验以深度理解原始 GAT 论文(Veličković et al., ICLR 2018)的动机。

按照《Attention is All You Need》我是这样设计Transformer 风格的注意力机制的:
e_{ij} = (W_q h_i)^\top (W_k h_j)
即:

  • 查询(Query): Q = W_q X
  • 键(Key): K = W_k X
  • 得分: E = Q K^\top

再经 LeakyReLU + softmax 得到注意力权重。

疑问:既然 Transformer 成功了,为何 GAT 不直接用点积?


2. Cora 实验:拼接 vs 点积注意力

先把实验结果放前面再来分析其原因
我们在标准引文网络 Cora 上训练两种 GAT 变体(两层,统一使用邻接矩阵表示,公平对比):

实验设置

  • 输入维度:1433(词袋)
  • 隐藏维度:8
  • Dropout: 0.6
  • 优化器:Adam, lr=0.005
  • 训练/验证/测试划分:标准 140/500/1000

结果对比

指标 拼接注意力(GAT 原版) 点积注意力
最佳验证准确率 0.8985 0.8875
测试准确率 0.8672 0.8616
参数数量 11,550 46,080
训练时间(秒) 53.56 135.58
收敛轮次 149 186

分析

  1. 性能:拼接注意力略优(+1.6% 测试准确率),说明其更适合图结构;
  2. 效率:点积版本参数更多(需独立 W_q, W_k, W_v ),训练慢接近 3 倍;
  3. 收敛:拼接注意力更快达到最优,优化更稳定。

💡 结论:尽管点积在 NLP 中成功,但在局部邻居聚合任务中,显式建模节点对交互的拼接机制更有效

从结果反推为什么拼接注意力(GAT 原版)仅用\frac{1}{4}的参数无论是在准确率还是训练时间都完爆点积注意力呢?

3. 实验结果分析

对比项 拼接注意力 点积注意力
交互建模 显式建模 (i,j) 对的联合特征 隐式依赖相似性
对称性 天然支持非对称(可通过 \mathbf{a}_{\text{left}}, \mathbf{a}_{\text{right}} 点积对称:e_{ij} = e_{ji}
表达能力 更强(双线性+偏置) 较弱(仅内积)
归纳偏置 “任意关系都可学” “相似即重要”

📌 在图中,“相似即重要”不一定成立。例如:预测“导师-学生”关系时,差异大的节点反而更相关。


四、注意力可视化

Figure 2026-01-25 001846.png

随机选择6个节点,把训练后GAT的注意力得分用三种方法可视化

1. 注意力分布条形图(前6个子图)

展示每个选定节点对其邻居的注意力权重分布:

  • Y轴:邻居节点(N0、N1等)
  • X轴:注意力权重值(0-1)
  • 颜色:表示邻居节点的真实类别
  • 观察:可以看出不同节点关注不同的邻居模式

2. 注意力矩阵热力图(左下角第7个子图)

展示选定节点间的相互注意力:

  • 颜色深浅:表示注意力权重大小(颜色越深权重越高)
  • 对角线:自注意力(节点对自身的关注)
  • 非对角线:节点间的相互注意力(由于GAT只关注邻居,非邻居位置为0)
  • 重要发现:热力图大部分为0,这是因为GAT的局部性约束——只关注直接邻居

3. 图结构可视化(右下角第8个子图)

直观展示节点连接关系和注意力强度:

  • 节点大小:中心节点(选定)较大
  • 节点颜色:表示节点类别
  • 边粗细:边越粗表示注意力权重越高
  • 边颜色:统一为灰色
  • 观察:可以直观看出哪些连接对模型更重要

五、GraphAttentionLayer 完整代码实现

1. 拼接注意力(原始 GAT)

💡注意:无论是什么图网络,成熟库的实现都是用的边索引edge_index而不是邻间矩阵adj_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, negative_slope=0.2):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.negative_slope = negative_slope

        # 线性变换
        self.linear = nn.Linear(in_features, out_features, bias=False)
        # 注意力向量 
        self.a_left = nn.Parameter(torch.empty(out_features))
        self.a_right = nn.Parameter(torch.empty(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.xavier_uniform_(self.a_left.unsqueeze(0))
        nn.init.xavier_uniform_(self.a_right.unsqueeze(0))

    def forward(self, x, adj, add_self_loop=True):
        N = x.size(0)
        device = x.device
        # 添加自环
        if add_self_loop:
            adj = adj + torch.eye(N, device=device)
        # 线性变换
        Wh = self.linear(x)  # [N, F']
        # 计算注意力 
        source = Wh @ self.a_left.unsqueeze(1)      # [N, 1]
        target = Wh @ self.a_right.unsqueeze(1)     # [N, 1]
        attention_logits = source + target.T        # [N, N]
        # LeakyReLU
        attention_logits = F.leaky_relu(attention_logits, negative_slope=self.negative_slope)
        # 掩码:只保留邻居
        mask = -1e9 * (1 - adj)
        attention_logits = attention_logits + mask
        # Softmax 归一化
        attention = F.softmax(attention_logits, dim=1)  # [N, N]
        # 聚合
        h_prime = attention @ Wh  # [N, F']
        # 激活
        h_prime = F.elu(h_prime)
        return h_prime

2. 点乘注意力(Transformer风格)

class GraphDotAttentionLayer(nn.Module):
    """点积注意力:基于邻接矩阵"""
    def __init__(self, in_features, out_features, negative_slope=0.2):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.negative_slope = negative_slope
        # 投影层(Q/K/V 共享或分离)
        self.q_proj = nn.Linear(in_features, out_features, bias=False)
        self.k_proj = nn.Linear(in_features, out_features, bias=False)
        self.v_proj = nn.Linear(in_features, out_features, bias=False)
        self.linear_out = nn.Linear(in_features, out_features, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        for layer in [self.q_proj, self.k_proj, self.v_proj, self.linear_out]:
            nn.init.xavier_uniform_(layer.weight)

    def forward(self, x, adj):
        N = x.size(0)
        
        Q = self.q_proj(x)  # [N, out_features]
        K = self.k_proj(x)  # [N, out_features]
        scores = (Q @ K.T)/x.size(1)  # [N, N]
        scores = F.leaky_relu(scores, negative_slope=self.negative_slope)      
        mask = -1e9 * (1 - adj)
        scores = scores + mask
        attention = F.softmax(scores, dim=1)  # [N, N]
        aggregated = attention @ x  # [N, in_features]
        out = self.linear_out(aggregated + x)  # [N, out_features]
        out = F.elu(out)
        
        return out 
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容