<p style="text-indent:0pt"><span style="font-size:16px"><span>本文先简单概述GNN节点分类任务,然后详细介绍如何使用Deep </span><span>Graph Library + Pytorch</span><span>实现一个简单的两层GNN模型在Cora引文数据上实现节点分类任务。若需获取模型的完整代码,可关注公众号【AI机器学习与知识图谱】后回复:</span></span><strong><strong>DGL第一讲完整代码</strong></strong></p>
<strong><strong>GNN节点分类概述</strong></strong><p style="text-indent:0pt"><span style="font-size:16px"><span>节点分类是图/图谱数据上常被采用的一个学习任务,既是用模型预测图中每个节点的类别。在GNN模型被提出之前,常用的模型如Deep</span><span>Walk</span><span>,Node</span><span>2Vec</span><span>等,都是借助序列属性和节点自身特性进行预测,但显然图数据不像NLP中的文本数据那样具有序列依赖性。相比之下,GNN系列模型是利用节点的邻接子图,使用子图汇聚的方式先获得节点表征,再对节点类别进行预测。例如,在2</span><span>017</span><span>年</span></span><span style="font-size:16px">Kipf et al.</span><span style="font-size:16px"><span>等提出的GCN模型将图的节点分类问题看作一个半监督学习任务。即只利用图中一小部分节点,模型就可以准确预测其他节点的类别。</span></span></p><span style="font-size:16px">接下来的实验将通过构建GCN模型,在Cora数据集上进行半监督节点分类任务的训练和预测。</span><span style="font-size:16px">Cora数据集是一个引文网络,其中节点是代指某篇论文,节点之间的边代表论文之间的相互引用关系。</span><span style="font-size:16px">
</span><span style="font-size:16px">Cora引文网络共包含2</span><span style="font-size:16px">708</span><span style="font-size:16px">个节点,1</span><span style="font-size:16px">0556</span><span style="font-size:16px">个边,其中每个节点由1</span><span style="font-size:16px">433</span><span style="font-size:16px">维特征组成,每个特征代表词库中的一个Word,如果此篇论文中包含这个W</span><span style="font-size:16px">ord</span><span style="font-size:16px">则这一维特征为1,否则这一维特征为0。在训练数据划分上,其中训练集1</span><span style="font-size:16px">40</span><span style="font-size:16px">个样本节点,验证集5</span><span style="font-size:16px">00</span><span style="font-size:16px">个,测试集1</span><span style="font-size:16px">000</span><span style="font-size:16px">个。目的是训练模型少标签半监督任务的预测能力。Cora引文网络中节点共分为七类,因此节点分类任务是个七分类问题。</span>
<span style="font-size:16px">
</span><p><strong><strong>DGL实现GNN节点分类</strong></strong></p><span style="font-size:16px">接下来使用DGL框架实现GNN模型进行节点分类任务,对代码进行逐行解释。</span><span>1</span> <span>import</span> dgl
<span>2</span> <span>import</span> torch
<span>3</span> <span>import</span> torch.nn <span>as</span> nn
<span>4</span> <span>import</span> torch.nn.functional <span>as</span> F<span style="font-size:16px"><p style="text-indent:0pt"><span style="font-size:16px">首先,上述四行代码,先加载需要使用的dgl库和pytorch库;</span></p><span>1</span> <span>import</span> dgl.data
<span>2</span> dataset = dgl.data.CoraGraphDataset()
<span>3</span> print(<span>'Number of categories:'</span>, dataset.num_classes)
<span>4</span> g = dataset[<span>0</span>]<span style="font-size:16px"><p style="text-indent:0pt"><span style="font-size:16px">上面第二行代码,加载dgl库提供的Cora数据对象,第四行代码,dgl库中Dataset数据集可能是包含多个图的,所以加载的dataset对象是一个list,list中的每个元素对应该数据的一个graph,但Cora数据集是由单个图组成,因此直接使用dataset[0]取出graph。</span></p>print(<span>'Node features: '</span>, g.ndata)
print(<span>'Edge features: '</span>, g.edata)</span></span><span style="font-size:16px">
</span><span style="font-size:16px">看上面两行代码,需要说明DGL库中一个Graph对象是使用字典形式存储了其Node Features和Edge Features,其中第一行g.ndata使用字典结构存储了节点特征信息,第二行g.edata使用字典结构存储了边特征信息。</span><span style="font-size:16px">对于Cora数据集的graph来说,Node Features共包含以下五个方面:</span><span style="font-size:16px"><p><span style="font-size:16px">1. train_mask: 指示节点是否在训练集中的布尔张量</span></p><p><span style="font-size:16px">2. val_mask: 指示节点是否在验证集中的布尔张量</span></p><p><span style="font-size:16px">3. test_mask: 指示节点是否在测试机中的布尔张量</span></p><p><span style="font-size:16px">4. label: 每个节点的真实类别</span></p><span style="font-size:16px">5. feat: 节点自身的属性</span></span><span>
</span><span>1</span> <span>from</span> dgl.nn <span>import</span> GraphConv
<span>2</span>
<span>3</span> <span><span>class</span> <span>GCN</span><span>(nn.Module)</span>:</span>
<span>4</span> <span><span>def</span> <span>init</span><span>(self, in_feats, h_feats, num_classes)</span>:</span>
<span>5</span> super(GCN, self).init()
<span>6</span> self.conv1 = GraphConv(in_feats, h_feats)
<span>7</span> self.conv2 = GraphConv(h_feats, num_classes)
<span>8</span>
<span>9</span> <span><span>def</span> <span>forward</span><span>(self, g, in_feat)</span>:</span>
<span>10</span> <span># 这里g代表的Cora数据Graph信息,一般就是经过归一化的邻接矩阵</span>
<span>11</span> <span># in_feat表示的是node representation,即节点初始化特征信息</span>
<span>12</span> h = self.conv1(g, in_feat)
<span>13</span> h = F.relu(h)
<span>14</span> h = self.conv2(g, h)
<span>15</span> <span>return</span> h
<span>16</span>
<span>17</span> <span># 使用给定的维度创建GCN模型,其中hidden维度设定为16,输入维度和输出维度由数据集确定。</span>
<span>18</span> model = GCN(g.ndata[<span>'feat'</span>].shape[<span>1</span>], <span>16</span>, dataset.num_classes)<span style="font-size:16px">
</span><span style="font-size:16px">上面代码使用dgl库中的dgl.nn.GraphConv模块构建了一个两层GCN网络,每层都通过汇聚邻居节点信息来更新节点表征,每层GCN网络都便随着维度的变化,第一层维度映射(in_feats, h_feats),第二层维度映射(h_feats, num_classes),总共两层网络因此第二层直接映射到最终分类类别维度上。</span><p style="text-indent:0pt"><span style="font-size:16px">这里需要强调上面代码第九行中g, in_feat两个参数,参数g代表的Cora数据Graph信息,一般就是经过归一化的邻接矩阵,如下所示,</span><span>其中</span><span><span>是邻接矩阵,</span><span><span>是单位矩阵,</span><span><span>是度矩阵</span><span style="font-size:16px">:</span>
</span></span></span></p><div class="image-package"><img src="https://upload-images.jianshu.io/upload_images/26011021-47a25ab05bf4e044.jpeg" img-data="{"format":"jpeg","size":6512,"height":60,"width":602}" class="uploaded-img" style="min-height:200px;min-width:200px;" width="auto" height="auto"/>
</div><span><span style="font-size:16px"/><span style="font-size:16px">参数in_feat表示的是node representation,即节点初始化特征信息。</span></span><span><span>
</span></span><span><span>def</span> <span>train</span><span>(g, model)</span>:</span>
optimizer = torch.optim.Adam(model.parameters(), lr=<span>0.01</span>)
best_val_acc = <span>0</span>
best_test_acc = <span>0</span>
features = g.ndata[<span>'feat'</span>]
labels = g.ndata[<span>'label'</span>]
train_mask = g.ndata[<span>'train_mask'</span>]
val_mask = g.ndata[<span>'val_mask'</span>]
test_mask = g.ndata[<span>'test_mask'</span>]
<span>for</span> e <span>in</span> range(<span>100</span>):
<span># Forward</span>
logits = model(g, features)
<span># Compute prediction</span>
pred = logits.argmax(<span>1</span>)
<span># Compute loss</span>
<span># Note that you should only compute the losses of the nodes in the training set.</span>
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
<span># Compute accuracy on training/validation/test</span>
train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
test_acc = (pred[test_mask] == labels[test_mask]).float().mean()
<span># Save the best validation accuracy and the corresponding test accuracy.</span>
<span>if</span> best_val_acc < val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
<span># Backward</span>
optimizer.zero_grad()
loss.backward()
optimizer.step()
<span>if</span> e % <span>5</span> == <span>0</span>:
print(<span>'In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'</span>.format(
e, loss, val_acc, best_val_acc, test_acc, best_test_acc))
model = GCN(g.ndata[<span>'feat'</span>].shape[<span>1</span>], <span>16</span>, dataset.num_classes)
train(g, model)<p style="text-indent:0pt"><span style="font-size:16px">上面是模型的训练函数,和pytorch模型训练过程都是相似的,训练过程如下图所示:</span></p><span>In</span> <span>epoch</span> 0, <span>loss</span>: 1<span>.947</span>, <span>val</span> <span>acc</span>: 0<span>.070</span> (<span>best</span> 0<span>.070</span>), <span>test</span> <span>acc</span>: 0<span>.064</span> (<span>best</span> 0<span>.064</span>)
<span>In</span> <span>epoch</span> 5, <span>loss</span>: 1<span>.905</span>, <span>val</span> <span>acc</span>: 0<span>.428</span> (<span>best</span> 0<span>.428</span>), <span>test</span> <span>acc</span>: 0<span>.426</span> (<span>best</span> 0<span>.426</span>)
<span>In</span> <span>epoch</span> 10, <span>loss</span>: 1<span>.835</span>, <span>val</span> <span>acc</span>: 0<span>.608</span> (<span>best</span> 0<span>.608</span>), <span>test</span> <span>acc</span>: 0<span>.646</span> (<span>best</span> 0<span>.646</span>)
<span>In</span> <span>epoch</span> 15, <span>loss</span>: 1<span>.739</span>, <span>val</span> <span>acc</span>: 0<span>.590</span> (<span>best</span> 0<span>.630</span>), <span>test</span> <span>acc</span>: 0<span>.623</span> (<span>best</span> 0<span>.648</span>)
<span>In</span> <span>epoch</span> 20, <span>loss</span>: 1<span>.618</span>, <span>val</span> <span>acc</span>: 0<span>.644</span> (<span>best</span> 0<span>.644</span>), <span>test</span> <span>acc</span>: 0<span>.670</span> (<span>best</span> 0<span>.670</span>)
<span>In</span> <span>epoch</span> 25, <span>loss</span>: 1<span>.475</span>, <span>val</span> <span>acc</span>: 0<span>.698</span> (<span>best</span> 0<span>.698</span>), <span>test</span> <span>acc</span>: 0<span>.737</span> (<span>best</span> 0<span>.737</span>)
<span>In</span> <span>epoch</span> 30, <span>loss</span>: 1<span>.316</span>, <span>val</span> <span>acc</span>: 0<span>.720</span> (<span>best</span> 0<span>.724</span>), <span>test</span> <span>acc</span>: 0<span>.731</span> (<span>best</span> 0<span>.731</span>)
<span>In</span> <span>epoch</span> 35, <span>loss</span>: 1<span>.148</span>, <span>val</span> <span>acc</span>: 0<span>.726</span> (<span>best</span> 0<span>.726</span>), <span>test</span> <span>acc</span>: 0<span>.728</span> (<span>best</span> 0<span>.728</span>)
<span>In</span> <span>epoch</span> 40, <span>loss</span>: 0<span>.981</span>, <span>val</span> <span>acc</span>: 0<span>.742</span> (<span>best</span> 0<span>.744</span>), <span>test</span> <span>acc</span>: 0<span>.754</span> (<span>best</span> 0<span>.747</span>)
<span>In</span> <span>epoch</span> 45, <span>loss</span>: 0<span>.822</span>, <span>val</span> <span>acc</span>: 0<span>.750</span> (<span>best</span> 0<span>.750</span>), <span>test</span> <span>acc</span>: 0<span>.764</span> (<span>best</span> 0<span>.764</span>)
<span>In</span> <span>epoch</span> 50, <span>loss</span>: 0<span>.678</span>, <span>val</span> <span>acc</span>: 0<span>.764</span> (<span>best</span> 0<span>.764</span>), <span>test</span> <span>acc</span>: 0<span>.766</span> (<span>best</span> 0<span>.766</span>)
<span>In</span> <span>epoch</span> 55, <span>loss</span>: 0<span>.552</span>, <span>val</span> <span>acc</span>: 0<span>.770</span> (<span>best</span> 0<span>.770</span>), <span>test</span> <span>acc</span>: 0<span>.766</span> (<span>best</span> 0<span>.766</span>)
<span>In</span> <span>epoch</span> 60, <span>loss</span>: 0<span>.447</span>, <span>val</span> <span>acc</span>: 0<span>.774</span> (<span>best</span> 0<span>.774</span>), <span>test</span> <span>acc</span>: 0<span>.764</span> (<span>best</span> 0<span>.764</span>)
<span>In</span> <span>epoch</span> 65, <span>loss</span>: 0<span>.361</span>, <span>val</span> <span>acc</span>: 0<span>.778</span> (<span>best</span> 0<span>.778</span>), <span>test</span> <span>acc</span>: 0<span>.772</span> (<span>best</span> 0<span>.772</span>)
<span>In</span> <span>epoch</span> 70, <span>loss</span>: 0<span>.292</span>, <span>val</span> <span>acc</span>: 0<span>.782</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.771</span> (<span>best</span> 0<span>.771</span>)
<span>In</span> <span>epoch</span> 75, <span>loss</span>: 0<span>.238</span>, <span>val</span> <span>acc</span>: 0<span>.778</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.775</span> (<span>best</span> 0<span>.771</span>)
<span>In</span> <span>epoch</span> 80, <span>loss</span>: 0<span>.196</span>, <span>val</span> <span>acc</span>: 0<span>.776</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.778</span> (<span>best</span> 0<span>.771</span>)
<span>In</span> <span>epoch</span> 85, <span>loss</span>: 0<span>.162</span>, <span>val</span> <span>acc</span>: 0<span>.774</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.778</span> (<span>best</span> 0<span>.771</span>)
<span>In</span> <span>epoch</span> 90, <span>loss</span>: 0<span>.136</span>, <span>val</span> <span>acc</span>: 0<span>.774</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.777</span> (<span>best</span> 0<span>.771</span>)
<span>In</span> <span>epoch</span> 95, <span>loss</span>: 0<span>.115</span>, <span>val</span> <span>acc</span>: 0<span>.770</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.776</span> (<span>best</span> 0<span>.771</span>)
<p>
</p><p><span style="font-size:18px"><strong>往期精彩</strong></span></p><p><span>【知识图谱系列】基于生成式的知识图谱预训练模型</span></p><p><span>【知识图谱系列】基于实数或复数空间的知识图谱嵌入</span>
</p><p><span style="font-size:14px"/></p><p><span>【知识图谱系列】知识图谱多跳推理之强化学习</span>
</p><p><span style="font-size:14px"/></p><p><span>【知识图谱系列】动态时序知识图谱EvolveGCN</span></p><p><span style="font-size:14px"/></p><p>【机器学习系列】机器学习中的两大学派</p>
「GNN框架系列」DGL第一讲:实现GNN节点分类
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- 数据完整存储与内存的数据集类 一、InMemoryDataset基类简介 在PyG中,通过继承InMemoryDa...
- 摘要:图神经网络,GCN,scipy 找了github上搜gcn排名第一的GCN项目分析一下它的代码实现。 快速开...
- 一、引言 结点表征的生成是图结点预测和边预测任务成功的关键。基于图神经网络的结点表征学习可以理解为对图神经网络进行...