昇腾AI4S图机器学习:DGL图构建接口的PyG替换

背景介绍

DGL (Deep Graph Learning) 和 PyG (Pytorch Geometric) 是两个主流的图神经网络库,它们在API设计和底层实现上有一定差异,在不同场景下,研究人员会使用不同的依赖库,昇腾NPU对PyG图机器学习库的支持亲和度更高,因此有些时候需要做DGL接口的PyG替换。

SE3Transformer在RFdiffusion蛋白质设计模型中(GitHub - RosettaCommons/RFdiffusion: Code for running RFdiffusion)作为核心组件,负责处理蛋白质结构的几何信息。其架构基于图神经网络,通过SE(3)等变性实现对三维旋转和平移的不变性特征提取。本系列以RFDiffusion模型中的SE3Transformer为例,讲解如何将DGL中的接口替换2为PyG实现。在本文中,主要展示图构建结构的替换。

DGL图构建接口的PyG替换(make_full_graph和make_topk_graph)

make_full_graph 函数

位置: 

- rfdiffusion/util_module.py 

输入:

- xyz: 蛋白质骨架坐标,形状为(B, L, 3)或(B, L, 3, 3) 

- pair: 成对特征,形状为(B, L, L, E) 

- idx:残基索引 

输出: 

- G : DGL图 

- edge_feats:边特征 

调用DGL函数: 

- dgl.graph:创建图结构 

数学逻辑: 

1. 提取氨基酸相对位置 

2. 构建完全连接图

3. 设置边特征和节点特征 

PyG实现代码: 

make_topk_graph

位置: 

- rfdiffusion/util_module.py

输入和输出:

- 与 make_full_graph 类似,但构建k近邻图而非完全图

调用DGL函数:

- dgl.graph:创建图结构

数学逻辑:

1. 计算氨基酸之间距离

2. 选择top-k最近邻居

3. 确保每个节点至少有kmin个邻居

优化方案:

- 使用PyG的knn_graph函数简化实现

- 利用PyG的批处理机制处理多图

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容