「GNN框架系列」DGL第一讲:实现GNN节点分类

<p style="text-indent:0pt"><span style="font-size:16px"><span>本文先简单概述GNN节点分类任务,然后详细介绍如何使用Deep </span><span>Graph Library + Pytorch</span><span>实现一个简单的两层GNN模型在Cora引文数据上实现节点分类任务。若需获取模型的完整代码,可关注公众号【AI机器学习与知识图谱】后回复:</span></span><strong><strong>DGL第一讲完整代码</strong></strong></p>
<strong><strong>GNN节点分类概述</strong></strong><p style="text-indent:0pt"><span style="font-size:16px"><span>节点分类是图/图谱数据上常被采用的一个学习任务,既是用模型预测图中每个节点的类别。在GNN模型被提出之前,常用的模型如Deep</span><span>Walk</span><span>,Node</span><span>2Vec</span><span>等,都是借助序列属性和节点自身特性进行预测,但显然图数据不像NLP中的文本数据那样具有序列依赖性。相比之下,GNN系列模型是利用节点的邻接子图,使用子图汇聚的方式先获得节点表征,再对节点类别进行预测。例如,在2</span><span>017</span><span>年</span></span><span style="font-size:16px">Kipf et al.</span><span style="font-size:16px"><span>等提出的GCN模型将图的节点分类问题看作一个半监督学习任务。即只利用图中一小部分节点,模型就可以准确预测其他节点的类别。</span></span></p><span style="font-size:16px">接下来的实验将通过构建GCN模型,在Cora数据集上进行半监督节点分类任务的训练和预测。</span><span style="font-size:16px">Cora数据集是一个引文网络,其中节点是代指某篇论文,节点之间的边代表论文之间的相互引用关系。</span><span style="font-size:16px">
</span><span style="font-size:16px">Cora引文网络共包含2</span><span style="font-size:16px">708</span><span style="font-size:16px">个节点,1</span><span style="font-size:16px">0556</span><span style="font-size:16px">个边,其中每个节点由1</span><span style="font-size:16px">433</span><span style="font-size:16px">维特征组成,每个特征代表词库中的一个Word,如果此篇论文中包含这个W</span><span style="font-size:16px">ord</span><span style="font-size:16px">则这一维特征为1,否则这一维特征为0。在训练数据划分上,其中训练集1</span><span style="font-size:16px">40</span><span style="font-size:16px">个样本节点,验证集5</span><span style="font-size:16px">00</span><span style="font-size:16px">个,测试集1</span><span style="font-size:16px">000</span><span style="font-size:16px">个。目的是训练模型少标签半监督任务的预测能力。Cora引文网络中节点共分为七类,因此节点分类任务是个七分类问题。</span>
<span style="font-size:16px">
</span><p><strong><strong>DGL实现GNN节点分类</strong></strong></p><span style="font-size:16px">接下来使用DGL框架实现GNN模型进行节点分类任务,对代码进行逐行解释。</span><span>1</span> <span>import</span> dgl
<span>2</span> <span>import</span> torch
<span>3</span> <span>import</span> torch.nn <span>as</span> nn
<span>4</span> <span>import</span> torch.nn.functional <span>as</span> F<span style="font-size:16px"><p style="text-indent:0pt"><span style="font-size:16px">首先,上述四行代码,先加载需要使用的dgl库和pytorch库;</span></p><span>1</span> <span>import</span> dgl.data
<span>2</span> dataset = dgl.data.CoraGraphDataset()
<span>3</span> print(<span>'Number of categories:'</span>, dataset.num_classes)
<span>4</span> g = dataset[<span>0</span>]<span style="font-size:16px"><p style="text-indent:0pt"><span style="font-size:16px">上面第二行代码,加载dgl库提供的Cora数据对象,第四行代码,dgl库中Dataset数据集可能是包含多个图的,所以加载的dataset对象是一个list,list中的每个元素对应该数据的一个graph,但Cora数据集是由单个图组成,因此直接使用dataset[0]取出graph。</span></p>print(<span>'Node features: '</span>, g.ndata)
print(<span>'Edge features: '</span>, g.edata)</span></span><span style="font-size:16px">
</span><span style="font-size:16px">看上面两行代码,需要说明DGL库中一个Graph对象是使用字典形式存储了其Node Features和Edge Features,其中第一行g.ndata使用字典结构存储了节点特征信息,第二行g.edata使用字典结构存储了边特征信息。</span><span style="font-size:16px">对于Cora数据集的graph来说,Node Features共包含以下五个方面:</span><span style="font-size:16px"><p><span style="font-size:16px">1. train_mask: 指示节点是否在训练集中的布尔张量</span></p><p><span style="font-size:16px">2. val_mask: 指示节点是否在验证集中的布尔张量</span></p><p><span style="font-size:16px">3. test_mask: 指示节点是否在测试机中的布尔张量</span></p><p><span style="font-size:16px">4. label: 每个节点的真实类别</span></p><span style="font-size:16px">5. feat: 节点自身的属性</span></span><span>
</span><span>1</span> <span>from</span> dgl.nn <span>import</span> GraphConv
<span>2</span>
<span>3</span> <span><span>class</span> <span>GCN</span><span>(nn.Module)</span>:</span>
<span>4</span> <span><span>def</span> <span>init</span><span>(self, in_feats, h_feats, num_classes)</span>:</span>
<span>5</span> super(GCN, self).init()
<span>6</span> self.conv1 = GraphConv(in_feats, h_feats)
<span>7</span> self.conv2 = GraphConv(h_feats, num_classes)
<span>8</span>
<span>9</span> <span><span>def</span> <span>forward</span><span>(self, g, in_feat)</span>:</span>
<span>10</span> <span># 这里g代表的Cora数据Graph信息,一般就是经过归一化的邻接矩阵</span>
<span>11</span> <span># in_feat表示的是node representation,即节点初始化特征信息</span>
<span>12</span> h = self.conv1(g, in_feat)
<span>13</span> h = F.relu(h)
<span>14</span> h = self.conv2(g, h)
<span>15</span> <span>return</span> h
<span>16</span>
<span>17</span> <span># 使用给定的维度创建GCN模型,其中hidden维度设定为16,输入维度和输出维度由数据集确定。</span>
<span>18</span> model = GCN(g.ndata[<span>'feat'</span>].shape[<span>1</span>], <span>16</span>, dataset.num_classes)<span style="font-size:16px">
</span><span style="font-size:16px">上面代码使用dgl库中的dgl.nn.GraphConv模块构建了一个两层GCN网络,每层都通过汇聚邻居节点信息来更新节点表征,每层GCN网络都便随着维度的变化,第一层维度映射(in_feats, h_feats),第二层维度映射(h_feats, num_classes),总共两层网络因此第二层直接映射到最终分类类别维度上。</span><p style="text-indent:0pt"><span style="font-size:16px">这里需要强调上面代码第九行中g, in_feat两个参数,参数g代表的Cora数据Graph信息,一般就是经过归一化的邻接矩阵,如下所示,</span><span>其中</span><span><span>是邻接矩阵,</span><span><span>是单位矩阵,</span><span><span>是度矩阵</span><span style="font-size:16px">:</span>
</span></span></span></p><div class="image-package"><img src="https://upload-images.jianshu.io/upload_images/26011021-47a25ab05bf4e044.jpeg" img-data="{"format":"jpeg","size":6512,"height":60,"width":602}" class="uploaded-img" style="min-height:200px;min-width:200px;" width="auto" height="auto"/>
</div><span><span style="font-size:16px"/><span style="font-size:16px">参数in_feat表示的是node representation,即节点初始化特征信息。</span></span><span><span>
</span></span><span><span>def</span> <span>train</span><span>(g, model)</span>:</span>
optimizer = torch.optim.Adam(model.parameters(), lr=<span>0.01</span>)
best_val_acc = <span>0</span>
best_test_acc = <span>0</span>

features = g.ndata[<span>'feat'</span>]
labels = g.ndata[<span>'label'</span>]
train_mask = g.ndata[<span>'train_mask'</span>]
val_mask = g.ndata[<span>'val_mask'</span>]
test_mask = g.ndata[<span>'test_mask'</span>]
<span>for</span> e <span>in</span> range(<span>100</span>):
<span># Forward</span>
logits = model(g, features)

<span># Compute prediction</span>
pred = logits.argmax(<span>1</span>)

<span># Compute loss</span>
<span># Note that you should only compute the losses of the nodes in the training set.</span>
loss = F.cross_entropy(logits[train_mask], labels[train_mask])

<span># Compute accuracy on training/validation/test</span>
train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

<span># Save the best validation accuracy and the corresponding test accuracy.</span>
<span>if</span> best_val_acc < val_acc:
best_val_acc = val_acc
best_test_acc = test_acc

<span># Backward</span>
optimizer.zero_grad()
loss.backward()
optimizer.step()

<span>if</span> e % <span>5</span> == <span>0</span>:
print(<span>'In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'</span>.format(
e, loss, val_acc, best_val_acc, test_acc, best_test_acc))

model = GCN(g.ndata[<span>'feat'</span>].shape[<span>1</span>], <span>16</span>, dataset.num_classes)
train(g, model)<p style="text-indent:0pt"><span style="font-size:16px">上面是模型的训练函数,和pytorch模型训练过程都是相似的,训练过程如下图所示:</span></p><span>In</span> <span>epoch</span> 0, <span>loss</span>: 1<span>.947</span>, <span>val</span> <span>acc</span>: 0<span>.070</span> (<span>best</span> 0<span>.070</span>), <span>test</span> <span>acc</span>: 0<span>.064</span> (<span>best</span> 0<span>.064</span>)
<span>In</span> <span>epoch</span> 5, <span>loss</span>: 1<span>.905</span>, <span>val</span> <span>acc</span>: 0<span>.428</span> (<span>best</span> 0<span>.428</span>), <span>test</span> <span>acc</span>: 0<span>.426</span> (<span>best</span> 0<span>.426</span>)
<span>In</span> <span>epoch</span> 10, <span>loss</span>: 1<span>.835</span>, <span>val</span> <span>acc</span>: 0<span>.608</span> (<span>best</span> 0<span>.608</span>), <span>test</span> <span>acc</span>: 0<span>.646</span> (<span>best</span> 0<span>.646</span>)
<span>In</span> <span>epoch</span> 15, <span>loss</span>: 1<span>.739</span>, <span>val</span> <span>acc</span>: 0<span>.590</span> (<span>best</span> 0<span>.630</span>), <span>test</span> <span>acc</span>: 0<span>.623</span> (<span>best</span> 0<span>.648</span>)
<span>In</span> <span>epoch</span> 20, <span>loss</span>: 1<span>.618</span>, <span>val</span> <span>acc</span>: 0<span>.644</span> (<span>best</span> 0<span>.644</span>), <span>test</span> <span>acc</span>: 0<span>.670</span> (<span>best</span> 0<span>.670</span>)
<span>In</span> <span>epoch</span> 25, <span>loss</span>: 1<span>.475</span>, <span>val</span> <span>acc</span>: 0<span>.698</span> (<span>best</span> 0<span>.698</span>), <span>test</span> <span>acc</span>: 0<span>.737</span> (<span>best</span> 0<span>.737</span>)
<span>In</span> <span>epoch</span> 30, <span>loss</span>: 1<span>.316</span>, <span>val</span> <span>acc</span>: 0<span>.720</span> (<span>best</span> 0<span>.724</span>), <span>test</span> <span>acc</span>: 0<span>.731</span> (<span>best</span> 0<span>.731</span>)
<span>In</span> <span>epoch</span> 35, <span>loss</span>: 1<span>.148</span>, <span>val</span> <span>acc</span>: 0<span>.726</span> (<span>best</span> 0<span>.726</span>), <span>test</span> <span>acc</span>: 0<span>.728</span> (<span>best</span> 0<span>.728</span>)
<span>In</span> <span>epoch</span> 40, <span>loss</span>: 0<span>.981</span>, <span>val</span> <span>acc</span>: 0<span>.742</span> (<span>best</span> 0<span>.744</span>), <span>test</span> <span>acc</span>: 0<span>.754</span> (<span>best</span> 0<span>.747</span>)
<span>In</span> <span>epoch</span> 45, <span>loss</span>: 0<span>.822</span>, <span>val</span> <span>acc</span>: 0<span>.750</span> (<span>best</span> 0<span>.750</span>), <span>test</span> <span>acc</span>: 0<span>.764</span> (<span>best</span> 0<span>.764</span>)
<span>In</span> <span>epoch</span> 50, <span>loss</span>: 0<span>.678</span>, <span>val</span> <span>acc</span>: 0<span>.764</span> (<span>best</span> 0<span>.764</span>), <span>test</span> <span>acc</span>: 0<span>.766</span> (<span>best</span> 0<span>.766</span>)
<span>In</span> <span>epoch</span> 55, <span>loss</span>: 0<span>.552</span>, <span>val</span> <span>acc</span>: 0<span>.770</span> (<span>best</span> 0<span>.770</span>), <span>test</span> <span>acc</span>: 0<span>.766</span> (<span>best</span> 0<span>.766</span>)
<span>In</span> <span>epoch</span> 60, <span>loss</span>: 0<span>.447</span>, <span>val</span> <span>acc</span>: 0<span>.774</span> (<span>best</span> 0<span>.774</span>), <span>test</span> <span>acc</span>: 0<span>.764</span> (<span>best</span> 0<span>.764</span>)
<span>In</span> <span>epoch</span> 65, <span>loss</span>: 0<span>.361</span>, <span>val</span> <span>acc</span>: 0<span>.778</span> (<span>best</span> 0<span>.778</span>), <span>test</span> <span>acc</span>: 0<span>.772</span> (<span>best</span> 0<span>.772</span>)
<span>In</span> <span>epoch</span> 70, <span>loss</span>: 0<span>.292</span>, <span>val</span> <span>acc</span>: 0<span>.782</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.771</span> (<span>best</span> 0<span>.771</span>)
<span>In</span> <span>epoch</span> 75, <span>loss</span>: 0<span>.238</span>, <span>val</span> <span>acc</span>: 0<span>.778</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.775</span> (<span>best</span> 0<span>.771</span>)
<span>In</span> <span>epoch</span> 80, <span>loss</span>: 0<span>.196</span>, <span>val</span> <span>acc</span>: 0<span>.776</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.778</span> (<span>best</span> 0<span>.771</span>)
<span>In</span> <span>epoch</span> 85, <span>loss</span>: 0<span>.162</span>, <span>val</span> <span>acc</span>: 0<span>.774</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.778</span> (<span>best</span> 0<span>.771</span>)
<span>In</span> <span>epoch</span> 90, <span>loss</span>: 0<span>.136</span>, <span>val</span> <span>acc</span>: 0<span>.774</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.777</span> (<span>best</span> 0<span>.771</span>)
<span>In</span> <span>epoch</span> 95, <span>loss</span>: 0<span>.115</span>, <span>val</span> <span>acc</span>: 0<span>.770</span> (<span>best</span> 0<span>.782</span>), <span>test</span> <span>acc</span>: 0<span>.776</span> (<span>best</span> 0<span>.771</span>)
<p>
</p><p><span style="font-size:18px"><strong>往期精彩</strong></span></p><p><span>【知识图谱系列】基于生成式的知识图谱预训练模型</span></p><p><span>【知识图谱系列】基于实数或复数空间的知识图谱嵌入</span>
</p><p><span style="font-size:14px"/></p><p><span>【知识图谱系列】知识图谱多跳推理之强化学习</span>
</p><p><span style="font-size:14px"/></p><p><span>【知识图谱系列】动态时序知识图谱EvolveGCN</span></p><p><span style="font-size:14px"/></p><p>【机器学习系列】机器学习中的两大学派</p>

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 221,548评论 6 515
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 94,497评论 3 399
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 167,990评论 0 360
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 59,618评论 1 296
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 68,618评论 6 397
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 52,246评论 1 308
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,819评论 3 421
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,725评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 46,268评论 1 320
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 38,356评论 3 340
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,488评论 1 352
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 36,181评论 5 350
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,862评论 3 333
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,331评论 0 24
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,445评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,897评论 3 376
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,500评论 2 359

推荐阅读更多精彩内容