def mask_test_edges(adj):
# Function to build test set with 10% positive links
# NOTE: Splits are randomized and results might slightly deviate from reported numbers in the paper.
# TODO: Clean up.
# Remove diagonal elements
adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
adj.eliminate_zeros()
# Check that diag is zero: 检查对角线元素是否为0
assert np.diag(adj.todense()).sum() == 0
adj_triu = sp.triu(adj) # triu 取出稀疏矩阵的上三角部分的非零元素 元素值及其坐标
adj_tuple = sparse_to_tuple(adj_triu) # 将上三角转为tuple 返回三元组形式
edges = adj_tuple[0] # 返回坐标值 coords
# train_edges = adj_tuple[0]
edges_all = sparse_to_tuple(adj)[0] # 这是从整个邻接矩阵得到的边的坐标
# num_test = int(np.floor(edges.shape[0] / 10.)) # 10%数量的边作为测试集
num_val = int(np.floor(edges.shape[0] / 10.)) # 5%数量的边作为验证集
all_edge_idx = list(range(edges.shape[0])) # edges应该是一个两位数组 每一行是一个坐标 列数就是所有边的总个数
np.random.shuffle(all_edge_idx) # 哇塞 通过打乱索引 来进行shuffle 而不是直接shuffle原数据
val_edge_idx = all_edge_idx[:num_val] # 验证集边的索引
# test_edge_idx = all_edge_idx[num_val:(num_val + num_test)] # 测试集边的索引
# test_edges = edges[test_edge_idx] # 通过索引指定对应的测试集的边
val_edges = edges[val_edge_idx] # 通过索引指定对应的验证集的边
train_edges = np.delete(edges, np.hstack([val_edge_idx]), axis=0) # 把test和val删掉就是训练集的边
### !!! 注意 因为adj确认了没有0 所以所有的test val 和train edge都是正例!
def ismember(a, b, tol=5):
rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)
return np.any(rows_close)
# np.all 测试沿给定轴的所有数组元素是否都计算为True
# 这里axis=-1 就是沿着纵向找,如果这一行的元素都不为0,则返回True,否则返回False
# np.any 测试沿给定轴的所有数组元素是否有计算为True
# 这里axis=-1 就是沿着纵向找,如果这一行的元素有一个不为0,则返回True,否则返回False
# 这个函数的作用是 如果坐标a是坐标集合b的其中一个,则返回True 也就是is member
test_edges_false = []
# 如果不满足while后面的条件 则跳出循环
# 这个应该是生成和正样本数量相等的负样本(也就是没有连接的边)
# continue是跳过本次循环 但是循环继续 循环没有终止
# break是循环终止
while len(test_edges_false) < len(test_edges):
idx_i = np.random.randint(0, adj.shape[0])
idx_j = np.random.randint(0, adj.shape[0])
if idx_i == idx_j:
continue
if ismember([idx_i, idx_j], edges_all):
# 这个循环是为了排除掉所有正样本 包括上三角和下三角 因为我们最终是要得到负样本的
continue
if test_edges_false: # 这个判别是 只要test_edges_false这个数组非空 就进入到下面这个循环
if ismember([idx_j, idx_i], np.array(test_edges_false)):
continue
if ismember([idx_i, idx_j], np.array(test_edges_false)):
continue
test_edges_false.append([idx_i, idx_j])
# 如果上述的所有判别条件都没有让这个跳出这次循环 则把这个符合规则的样本加入到负样本集里面
i = 0
val_edges_false = []
while len(val_edges_false) < len(val_edges):
idx_i = np.random.randint(0, adj.shape[0])
idx_j = np.random.randint(0, adj.shape[0])
if idx_i == idx_j:
continue
if ismember([idx_i, idx_j], train_edges):
continue
if ismember([idx_j, idx_i], train_edges):
continue
if ismember([idx_i, idx_j], val_edges):
continue
if ismember([idx_j, idx_i], val_edges):
continue
if val_edges_false:
if ismember([idx_j, idx_i], np.array(val_edges_false)):
continue
if ismember([idx_i, idx_j], np.array(val_edges_false)):
continue
val_edges_false.append([idx_i, idx_j])
print(i)
i +=1
# 如果上述的所有判别条件都没有让这个跳出这次循环 则把这个符合规则的样本加入到负样本集里面
# ~是取反的意思
# 以下五句话分别是确认:
# 为测试集、验证集生成的负样本的边坐标不在所有正样本边集合里面
# 验证集正样本和测试集正样本都不在训练集里面
# 验证集测试集正样本木有重叠
# 但是用上下面的五句话往往会造成内存爆炸
# assert ~ismember(test_edges_false, edges_all)
# assert ~ismember(val_edges_false, edges_all)
# assert ~ismember(val_edges, train_edges)
# assert ~ismember(test_edges, train_edges)
# assert ~ismember(val_edges, test_edges)
data = np.ones(train_edges.shape[0])
# data为训练集正样本个数
# Re-build adj matrix
# 这个很好理解 就是根据之前切好的训练集正样本 重构邻接矩阵 只有训练集样本对应的位置为1
adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape)
adj_train = adj_train + adj_train.T
# 第一行的adj_train是上三角矩阵 加上转置之后的(下三角矩阵)变成完整的重构邻接矩阵
# NOTE: these edge lists only contain single direction of edge!
return adj_train, train_edges, val_edges, val_edges_false
# return adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false
make_test_edge(将图转为链路预测二分类问题)
最后编辑于 :
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
相关阅读更多精彩内容
- 【蝴蝶效应】 蝴蝶效应:上个世纪70年代,美国一个名叫洛伦兹的气象学家在解释空气系统理论时说,亚马逊雨林一只蝴蝶...
- ☆☆☆☆☆逻辑回归 (LR)(分类/解决二分类问题) 一. sigmoid函数(逻辑回归函数) 1.t 就是线性回...
- 【火炉炼AI】机器学习008-简单线性分类器解决二分类问题 (本文所使用的Python库和版本号: Python ...
- PaddlePaddle 飞桨 FAQ合集 - 训练问题20 Question: 预测时,当先后加载检测和分类两个...