GAT(Graph Attention Network)
GitHub项目(GAT[keras版] GAT[pytotch版] GAT[tensorflow版])
该项目做的任务仍是图中节点分类问题,语料仍是Cora
1.下载代码,并上传到服务器解压
unzip pyGAT-master.zip
2.选择或安装运行该程序需要的环境
pyGAT relies on Python 3.5 and PyTorch 0.4.1 (due to torch.sparse_coo_tensor).
激活环境 source activate pt_env
3.进入pyGAT-master目录,运行:Python main.py
以上操作,运行成功!!!
开始代码解剖
1.超参设置
2.加载数据
idx_features_labels [0]是节点id [1-1433]是节点的one-hot特征向量 [1434]是节点的label标签。 这个数据是从文件data/cora/cora.content文件中读出来的。
将刚才加载的idx_features_labels数据,取出features部分,用稀疏矩阵的形式存储;取出labels部分,转换成one-hot多分类向量。
从data/cora/cora.cites里读入数据,构建整个大图的邻接矩阵。
cora.cites里的数据格式如图,点对形式
3.搭建GAT模型
GAT(Graph Attention Network)
GAT整个模型,初始有8个注意力层
GraphAttentionLayer层代码
模型训练,输入数据转换过程,数据形状