GNN学习第9天

首先感谢datawhale 的GNN课程,非常精彩。
GNN/Markdown版本/6-1-数据完整存于内存的数据集类.md 

Task04 数据完整存储与内存的数据集类+节点预测与边预测任务实践

1 知识梳理

1.1 使用数据集的一般过程

从网络上下载数据原始文件;

对数据原始文件做处理,为每一个图样本生成一个**Data对象**;

对每一个Data对象执行数据处理,使其转换成新的Data对象;

过滤Data对象

保存Data对象到文件

获取Data对象,在每一次获取Data对象时,都先对Data对象做数据变换(于是获取到的是数据变换后的Data对象)。

1.2 边预测任务

思路:生成负样本,使得正负样本数量平衡

使用train_test_split_edges函数,采样得到负样本,并将正负样本分成训练集、验证集和测试集

2 实战练习

2.1 PlanetoidPubMed数据集类的构造 (CORA数据集训练滴)dataset = Planetoid(root='./tmp/cora', name='Cora')

dataset = Planetoid(root='./tmp/cora', name='Cora')

print('数据类别个数:', dataset.num_classes)

print('节点数:', dataset[0].num_nodes)

print('边数:', dataset[0].num_edges)

print('节点特征维度:', dataset[0].num_features)

importos.pathasospimporttorchfromtorch_geometric.dataimport(InMemoryDataset,download_url)fromtorch_geometric.ioimportread_planetoid_dataclassPlanetoidPubMed(InMemoryDataset):r""" 节点代表文章,边代表引文关系。

                训练、验证和测试的划分通过二进制掩码给出。

    参数:

        root (string): 存储数据集的文件夹的路径

        transform (callable, optional): 数据转换函数,每一次获取数据时被调用。

        pre_transform (callable, optional): 数据转换函数,数据保存到文件前被调用。

    """#    url = 'https://github.com/kimiyoung/planetoid/raw/master/data'url='https://gitee.com/jiajiewu/planetoid/raw/master/data'def__init__(self,root,transform=None,pre_transform=None):super(PlanetoidPubMed,self).__init__(root,transform,pre_transform)self.data,self.slices=torch.load(self.processed_paths[0])@propertydefraw_dir(self):returnosp.join(self.root,'raw')@propertydefprocessed_dir(self):returnosp.join(self.root,'processed')@propertydefraw_file_names(self):names=['x','tx','allx','y','ty','ally','graph','test.index']return['ind.pubmed.{}'.format(name)fornameinnames]@propertydefprocessed_file_names(self):return'data.pt'defdownload(self):fornameinself.raw_file_names:download_url('{}/{}'.format(self.url,name),self.raw_dir)defprocess(self):data=read_planetoid_data(self.raw_dir,'pubmed')data=dataifself.pre_transformisNoneelseself.pre_transform(data)torch.save(self.collate([data]),self.processed_paths[0])def__repr__(self):return'{}()'.format(self.name)Copy to clipboardErrorCopied

程序运行流程:

检查数据原始文件是否已经下载

检查数据是否经过处理:检查数据变换的方法、检查样本过滤的方法、检查是否处理好数据

dataset=PlanetoidPubMed('dataset/PlanetoidPubMed')print('数据类别个数:',dataset.num_classes)print('节点数:',dataset[0].num_nodes)print('边数:',dataset[0].num_edges)print('节点特征维度:',dataset[0].num_features)Copy to clipboardErrorCopied

数据类别个数: 3

节点数: 19717

边数: 88648

节点特征维度: 500Copy to clipboardErrorCopied

2.2 使用GAT图神经网络进行节点预测

fromtorch_geometric.nnimportGATConv,Sequentialfromtorch.nnimportLinear,ReLUimporttorch.nn.functionalasFclassGAT(torch.nn.Module):def__init__(self,num_features,hidden_channels_list,num_classes):super(GAT,self).__init__()torch.manual_seed(12345)hns=[num_features]+hidden_channels_list        conv_list=[]foridxinrange(len(hidden_channels_list)):conv_list.append((GATConv(hns[idx],hns[idx+1]),'x, edge_index -> x'))conv_list.append(ReLU(inplace=True),)self.convseq=Sequential('x, edge_index',conv_list)self.linear=Linear(hidden_channels_list[-1],num_classes)defforward(self,x,edge_index):x=self.convseq(x,edge_index)x=F.dropout(x,p=0.5,training=self.training)x=self.linear(x)returnxCopy to clipboardErrorCopied

deftrain():model.train()optimizer.zero_grad()# Clear gradients.out=model(data.x,data.edge_index)# Perform a single forward pass.# Compute the loss solely based on the training nodes.loss=criterion(out[data.train_mask],data.y[data.train_mask])loss.backward()# Derive gradients.optimizer.step()# Update parameters based on gradients.returnlossdeftest():model.eval()out=model(data.x,data.edge_index)pred=out.argmax(dim=1)# Use the class with highest probability.test_correct=pred[data.test_mask]==data.y[data.test_mask]# Check against ground-truth labels.test_acc=int(test_correct.sum())/int(data.test_mask.sum())# Derive ratio of correct predictions.returntest_accCopy to clipboardErrorCopied

importmatplotlib.pyplotaspltfromsklearn.manifoldimportTSNE%matplotlib inlinedefvisualize(h,color):z=TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())plt.figure(figsize=(10,10))plt.xticks([])plt.yticks([])plt.scatter(z[:,0],z[:,1],s=70,c=color.cpu(),cmap="Set2")plt.show()Copy to clipboardErrorCopied

fromtorch_geometric.transformsimportNormalizeFeaturesdataset=PlanetoidPubMed(root='dataset/PlanetoidPubMed/',transform=NormalizeFeatures())print('dataset.num_features:',dataset.num_features)device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')data=dataset[0].to(device)model=GAT(num_features=dataset.num_features,hidden_channels_list=[200,100],num_classes=dataset.num_classes).to(device)print(model)optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)criterion=torch.nn.CrossEntropyLoss()forepochinrange(1,201):loss=train()ifepoch%10==0:print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')test_acc=test()print(f'Test Accuracy: {test_acc:.4f}')model.eval()out=model(data.x,data.edge_index)visualize(out,color=data.y)Copy to clipboardErrorCopied

dataset.num_features: 500

GAT(

  (convseq): Sequential(

    (0): GATConv(500, 200, heads=1)

    (1): ReLU(inplace=True)

    (2): GATConv(200, 100, heads=1)

    (3): ReLU(inplace=True)

  )

  (linear): Linear(in_features=100, out_features=3, bias=True)

)

dataset.num_features: 1433

GAT(

  (convseq): Sequential(

    (0): GATConv(1433, 200, heads=1)

    (1): ReLU(inplace=True)

    (2): GATConv(200, 100, heads=1)

    (3): ReLU(inplace=True)

  )

  (linear): Linear(in_features=100, out_features=7, bias=True)

)

Epoch: 010, Loss: 1.7378

Epoch: 020, Loss: 0.7310

Epoch: 030, Loss: 0.2087

Epoch: 040, Loss: 0.0610

Epoch: 050, Loss: 0.0477

Epoch: 060, Loss: 0.0368

Epoch: 070, Loss: 0.0360

Epoch: 080, Loss: 0.0354

Epoch: 090, Loss: 0.0310

Epoch: 100, Loss: 0.0279

Epoch: 110, Loss: 0.0263

Epoch: 120, Loss: 0.0281

Epoch: 130, Loss: 0.0349

Epoch: 140, Loss: 0.0246

Epoch: 150, Loss: 0.0298

Epoch: 160, Loss: 0.0218

Epoch: 170, Loss: 0.0328

Epoch: 180, Loss: 0.0199

Epoch: 190, Loss: 0.0223

Epoch: 200, Loss: 0.0330

Test Accuracy: 0.7510



2.3 使用两层GCNConv神经网络进行边预测

fromtorch_geometric.datasetsimportPlanetoidfromtorch_geometric.utilsimporttrain_test_split_edgesimporttorch_geometric.transformsasTdevice=torch.device('cuda'iftorch.cuda.is_available()else'cpu')dataset='Cora'path=osp.join('dataset',dataset)# 读取Cora数据集dataset=Planetoid(path,dataset,transform=T.NormalizeFeatures())data=dataset[0]ground_truth_edge_index=data.edge_index.to(device)data.train_mask=data.val_mask=data.test_mask=data.y=None# 划分数据集data=train_test_split_edges(data)data=data.to(device)Copy to clipboardErrorCopied

fromtorch_geometric.nnimportGCNConv# 构建神经网络classNet(torch.nn.Module):def__init__(self,in_channels,out_channels):super(Net,self).__init__()self.conv1=GCNConv(in_channels,128)self.conv2=GCNConv(128,out_channels)defencode(self,x,edge_index):x=self.conv1(x,edge_index)x=x.relu()returnself.conv2(x,edge_index)defdecode(self,z,pos_edge_index,neg_edge_index):edge_index=torch.cat([pos_edge_index,neg_edge_index],dim=-1)return(z[edge_index[0]]*z[edge_index[1]]).sum(dim=-1)defdecode_all(self,z):prob_adj=z @ z.t()return(prob_adj>0).nonzero(as_tuple=False).t()Copy to clipboardErrorCopied

fromtorch_geometric.utilsimportnegative_samplingimporttorch.nn.functionalasF# 得到边的类别{0,1}defget_link_labels(pos_edge_index,neg_edge_index):num_links=pos_edge_index.size(1)+neg_edge_index.size(1)link_labels=torch.zeros(num_links,dtype=torch.float)link_labels[:pos_edge_index.size(1)]=1.returnlink_labelsdeftrain(data,model,optimizer):model.train()# 进行负采样,使得样本数一致neg_edge_index=negative_sampling(edge_index=data.train_pos_edge_index,num_nodes=data.num_nodes,num_neg_samples=data.train_pos_edge_index.size(1))optimizer.zero_grad()z=model.encode(data.x,data.train_pos_edge_index)link_logits=model.decode(z,data.train_pos_edge_index,neg_edge_index)link_labels=get_link_labels(data.train_pos_edge_index,neg_edge_index).to(data.x.device)loss=F.binary_cross_entropy_with_logits(link_logits,link_labels)loss.backward()optimizer.step()returnlossCopy to clipboardErrorCopied

fromsklearn.metricsimportroc_auc_score@torch.no_grad()deftest(data,model):model.eval()z=model.encode(data.x,data.train_pos_edge_index)results=[]forprefixin['val','test']:pos_edge_index=data[f'{prefix}_pos_edge_index']neg_edge_index=data[f'{prefix}_neg_edge_index']link_logits=model.decode(z,pos_edge_index,neg_edge_index)# 得到正负类别概率link_probs=link_logits.sigmoid()link_labels=get_link_labels(pos_edge_index,neg_edge_index)results.append(roc_auc_score(link_labels.cpu(),link_probs.cpu()))returnresultsCopy to clipboardErrorCopied

model=Net(dataset.num_features,64).to(device)optimizer=torch.optim.Adam(params=model.parameters(),lr=0.01)best_val_auc=test_auc=0forepochinrange(1,101):loss=train(data,model,optimizer)val_auc,tmp_test_auc=test(data,model)ifval_auc>best_val_auc:best_val_auc=val_auc        test_auc=tmp_test_aucifepoch%10==0:print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, 'f'Test: {test_auc:.4f}')z=model.encode(data.x,data.train_pos_edge_index)final_edge_index=model.decode_all(z)print('ground truth edge shape:',ground_truth_edge_index.shape)print('final edge shape:',final_edge_index.shape)Copy to clipboardErrorCopied

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 205,132评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,802评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,566评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,858评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,867评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,695评论 1 282
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,064评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,705评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 42,915评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,677评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,796评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,432评论 4 322
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,041评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,992评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,223评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,185评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,535评论 2 343

推荐阅读更多精彩内容