相对tensorflow(1.0), pytorch确实要更容易使用。由于课题和图神经网络相关,最近也在学习使用一些图深度建模的工具,比如tensorflow的Deep Graph Library以及pytorch的 pytorch geometric. 因为还在学习tensorflow 2.0,目前以pytorch geometric为主
我在网上找到一篇介绍这个的博客,写得很不错: https://towardsdatascience.com/hands-on-graph-neural-networks-with-pytorch-pytorch-geometric-359487e221a8
我尝试自己总结一些关键点和困扰一段时间的问题, 并且写了一个简单的example, 在我之前出错的一些点上有注释,用于参考和复习:https://github.com/GQ93/Pytorch-geometric-notes
图数据结构
图说起来也很简单,就是两个核心点,一个是图节点(nodes/vertics),一个是边(edges/links)表示节点之间的连接关系,基本概念可以参考wiki的界面https://en.wikipedia.org/wiki/Graph_theory
总体而言,图可以是不规整的(irregular),对比而言,平时我们看到的图片都是规整的(regular),可以表示成矩阵或者向量。问题在于设计数据结构如何储存图,一般有两种方案
一. 矩阵表示
参考: https://en.wikipedia.org/wiki/Graph_(abstract_data_type)#Representations
又细分成两种
(1) 用邻接矩阵(adjacency matrix:https://en.wikipedia.org/wiki/Adjacency_matrix),度矩阵(degree matrix:https://en.wikipedia.org/wiki/Degree_matrix), 拉普拉斯矩阵(Laplacian matrix:https://en.wikipedia.org/wiki/Laplacian_matrix)去表示, 衍生出各种对拉普拉斯矩阵的操作,比如图傅里叶变换(graph fourier transform), 也有稀疏邻接矩阵
(2) 用关联矩阵(incidence matrix:https://en.wikipedia.org/wiki/Incidence_matrix)表示,行表示节点,列表示边, 和这个相关的例如:超图(hypergraph)
这(两)种方式的缺点在于使用内存大, 矩阵维度和节点数目N挂钩。但是图的连接常常是稀疏的(sparse),也就是邻接矩阵中很多元素都是0(两个node没有连接关系),这些0元素会占据大量存储空间,使效率很低下。尤其是大型网络图,都不会把图完整的表示成一个矩阵。PS:吐槽一下,学数学的倒是特别喜欢
二. 邻接表
邻接表(Adjacency list:https://en.wikipedia.org/wiki/Adjacency_list),也是数据领域常用的存储图方式,比如将边表示成节点对,成为一个2*N_edges的matrix,第一行表示source node, 第二行表示target node。这样的好处在于可以只储存有边存在的,对稀疏结构友好。总体而言,如果图是dense的,可以考虑矩阵表示,如果是稀疏的,最好使用稀疏邻接矩阵或邻接表
Pytorch Geometric
一. torch_geometric.data.Data
pytorch Geometric Data使用邻接表去表示图,同时也表示了node特征x, 边属性edge_attr等, 需要注意的是, Data只表示一张图(single graph)
Data作为一个数据结构,需要填充几个属性
Data(x=None, edge_index=None, edge_attr=None, y=None)
x: 表示节点特征,可选,shape: [num_nodes, num_node_features] 有的图只有结构没有节点特征
edge_index: 表示边,也就是邻接表, shape: [2, num_edges]
注意,因为能表示有向图, 对于无向图,一条边要存入两次,也就是位于节点1和节点2的边,需要写成[[1,2][2,1]]而不能只写入[[1],[2]]; node的编号和edge要对应,也就是 max_num_edges = num_nodes*num_nodes 而不是num_nodes*num_nodes /2
edge_attr: 表示边属性(e.g. , 权重,类型),shape: [num_edges, num_edge_features]
y: 是label,官方文档中说 Graph or node targets with arbitrary shape,所以shape可以是[num_nodes, nodes_label_dimension],或者是[graph_label_dimesnion]
二. 构建Dataset
pytorch geometric 构建数据集分两种, 一种继承InMemoryDataset,一次性加载所有数据到内存;另一种继承Dataset, 分次加载到内存
A. 继承InMemoryDataset
import torch
from torch_geometric.data import InMemoryDataset
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# Download to `self.raw_dir`.
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
几个关键点容易被忽视
1. 如果需要在initial里面初始化一些参数,如定义mask,需要在super前继承参数,不然会失败无法传递到子函数里面,原因我也不清楚。 举例:
self.num_train_per_class 要卸载super(NodeDatasetInMem, self)这一行前面
2. 我们主要需要编辑def processed_file_names(self) 和 def process(self),
processed_file_names只需要申明把处理好的dataset存在哪里(路径加文件名)
process就是写一个函数,处理数据成torch_geometric.data.Data的形式,如果是图分类,还需要把多个图存成一个list
要注意x一般是float tensor, y 是 long tensor, mask 是boolean tensor, edge_index是long tensor
而且当y是graph label时, 不能是0-dimension tensor, 也就是说
y = torch.tensor(0, dtype=torch.long)
会报错, 要写成
y = torch.tensor([0], dtype=torch.long)
3. 其余函数作用
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
这个是官方代码里面的,作用就是通过self.collate把数据划分成不同slices去保存读取 (大数据块切成小块)
所以即使只有一个graph写成了data, 在调用self.collate时,也要写成list:
data, slices = self.collate([data])
B. 继承Dataset
直接继承torch_geometric.data.Dataset,除了和InMemoryDataset相似的函数以外,需要多写两个函数
torch_geometric.data.Dataset.len():
因为Dataset相对于InMemoryDataset,不会一次加载所有函数,而是分批,所有会把数据保存成好几个小数据包(.pt 文件),len() 就是说明有几个数据包,官方的指导:
def len(self):
return len(self.processed_file_names)
可以完全照搬,只需要改变processed_file_names的返回值,例如
还有一个get() 函数
torch_geometric.data.Dataset.get():
这个函数需要返回值时一个data,single graph: Implements the logic to load a single graph
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
注意, 这里的load里面的函数名要和processed_file_name()返回的函数名一致, idx就是数据包的遍历下标
几个容易出问题的地方
1. 继承InMemoryDataset时,在super继承之后,有一个读取数据的命令
由于继承Dataset, 有get函数load数据,所以写继承Dataset时不需要这条命令,否则会报错
2. 不再调用self.collate() 去划分数据包, 也就没有data_list. 直接把一个个小数据包按照下标储存就好
以后看情况补足raw_file_names()和download()相关,不过本地数据可以不需要填充这两个函数