昇腾AI4S图机器学习:DGL消息传递接口的PyG替换

背景介绍

DGL (Deep Graph Learning) 和 PyG (Pytorch Geometric) 是两个主流的图神经网络库,它们在API设计和底层实现上有一定差异,在不同场景下,研究人员会使用不同的依赖库,昇腾NPU对PyG图机器学习库的支持亲和度更高,因此有些时候需要做DGL接口的PyG替换。

SE3Transformer在RFdiffusion蛋白质设计模型中(GitHub - RosettaCommons/RFdiffusion: Code for running RFdiffusion)作为核心组件,负责处理蛋白质结构的几何信息。其架构基于图神经网络,通过SE(3)等变性实现对三维旋转和平移的不变性特征提取。本系列以RFDiffusion模型中的SE3Transformer为例,讲解如何将DGL中的接口替换为PyG实现。

在本文中,主要展示消息传递接口的PyG替换。

消息传递接口 

一、边-节点消息传递 (EdgeSoftmax + Aggregation)

位置: 

rfdiffusion/modules/equivariant_attention/modules.py 中的 TransformerLayer 

输入: 

- 节点特征: x , 形状为(N, F) 

- 边特征: edge_attr , 形状为(E, F') 

- 图结构: graph 

输出: 

- 更新的节点特征: 形状为(N, F_out) 

DGL函数:

- dgl.nn.EdgeSoftmax:对边特征进行归一化 

- dgl.function.copy_edge:复制边特征

- dgl.function.sum:聚合消息 

数学逻辑:

1. 计算注意力分数: a_{ij}=\mathrm{softmax}_j(e_{ij})

2. 消息聚合: h_i^{\prime}=\sum_{j\in\mathcal{N}(i)}a_{ij}\cdot h_j

PyG实现:


二、矢量特征消息传递

位置: 

rfdiffusion/modules/equivariant_attention/modules.py 中的 AttentionBlockSE3

输入:

- 标量特征: feat_scalar , 形状为(N, F_s) 

- 矢量特征: feat_vector , 形状为(N, F_v, 3) 

- 图结构: graph 

输出: 

- 更新的标量和矢量特征 

DGL函数: 

- dgl.nn.EdgeSoftmax:边特征softmax 

- g.send_and_recv:消息传递与聚合

数学逻辑:

1. m_{ij}=f_\mathrm{att}(h_i^s,h_j^s,h_i^v,h_j^v)

2. 矢量特征旋转: h_j^v\cdot R_{ij} ,其中 R_{ij}R_{ij}是相对方向 

PyG实现关键点: 

- 需要自定义消息传递函数

- 实现等变性旋转操作

- 处理批处理边索引

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容