基于PGL的图嵌入算法metapath2vec源码解读

Metapath2vec是一种基于深度学习的网络表示学习方法,用于学习复杂网络中节点的向量表示。它是在元路径(metapath)的基础上进行的。元路径是指网络中的一条特定类型的节点序列,例如在社交网络中,用户之间的关系可以用元路径“用户-群组-用户”来表示。Metapath2vec利用Skip-gram模型,从元路径中提取出节点序列,然后将这些节点序列作为输入,训练一个神经网络来学习节点向量表示(embedding)。通过这种方式,Metapath2vec可以在复杂网络中学习节点之间的关系,得到节点embedding后,再加一层全连接层和softmax就可以做节点分类,通过计算节点间相似度就可以做推荐。

目的

  生成节点embedding

metapath2vec算法

  metapath是一个异构图的随机游走算法,比如,A-P-A表示一篇论文有共同的作者,A1-P1-C1-P2-A3表示一个会议上有两个个不同作者发表。它是一个对称的结构。当达到最大长度或者找不到合适的节点才结束游走。


metapath

  整体的框架如下,当完成元路径随机游走后,我们会得到一些元路径,这个路径像是NLP中的句子,NLP中有通过skip-gram来预测词。我们的训练数据需要得到一个pair对,比如下图A4P3,他们的label为1;而A4P5是一个负样本,他们的label为0,这个过程像是NLP中的二分类,训练完成后,就会得到一个模型,然后根据这个模型推理所有节点的embeding。


metapath整体框架

  接下来的工作我们就解读下基于PGL的metapath2vec算法代码,metapath2vec源代码。大部分的工作就是将图转化为skip-gram算法所需要的中心词和周围词的关系,不明白的skip-gram的可以参考上一遍文章。

datasets文件夹
├── dataset.py
├── helper.py
├── node.py
├── pair.py
├── sampling.py
└── walk.py
数据生成文件调用关系.png
#config.yaml
task_name: distributed_metapath2vec

# ---------------------------数据配置-------------------------------------------------#
# for data preprocessing
data_path: ./data/net_aminer
author_label_file: ./data/label/googlescholar.8area.author.label.txt
venue_label_file: ./data/label/googlescholar.8area.venue.label.txt
processed_path: ./graph_data

# for pgl graph engine
etype2files: "p2a:./graph_data/paper2author_edges.txt,p2c:./graph_data/paper2conf_edges.txt"
ntype2files: "p:./graph_data/node_types.txt,a:./graph_data/node_types.txt,c:./graph_data/node_types.txt"
#表示无向图,会生成两条数据
symmetry: True
#metapath是对成的
meta_path: "c2p-p2a-a2p-p2c"
first_node_type: "c"

shard_num: 100

# walk游走的最大长度
walk_len: 24
#skip-gram 中的skip大小
win_size: 3
#负采样的个数
neg_num: 5
#游走最大的度
walk_times: 20


# ---------------------------模型参数配置---------------------------------------------#
model_type: SkipGramModel
warm_start_from: null
num_nodes: 5000000
embed_size: 64
sparse_embed: False

# ---------------------------训练参数配置---------------------------------------------#
epochs: 1
num_workers: 4
lr: 0.001
lazy_mode: False
batch_node_size: 200
batch_pair_size: 1000
pair_stream_shuffle_size: 100000
log_dir: ./logs
output_dir: ./outputs
save_dir: ./checkpoints
log_steps: 1000

dropbox文件不好下载, 现已经上传到百度云盘链接: net_aminer 数据集 提取码: s9iv

处理数据集

python data_preprocess.py --config config.yaml

node_types.txt 格式node_type"\t" node_id

c       0
c       1
c       2
c       3
a       3885
a       3886
p       4891796
p       4891797

paper2author_edges.txt 格式paper_id"\t"author_id

1738139 1105483
1963494 1629565
2128630 418483
2509017 841304
3536281 1611393

paper2conf_edges.txt 格式paper_id"\t"conf_id

2090976 1108
4666445 2808
4704329 2055
1951251 3195
3680120 779
# dataset.py
class TrainPairDataset(StreamDataset):
    def __init__(self, config, ip_list_file, mode="train"):
        self.config = config
        self.ip_list_file = ip_list_file
        self.mode = mode

    def __iter__(self):
        client_id = os.getpid()
        self.graph = DistGraphClient(self.config, self.config.shard_num,
                                     self.ip_list_file, client_id)

        self.generator = PairGenerator(
            self.config,
            self.graph,
            mode=self.mode,
            rank=self._worker_info.fid,
            nrank=self._worker_info.num_workers)

        for data in self.generator():
            yield data

class CollateFn(object):
    def __init__(self):
        pass

    def __call__(self, batch_data):
        src_list = []
        pos_list = []
        for src, pos in batch_data:
            src_list.append(src)
            pos_list.append(pos)
        #model获取这里的数据
        src_list = np.array(src_list, dtype="int64").reshape(-1, 1)
        pos_list = np.array(pos_list, dtype="int64").reshape(-1, 1)
        return {'src': src_list, 'pos': pos_list}
# pair.py
class PairGenerator(object):
    #...
    def __call__(self):
        iterval = 20000000 * 24 // self.config.walk_len
        pair_count = 0
        for walks in self.walk_generator():
            try:
                for walk in walks:
                    index = np.arange(0, len(walk), dtype="int64")
                    batch_s, batch_p = skip_gram_gen_pair(index,
                                                          self.config.win_size)
                    for s, p in zip(batch_s, batch_p):
                        # 返回给CollateFn
                        yield walk[s], walk[p]
                        pair_count += 1
                        if pair_count % iterval == 0 and self.rank == 0:
                            log.info("[%s] pairs have been loaded in rank [%s]" \
                                    % (pair_count, self.rank))

            except Exception as e:
                log.exception(e)

        log.info("total [%s] pairs in rank [%s]" % (pair_count, self.rank))

异构图的随机游走,返回metapath节点路径

#sampling.py
def metapath_randomwalk_with_walktimes(graph,
                                       start_nodes,
                                       metapath,
                                       walk_length,
                                       walk_times=10,
                                       alias_name=None,
                                       events_name=None):
    """Implementation of metapath random walk in heterogeneous graph.

    Args:
        graph: instance of pgl heterogeneous graph
        start_nodes: start nodes to generate walk
        metapath: meta path for sample nodes.
            e.g: "c2p-p2a-a2p-p2c"
        walk_length: the walk length

    Return:
        a list of metapath walks.

    """

    edge_types = metapath.split('-')
    walk = []
    cur_nodes = []
    # start_nodes size=200
    neighbors = graph.sample_successor(
        np.array(
            start_nodes, dtype="uint64"),
        max_degree=walk_times,
        edge_type=edge_types[0])
    # 将开始节点和继承节点加入到返回的walk中,walk 的size=200*20
    for neigh, walk_id in zip(neighbors, start_nodes):
        for node_id in neigh:
            walk.append([walk_id, node_id])
            cur_nodes.append(node_id)
  
    if len(walk) == 0:
        return walk

    cur_walk_ids = np.arange(0, len(walk))
    cur_nodes = np.array(cur_nodes, dtype="uint64")
    #  if np.random.random() - 0.02 < 0:
    #      sys.stderr.write("length of walks %s\n" % (len(walk)))

    mp_len = len(edge_types)
    for i in range(1, walk_length - 1):
        cur_succs = graph.sample_successor(
            cur_nodes, max_degree=1, edge_type=edge_types[i % mp_len])
        mask = np.array([len(succ) > 0 for succ in cur_succs], dtype="bool")
        # mask: array([ True,  True,  True, ...,  True,  True,  True])
        # np.any()是或操作,任意一个元素为True,输出为True
        # 所有的节点都没有出节点的时候才结束
        if np.any(mask):
            # 取出为True的节点
            cur_walk_ids = cur_walk_ids[mask]
            cur_nodes = cur_nodes[mask]
            cur_succs = np.array(cur_succs, dtype="object")[mask]
        else:
            # stop when all nodes have no successor
            break
        #walk[0] 就是一个完整的metapath
        nxt_cur_nodes = []
        for s, walk_id in zip(cur_succs, cur_walk_ids):
            walk[walk_id].append(s[0])
            nxt_cur_nodes.append(s[0])
        cur_nodes = np.array(nxt_cur_nodes, dtype="uint64")
    return walk

# model.py
class SkipGramModel(nn.Layer):
    #...
    def forward(self, feed_dict):
        src_embed = self.embedding(feed_dict['src'])
        pos_embed = self.embedding(feed_dict['pos'])

        # batch neg sample
        # 负采样在这里生成
        batch_size = feed_dict['pos'].shape[0]
        neg_idx = paddle.randint(
            low=0, high=batch_size, shape=[batch_size, self.neg_num])

        negs = []
        for i in range(self.neg_num):
            tmp = paddle.gather(pos_embed, neg_idx[:, i])
            tmp = paddle.reshape(tmp, [-1, 1, self.embed_size])
            negs.append(tmp)

        neg_embed = paddle.concat(negs, axis=1)
        src_embed = paddle.reshape(src_embed, [-1, 1, self.embed_size])
        pos_embed = paddle.reshape(pos_embed, [-1, 1, self.embed_size])

        # [batch_size, 1, 1]
        pos_logits = paddle.matmul(src_embed, pos_embed, transpose_y=True)
        # [batch_size, 1, neg_num]
        neg_logits = paddle.matmul(src_embed, neg_embed, transpose_y=True)

        ones_label = paddle.ones_like(pos_logits)
        pos_loss = self.loss_fn(pos_logits, ones_label)

        zeros_label = paddle.zeros_like(neg_logits)
        neg_loss = self.loss_fn(neg_logits, zeros_label)

        loss = (pos_loss + neg_loss) / 2
        return loss
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,189评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,577评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,857评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,703评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,705评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,620评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,995评论 3 396
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,656评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,898评论 1 298
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,639评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,720评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,395评论 4 319
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,982评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,953评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,195评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,907评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,472评论 2 342

推荐阅读更多精彩内容