Task02:消息传递范式

一、消息传递范式介绍

消息传递范式是一种聚合邻接结点信息来更新中心结点信息的范式,实现了图与神经网络的连接。遵循消息传递范式的图神经网络被称为消息传递图神经网络。

图中黄色方框部分展示的是一次邻居结点信息传递到中心结点的过程:

B结点的邻接结点(A,C)的信息经过变换后聚合到B结点,接着B结点信息邻居结点聚合信息一起经过变换得到B结点的新的结点信息

邻居结点信息传递到中心结点的过程会进行多次。

消息传递图神经网络遵循聚合邻接结点信息来更新中心结点信息的过程,来生成结点表征

消息图神经网络可以描述为

\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right)

未经过训练的图神经网络生成的结点表征不是好的结点表征。

通过监督学习对图神经网络做很好的训练,图神经网络可以生成好的结点表征,用于衡量结点之间的相似性。

二、MessagePassing基类初步分析

Pytorch Geometric(PyG)提供了MessagePassing基类,封装了“消息传递”的运行流程。通过继承MessagePassing基类,可以方便地构造消息传递图神经网络。

MessagePassing.__init__(aggr="add", flow="source_to_target", node_dim=-2):

aggr:定义要使用的聚合方案;

flow:定义消息传递的流向;

node_dim:定义沿着哪个维度传播,默认值为-2。

MessagePassing.propagate(edge_index, size=None, **kwargs):

以edge_index(边的端点的索引)和flow(消息的流向)以及一些额外的数据为参数。

propagate()不仅限于基于形状为[N, N]的对称邻接矩阵进行“消息传递过程”。基于非对称的邻接矩阵进行消息传递(当图为二部图时),需要传递参数size=(N, M)。

设置size=None,则认为邻接矩阵是对称的。

MessagePassing.aggregate(...):

将从源节点传递过来的消息聚合在目标结点上,一般可选的聚合方式有sum, mean和max。

MessagePassing.message_and_aggregate(...):

在一些场景里,邻接结点信息变换和邻接结点信息聚合这两项操作可以融合在一起。

三、MessagePassing子类实例

以继承MessagePassing基类的GCNConv类为例,学习如何通过继承MessagePassing基类来实现一个简单的图神经网络。

GCNConv的数学定义为:

\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right),

1.向邻接矩阵添加自环边;

2.对节点表征做线性转换;

3.计算归一化系数;

4.归一化邻接节点的节点表征;

5.将相邻节点表征相加。

实现代码如下:

import torch

from torch_geometric.nn import MessagePassing

from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):

    def __init__(self, in_channels, out_channels):

        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')

        self.lin = torch.nn.Linear(in_channels, out_channels)


    def forward(self, x, edge_index):

        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = self.lin(x)

        row, col = edge_index

        deg = degree(col, x.size(0), dtype=x.dtype)

        deg_inv_sqrt = deg.pow(-0.5)

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, norm=norm)


    def message(self, x_j, norm):

        return norm.view(-1, 1) * x_j

创建一个仅包含一次“消息传递过程”的图神经网络

实现代码如下:

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='dataset/Cora', name='Cora')

data = dataset[0]

net = GCNConv(data.num_features, 64)

h_nodes = net(data.x, data.edge_index)

print(h_nodes.shape)

四、总结

消息传递范式是一种聚合邻接结点信息来更新中心结点信息的范式,将卷积算子推广到了不规则数据领域,实现了图与神经网络的连接。

该范式包含这样三个步骤:

(1)邻接结点信息变换;

(2)邻接结点信息聚合到中心结点;

(3)聚合信息变换

五、作业

总结MessagePassing基类的运行流程

\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right)

定义函数update(),aggregate()函数,message() ;

定义MessagePassing类, 继承nn.Module, 实现forward函数

forward中调用消息传递的起始函数propagate(),  propagate()检查edge_index, 判断是执行 message_and_aggregate() 还是 message(), aggregate() 

message()负责定义结点消息传递的方向[‘source_to_target’, ‘target_to_source’] 

aggregate()负责定义结点聚合使用的聚合方式 [‘max’, ‘add’, ‘sum’]

update()负责更新结点表征


复现一个一层的图神经网络的构造,总结通过继承MessagePassing基类来构造自己的图神经网络类的规范

import torch.nn as nn

from torch_geometric.nn import MessagePassing

from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):

    def __init__(self, in_channels, out_channels, aggr="add"):

        super(GCNConv, self).__init__(aggr=aggr)

        self.lin = nn.Linear(in_channels, out_channels)


    def forward(self, x, edge_index):

        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = self.lin(x)

        row, col = edge_index

        deg = degree(col, x.size(0), dtype=x.dtype)

        deg_inv_sqrt = deg.pow(-0.5)

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, norm=norm)


    def message(self, edge_index, x_j, norm):

        return norm.view(-1, 1) * x_j



DataWhale开源学习资料:

https://github.com/datawhalechina/team-learning-nlp/tree/master/GNN

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容