pytorch geometric 自定义数据集

相对tensorflow(1.0), pytorch确实要更容易使用。由于课题和图神经网络相关,最近也在学习使用一些图深度建模的工具,比如tensorflow的Deep Graph Library以及pytorch的 pytorch geometric. 因为还在学习tensorflow 2.0,目前以pytorch geometric为主

我在网上找到一篇介绍这个的博客,写得很不错: https://towardsdatascience.com/hands-on-graph-neural-networks-with-pytorch-pytorch-geometric-359487e221a8

我尝试自己总结一些关键点和困扰一段时间的问题, 并且写了一个简单的example, 在我之前出错的一些点上有注释,用于参考和复习:https://github.com/GQ93/Pytorch-geometric-notes

图数据结构

图说起来也很简单,就是两个核心点,一个是图节点(nodes/vertics),一个是边(edges/links)表示节点之间的连接关系,基本概念可以参考wiki的界面https://en.wikipedia.org/wiki/Graph_theory

总体而言,图可以是不规整的(irregular),对比而言,平时我们看到的图片都是规整的(regular),可以表示成矩阵或者向量。问题在于设计数据结构如何储存图,一般有两种方案

一. 矩阵表示

参考: https://en.wikipedia.org/wiki/Graph_(abstract_data_type)#Representations

又细分成两种

(1) 用邻接矩阵(adjacency matrix:https://en.wikipedia.org/wiki/Adjacency_matrix),度矩阵(degree matrix:https://en.wikipedia.org/wiki/Degree_matrix), 拉普拉斯矩阵(Laplacian matrix:https://en.wikipedia.org/wiki/Laplacian_matrix)去表示, 衍生出各种对拉普拉斯矩阵的操作,比如图傅里叶变换(graph fourier transform), 也有稀疏邻接矩阵

(2) 用关联矩阵(incidence matrix:https://en.wikipedia.org/wiki/Incidence_matrix)表示,行表示节点,列表示边, 和这个相关的例如:超图(hypergraph)

这(两)种方式的缺点在于使用内存大, 矩阵维度和节点数目N挂钩。但是图的连接常常是稀疏的(sparse),也就是邻接矩阵中很多元素都是0(两个node没有连接关系),这些0元素会占据大量存储空间,使效率很低下。尤其是大型网络图,都不会把图完整的表示成一个矩阵。PS:吐槽一下,学数学的倒是特别喜欢


二. 邻接表

邻接表(Adjacency listhttps://en.wikipedia.org/wiki/Adjacency_list),也是数据领域常用的存储图方式,比如将边表示成节点对,成为一个2*N_edges的matrix,第一行表示source node, 第二行表示target node。这样的好处在于可以只储存有边存在的,对稀疏结构友好。总体而言,如果图是dense的,可以考虑矩阵表示,如果是稀疏的,最好使用稀疏邻接矩阵或邻接表


Pytorch Geometric 

一. torch_geometric.data.Data

pytorch Geometric Data使用邻接表去表示图,同时也表示了node特征x, 边属性edge_attr等, 需要注意的是, Data只表示一张图(single graph)

Data作为一个数据结构,需要填充几个属性

Data(x=Noneedge_index=Noneedge_attr=Noney=None)

x: 表示节点特征,可选,shape: [num_nodes, num_node_features] 有的图只有结构没有节点特征

edge_index: 表示边,也就是邻接表, shape: [2, num_edges] 

注意因为能表示有向图, 对于无向图,一条边要存入两次,也就是位于节点1和节点2的边,需要写成[[1,2][2,1]]而不能只写入[[1],[2]]; node的编号和edge要对应,也就是 max_num_edges = num_nodes*num_nodes 而不是num_nodes*num_nodes /2

edge_attr: 表示边属性(e.g. , 权重,类型),shape: [num_edges, num_edge_features] 

y: 是label,官方文档中说  Graph or node targets with arbitrary shape,所以shape可以是[num_nodes, nodes_label_dimension],或者是[graph_label_dimesnion]

二. 构建Dataset

pytorch geometric 构建数据集分两种, 一种继承InMemoryDataset,一次性加载所有数据到内存;另一种继承Dataset, 分次加载到内存

A. 继承InMemoryDataset

import torch

from torch_geometric.data import InMemoryDataset

class MyOwnDataset(InMemoryDataset):

    def __init__(self, root, transform=None, pre_transform=None):

        super(MyOwnDataset, self).__init__(root, transform, pre_transform)

        self.data, self.slices = torch.load(self.processed_paths[0])

    @property

    def raw_file_names(self):

        return ['some_file_1', 'some_file_2', ...]

    @property

    def processed_file_names(self):

        return ['data.pt']

    def download(self):

        # Download to `self.raw_dir`.

    def process(self):

        # Read data into huge `Data` list.

        data_list = [...]

        if self.pre_filter is not None:

            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:

            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)

        torch.save((data, slices), self.processed_paths[0])

几个关键点容易被忽视

1. 如果需要在initial里面初始化一些参数,如定义mask,需要在super前继承参数,不然会失败无法传递到子函数里面,原因我也不清楚。 举例:


继承参数

self.num_train_per_class 要卸载super(NodeDatasetInMem, self)这一行前面

2. 我们主要需要编辑def processed_file_names(self) 和 def process(self), 

processed_file_names只需要申明把处理好的dataset存在哪里(路径加文件名)

process就是写一个函数,处理数据成torch_geometric.data.Data的形式,如果是图分类,还需要把多个图存成一个list

要注意x一般是float tensor, y 是 long tensor, mask 是boolean tensor, edge_index是long tensor

而且当y是graph label时, 不能是0-dimension tensor, 也就是说

y = torch.tensor(0, dtype=torch.long)

会报错, 要写成

y = torch.tensor([0], dtype=torch.long) 

3. 其余函数作用

data, slices = self.collate(data_list)

torch.save((data, slices), self.processed_paths[0])

这个是官方代码里面的,作用就是通过self.collate把数据划分成不同slices去保存读取 (大数据块切成小块)

所以即使只有一个graph写成了data, 在调用self.collate时,也要写成list:

data, slices = self.collate([data])

B. 继承Dataset

直接继承torch_geometric.data.Dataset,除了和InMemoryDataset相似的函数以外,需要多写两个函数

torch_geometric.data.Dataset.len():

因为Dataset相对于InMemoryDataset,不会一次加载所有函数,而是分批,所有会把数据保存成好几个小数据包(.pt 文件),len() 就是说明有几个数据包,官方的指导:

def len(self):

        return len(self.processed_file_names)

可以完全照搬,只需要改变processed_file_names的返回值,例如


有几个数据包就写几个数据名

还有一个get() 函数

torch_geometric.data.Dataset.get():

这个函数需要返回值时一个data,single graph: Implements the logic to load a single graph

def get(self, idx):

        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))

        return data

注意, 这里的load里面的函数名要和processed_file_name()返回的函数名一致, idx就是数据包的遍历下标

几个容易出问题的地方

1. 继承InMemoryDataset时,在super继承之后,有一个读取数据的命令


load数据

由于继承Dataset, 有get函数load数据,所以写继承Dataset时不需要这条命令,否则会报错

2. 不再调用self.collate() 去划分数据包, 也就没有data_list. 直接把一个个小数据包按照下标储存就好


不再调用self.collate(),直接torch.save



以后看情况补足raw_file_names()和download()相关,不过本地数据可以不需要填充这两个函数

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 205,132评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,802评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,566评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,858评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,867评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,695评论 1 282
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,064评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,705评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 42,915评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,677评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,796评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,432评论 4 322
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,041评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,992评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,223评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,185评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,535评论 2 343