Metapath2vec是一种基于深度学习的网络表示学习方法,用于学习复杂网络中节点的向量表示。它是在元路径(metapath)的基础上进行的。元路径是指网络中的一条特定类型的节点序列,例如在社交网络中,用户之间的关系可以用元路径“用户-群组-用户”来表示。Metapath2vec利用Skip-gram模型,从元路径中提取出节点序列,然后将这些节点序列作为输入,训练一个神经网络来学习节点向量表示(embedding)。通过这种方式,Metapath2vec可以在复杂网络中学习节点之间的关系,得到节点embedding后,再加一层全连接层和softmax就可以做节点分类,通过计算节点间相似度就可以做推荐。
目的
生成节点embedding
metapath2vec算法
metapath是一个异构图的随机游走算法,比如,A-P-A表示一篇论文有共同的作者,A1-P1-C1-P2-A3表示一个会议上有两个个不同作者发表。它是一个对称的结构。当达到最大长度或者找不到合适的节点才结束游走。
整体的框架如下,当完成元路径随机游走后,我们会得到一些元路径,这个路径像是NLP中的句子,NLP中有通过skip-gram来预测词。我们的训练数据需要得到一个pair对,比如下图A4P3,他们的label为1;而A4P5是一个负样本,他们的label为0,这个过程像是NLP中的二分类,训练完成后,就会得到一个模型,然后根据这个模型推理所有节点的embeding。
接下来的工作我们就解读下基于PGL的metapath2vec算法代码,metapath2vec源代码。大部分的工作就是将图转化为skip-gram算法所需要的中心词和周围词的关系,不明白的skip-gram的可以参考上一遍文章。
datasets文件夹
├── dataset.py
├── helper.py
├── node.py
├── pair.py
├── sampling.py
└── walk.py
#config.yaml
task_name: distributed_metapath2vec
# ---------------------------数据配置-------------------------------------------------#
# for data preprocessing
data_path: ./data/net_aminer
author_label_file: ./data/label/googlescholar.8area.author.label.txt
venue_label_file: ./data/label/googlescholar.8area.venue.label.txt
processed_path: ./graph_data
# for pgl graph engine
etype2files: "p2a:./graph_data/paper2author_edges.txt,p2c:./graph_data/paper2conf_edges.txt"
ntype2files: "p:./graph_data/node_types.txt,a:./graph_data/node_types.txt,c:./graph_data/node_types.txt"
#表示无向图,会生成两条数据
symmetry: True
#metapath是对成的
meta_path: "c2p-p2a-a2p-p2c"
first_node_type: "c"
shard_num: 100
# walk游走的最大长度
walk_len: 24
#skip-gram 中的skip大小
win_size: 3
#负采样的个数
neg_num: 5
#游走最大的度
walk_times: 20
# ---------------------------模型参数配置---------------------------------------------#
model_type: SkipGramModel
warm_start_from: null
num_nodes: 5000000
embed_size: 64
sparse_embed: False
# ---------------------------训练参数配置---------------------------------------------#
epochs: 1
num_workers: 4
lr: 0.001
lazy_mode: False
batch_node_size: 200
batch_pair_size: 1000
pair_stream_shuffle_size: 100000
log_dir: ./logs
output_dir: ./outputs
save_dir: ./checkpoints
log_steps: 1000
dropbox文件不好下载, 现已经上传到百度云盘链接: net_aminer 数据集 提取码: s9iv
处理数据集
python data_preprocess.py --config config.yaml
node_types.txt 格式node_type"\t" node_id
c 0
c 1
c 2
c 3
a 3885
a 3886
p 4891796
p 4891797
paper2author_edges.txt 格式paper_id"\t"author_id
1738139 1105483
1963494 1629565
2128630 418483
2509017 841304
3536281 1611393
paper2conf_edges.txt 格式paper_id"\t"conf_id
2090976 1108
4666445 2808
4704329 2055
1951251 3195
3680120 779
# dataset.py
class TrainPairDataset(StreamDataset):
def __init__(self, config, ip_list_file, mode="train"):
self.config = config
self.ip_list_file = ip_list_file
self.mode = mode
def __iter__(self):
client_id = os.getpid()
self.graph = DistGraphClient(self.config, self.config.shard_num,
self.ip_list_file, client_id)
self.generator = PairGenerator(
self.config,
self.graph,
mode=self.mode,
rank=self._worker_info.fid,
nrank=self._worker_info.num_workers)
for data in self.generator():
yield data
class CollateFn(object):
def __init__(self):
pass
def __call__(self, batch_data):
src_list = []
pos_list = []
for src, pos in batch_data:
src_list.append(src)
pos_list.append(pos)
#model获取这里的数据
src_list = np.array(src_list, dtype="int64").reshape(-1, 1)
pos_list = np.array(pos_list, dtype="int64").reshape(-1, 1)
return {'src': src_list, 'pos': pos_list}
# pair.py
class PairGenerator(object):
#...
def __call__(self):
iterval = 20000000 * 24 // self.config.walk_len
pair_count = 0
for walks in self.walk_generator():
try:
for walk in walks:
index = np.arange(0, len(walk), dtype="int64")
batch_s, batch_p = skip_gram_gen_pair(index,
self.config.win_size)
for s, p in zip(batch_s, batch_p):
# 返回给CollateFn
yield walk[s], walk[p]
pair_count += 1
if pair_count % iterval == 0 and self.rank == 0:
log.info("[%s] pairs have been loaded in rank [%s]" \
% (pair_count, self.rank))
except Exception as e:
log.exception(e)
log.info("total [%s] pairs in rank [%s]" % (pair_count, self.rank))
异构图的随机游走,返回metapath节点路径
#sampling.py
def metapath_randomwalk_with_walktimes(graph,
start_nodes,
metapath,
walk_length,
walk_times=10,
alias_name=None,
events_name=None):
"""Implementation of metapath random walk in heterogeneous graph.
Args:
graph: instance of pgl heterogeneous graph
start_nodes: start nodes to generate walk
metapath: meta path for sample nodes.
e.g: "c2p-p2a-a2p-p2c"
walk_length: the walk length
Return:
a list of metapath walks.
"""
edge_types = metapath.split('-')
walk = []
cur_nodes = []
# start_nodes size=200
neighbors = graph.sample_successor(
np.array(
start_nodes, dtype="uint64"),
max_degree=walk_times,
edge_type=edge_types[0])
# 将开始节点和继承节点加入到返回的walk中,walk 的size=200*20
for neigh, walk_id in zip(neighbors, start_nodes):
for node_id in neigh:
walk.append([walk_id, node_id])
cur_nodes.append(node_id)
if len(walk) == 0:
return walk
cur_walk_ids = np.arange(0, len(walk))
cur_nodes = np.array(cur_nodes, dtype="uint64")
# if np.random.random() - 0.02 < 0:
# sys.stderr.write("length of walks %s\n" % (len(walk)))
mp_len = len(edge_types)
for i in range(1, walk_length - 1):
cur_succs = graph.sample_successor(
cur_nodes, max_degree=1, edge_type=edge_types[i % mp_len])
mask = np.array([len(succ) > 0 for succ in cur_succs], dtype="bool")
# mask: array([ True, True, True, ..., True, True, True])
# np.any()是或操作,任意一个元素为True,输出为True
# 所有的节点都没有出节点的时候才结束
if np.any(mask):
# 取出为True的节点
cur_walk_ids = cur_walk_ids[mask]
cur_nodes = cur_nodes[mask]
cur_succs = np.array(cur_succs, dtype="object")[mask]
else:
# stop when all nodes have no successor
break
#walk[0] 就是一个完整的metapath
nxt_cur_nodes = []
for s, walk_id in zip(cur_succs, cur_walk_ids):
walk[walk_id].append(s[0])
nxt_cur_nodes.append(s[0])
cur_nodes = np.array(nxt_cur_nodes, dtype="uint64")
return walk
# model.py
class SkipGramModel(nn.Layer):
#...
def forward(self, feed_dict):
src_embed = self.embedding(feed_dict['src'])
pos_embed = self.embedding(feed_dict['pos'])
# batch neg sample
# 负采样在这里生成
batch_size = feed_dict['pos'].shape[0]
neg_idx = paddle.randint(
low=0, high=batch_size, shape=[batch_size, self.neg_num])
negs = []
for i in range(self.neg_num):
tmp = paddle.gather(pos_embed, neg_idx[:, i])
tmp = paddle.reshape(tmp, [-1, 1, self.embed_size])
negs.append(tmp)
neg_embed = paddle.concat(negs, axis=1)
src_embed = paddle.reshape(src_embed, [-1, 1, self.embed_size])
pos_embed = paddle.reshape(pos_embed, [-1, 1, self.embed_size])
# [batch_size, 1, 1]
pos_logits = paddle.matmul(src_embed, pos_embed, transpose_y=True)
# [batch_size, 1, neg_num]
neg_logits = paddle.matmul(src_embed, neg_embed, transpose_y=True)
ones_label = paddle.ones_like(pos_logits)
pos_loss = self.loss_fn(pos_logits, ones_label)
zeros_label = paddle.zeros_like(neg_logits)
neg_loss = self.loss_fn(neg_logits, zeros_label)
loss = (pos_loss + neg_loss) / 2
return loss