超大规模数据集类的创建

引言

本文是datawhale开源社区GNN组队学习的笔记,绝大部分内容出自其中。
torch_geometric.data.InMemoryDataset类的使用是让数据可全部储存于内存的数据集,这些数据集对应的数据集类在创建对象时就将所有数据都加载到内存。然而在一些应用场景中,数据集规模超级大,我们很难有足够大的内存完全存下所有数据。因此需要一个按需加载样本到内存的数据集类。

Dataset基类

在我们将学习为一个包含上千万个图样本的数据集构建一个数据集类。
在PyG中,我们通过继承torch_geometric.data.Dataset基类来自定义一个按需加载样本到内存的数据集类。
继承torch_geometric.data.InMemoryDataset基类要实现的方法(raw_file_names(), processed_file_names(),download(),process()),继承此基类同样要实现,此外还需要实现以下方法:

  1. len():返回数据集中的样本的数量。
  2. get():实现加载单个图的操作。注意:在内部,getitem()返回通过调用get()来获取Data对象,并根据transform参数对它们进行选择性转换。

对于无需下载数据集原文件的情况,我们不重写download方法即可跳过下载。对于无需对数据集做预处理的情况,我们不重写process方法即可跳过预处理。

合并小图组成大图

图可以有任意数量的节点和边,它不是规整的数据结构,因此对图数据封装成批的操作与对图像和序列等数据封装成批的操作不同。PyTorch Geometric中采用的将多个图封装成批的方式是,将小图作为连通组件(connected component)的形式合并,构建一个大图。于是小图的邻接矩阵存储在大图邻接矩阵的对角线上。大图的邻接矩阵、属性矩阵、预测目标矩阵分别为:


合并后的大图

此方法有以下关键的优势:

  • 依靠消息传递方案的GNN运算不需要被修改,因为消息仍然不能在属于不同图的两个节点之间交换。
  • 没有额外的计算或内存的开销。例如,这个批处理程序的工作完全不需要对节点或边缘特征进行任何填充。请注意,邻接矩阵没有额外的内存开销,因为它们是以稀疏的方式保存的,只保留非零项,即边。
    小图中的属性拼接
    将小图存储到大图中时需要对小图的属性做一些修改,一个最显著的例子就是要对节点序号增值。在最一般的形式中,PyTorch Geometric的DataLoader类会自动对edge_index张量增值,增加的值为当前被处理图的前面的图的累积节点数量。比方说,现在对第个图的edge_index张量做增值,前面个图的累积节点数量为,那么对第个图的edge_index张量的增值。增值后,对所有图的edge_index张量(其形状为[2, num_edges])在第二维中连接起来。

然而,有一些特殊的场景中(如下所述),基于需求我们希望能修改这一行为。PyTorch Geometric允许我们通过覆盖torch_geometric.data.__inc__()torch_geometric.data.__cat_dim__()函数来实现我们希望的行为。

from torch_geometric.data import Data, DataLoader
import torch

class PairData(Data):
    #将两个图,一个源图G_s和一个目标图G_t,存储在一个Data类中
    def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t
    
    #c重写__inc__()两个连续的图的属性之间的增量大小
    def __inc__(self, key, value):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super().__inc__(key, value)

# 定义边索引矩阵
edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
#5个节点,16个特征
x_s = torch.randn(5, 16)  

edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
#4个节点 16个特征
x_t = torch.randn(4, 16)

data = PairData(edge_index_s, x_s, edge_index_t, x_t)
data_list = [data, data]

loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))


print(batch)
#对应小图 未成功将bacth映射成小图
print(batch.edge_index_s)
print(batch.edge_index_t)
'''
Batch(edge_index_s=[2, 8], edge_index_t=[2, 6], x_s=[10, 16], x_t=[8, 16])
tensor([[0, 0, 0, 0, 5, 5, 5, 5],
        [1, 2, 3, 4, 6, 7, 8, 9]])
tensor([[0, 0, 0, 4, 4, 4],
        [1, 2, 3, 5, 6, 7]])
'''

# 利用follow_batch属性
loader = DataLoader(data_list, batch_size=2, follow_batch=['x_s', 'x_t'])
batch = next(iter(loader))

print(batch)
# Batch(edge_index_s=[2, 8], x_s=[10, 16], x_s_batch=[10], edge_index_t=[2, 6], x_t=[8, 16], x_t_batch=[8])
print(batch.x_s_batch)
# tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])

print(batch.x_t_batch)
# tensor([0, 0, 0, 0, 1, 1, 1, 1])

二部图的增值
一般来说,不同类型的节点数量不需要一致,于是二部图的邻接矩阵A \in \{0,1\}^{N \times M}可能不是平方矩阵,即可能有N \neq M。对二部图的封装成批过程中,edge_index 中边的源节点与目标节点做的增值操作应是不同的。

class BipartiteData(Data):
    def __init__(self, edge_index, x_s, x_t):
        super(BipartiteData, self).__init__()
        self.edge_index = edge_index
        self.x_s = x_s
        self.x_t = x_t

def __inc__(self, key, value):
    if key == 'edge_index':
        #源、目标节点增值
        return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
    else:
        return super().__inc__(key, value)

edge_index = torch.tensor([
    [0, 0, 1, 1],
    [0, 1, 1, 2],
])
x_s = torch.randn(2, 16)  # 2 nodes.
x_t = torch.randn(3, 16)  # 3 nodes.

data = BipartiteData(edge_index, x_s, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
# Batch(edge_index=[2, 8], x_s=[4, 16], x_t=[6, 16])

#边的源节点增值为2,目标节点增值为3
print(batch.edge_index)
# tensor([[0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 1, 2, 3, 4, 4, 5]])

超大图数据实践

import os
import os.path as osp

import pandas as pd
import torch
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import download_url, extract_zip
from rdkit import RDLogger
from torch_geometric.data import Data, Dataset
import shutil
from torch_geometric.data import DataLoader
from tqdm import tqdm

RDLogger.DisableLog('rdApp.*')

class MyPCQM4MDataset(Dataset):

    def __init__(self, root):
        self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip'
        super(MyPCQM4MDataset, self).__init__(root)

        filepath = osp.join(root, 'raw/data.csv.gz')
        data_df = pd.read_csv(filepath)
        self.smiles_list = data_df['smiles']
        self.homolumogap_list = data_df['homolumogap']

    @property
    def raw_file_names(self):
        return 'data.csv.gz'

    def download(self):
        path = download_url(self.url, self.root)
        extract_zip(path, self.root)
        os.unlink(path)
        shutil.move(osp.join(self.root, 'pcqm4m_kddcup2021/raw/data.csv.gz'), osp.join(self.root, 'raw/data.csv.gz'))

    def len(self):
        return len(self.smiles_list)

    def get(self, idx):
        smiles, homolumogap = self.smiles_list[idx], self.homolumogap_list[idx]
        graph = smiles2graph(smiles)
        assert(len(graph['edge_feat']) == graph['edge_index'].shape[1])
        assert(len(graph['node_feat']) == graph['num_nodes'])

        x = torch.from_numpy(graph['node_feat']).to(torch.int64)
        edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
        edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
        y = torch.Tensor([homolumogap])
        num_nodes = int(graph['num_nodes'])
        data = Data(x, edge_index, edge_attr, y, num_nodes=num_nodes)
        return data

    # 获取数据集划分
    def get_idx_split(self):
        split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.root, 'pcqm4m_kddcup2021/split_dict.pt')))
        return split_dict


dataset = MyPCQM4MDataset('dataset2')
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)
for batch in tqdm(dataloader):
    pass
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容