GAT(GRAPH ATTENTION NETWORKS)是一种使用了self attention机制图神经网络,该网络使用类似transformer里面self attention的方式计算图里面某个节点相对于每个邻接节点的注意力,将节点本身的特征和注意力特征concate起来作为该节点的特征,在此基础上进行节点的分类等任务。
下面是transformer self attention原理图:
GAT使用了类似的流程计算节点的self attention,首先计算当前节点和每个邻接节点的注意力score,然后使用该score乘以每个节点的特征,累加起来并经过一个非线性映射,作为当前节点的特征。
Attention score公式表示如下:
这里使用W矩阵将原始的特征映射到一个新的空间,a代表self attention的计算,如前面图2所示,这样计算出两个邻接节点的attention score,也就是Eij,然后对所有邻接节点的score进行softmax处理,得到归一化的attention score。
代码可以参考这个实现:https://github.com/gordicaleksa/pytorch-GAT
核心代码:
def forward(self, data):
in_nodes_features, connectivity_mask = data
num_of_nodes = in_nodes_features.shape[0]
in_nodes_features = self.dropout(in_nodes_features)
# V
nodes_features_proj = self.linear_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)
nodes_features_proj = self.dropout(nodes_features_proj)
# Q、K
scores_source = torch.sum((nodes_features_proj * self.scoring_fn_source), dim=-1, keepdim=True)
scores_target = torch.sum((nodes_features_proj * self.scoring_fn_target), dim=-1, keepdim=True)
scores_source = scores_source.transpose(0, 1)
scores_target = scores_target.permute(1, 2, 0)
# Q * K
all_scores = self.leakyReLU(scores_source + scores_target)
all_attention_coefficients = self.softmax(all_scores + connectivity_mask)
# Q * K * V
out_nodes_features = torch.bmm(all_attention_coefficients, nodes_features_proj.transpose(0, 1))
out_nodes_features = out_nodes_features.permute(1, 0, 2)
# in_nodes_features + out_nodes_features(attention)
out_nodes_features = self.skip_concat_bias(all_attention_coefficients, in_nodes_features, out_nodes_features)
return (out_nodes_features, connectivity_mask)
该GAT的实现也包含在了PYG库中,这个库涵盖了各种常见的图神经网络方面的论文算法实现。