label propagation算法介绍
标签传播算法(label propagation)的核心思想非常简单:相似的数据应该具有相同的label。LP算法包括两大步骤:1)构造关系网;2)标签传播。
算法具体步骤如下:
1、初始时,给每个节点一个唯一的标签;
2、每个节点使用其邻居节点的标签中最多的标签来更新自身的标签。
3、反复执行步骤2,直到每个节点的标签都不再发生变化为止。一次迭代过程中一个节点标签的更新可以分为同步和异步两种。所谓同步更新,即节点z在第t次迭代的label依据于它的邻居节点在第t-1次迭代时所得的label;异步更新,即节点z在第t次迭代的label依据于第t次迭代已经更新过label的节点和第t次迭代未更新过label的节点在第t-1次迭代时的label。
graphX自带LP算法的缺陷
1、边权重信息不参与计算过程;
2、标签传播结果存在震荡的问题(震荡问题是所有基于BSP模式的框架普遍存在的问题)
关于graphx及BSP可见我另一篇文章 https://www.jianshu.com/p/7190123ad329
边权重与无向图支持的改造
- 基于pregel接口,重新实现了一套传播sendMessage和mergeMessage方法
def sendMessage(e: EdgeTriplet[VertexId, Int]): Iterator[(VertexId, Map[VertexId, Long])] = {
Iterator((e.srcId, Map(e.dstAttr -> e.attr)), (e.dstId, Map(e.srcAttr -> e.attr)))
}
def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long])
: Map[VertexId, Long] = {
(count1.keySet ++ count2.keySet).map { i =>
val count1Val = count1.getOrElse(i, 0L)
val count2Val = count2.getOrElse(i, 0L)
i -> (count1Val + count2Val)
}(collection.breakOut)
}
标签传播震荡问题改造
1、初始化每个节点属性信息,先给每个节点分配不重复标签。如,节点1对应标签1,节点i对应标签i;
2、N个节点,同步找到对应节点邻居,获取此节点邻居标签,找到出现权重最高的标签,若权重最高的标签不止一个,则选择标签值较大的标签赋值给当前节点;
3、若本轮标签重标记后,节点标签不再变化(或者达到设定的最大迭代次数),则迭代停止,否则重复第2步。迭代结果即为RS0;
4、当第3步结束后,以其结果RS0作为节点初始化信息,重新初始化每个节点属性信息,并从第2步开始,再分别迭代1轮、2轮、3轮,结果分别存为 RS1、RS2和RS3;
5、综合RS0、RS1、RS2和RS3的结果,得到最终每个节点的标签结果。如,节点i在RS0、RS1、RS2和RS3中的标签信息分别为(a、b、c、d),选择其中计数最多的标签作为节点i的最终结果,若计数最多的标签不止一个,则选择标签值最大的标签作为节点i最终的标签。
6、至此,label propagation算法结束,每个节点获得的标签即为其最终归属的cluster的id,聚类结束。
效果对比(demo数据)
graphx自带label propagation
-
demo数据展示(边权重表示点之间的亲密度)
- 期望的聚类结果
- graphx自带LPA聚类结果(共分成4个cluster,不同颜色标注)
-
改进算法的聚类结果
效果对比(通过wifi连接获取的关系数据)
1、外卖标签,数据集中该标签占比0.3965。数据集共23137人。训练集16195人,其中带标签6451人;测试集6942人。其中带标签2723人。
a、graphx自带lp:召回率0.0823,精确率0.5450
b、pregel实现改进版lp:召回率0.2281,精确率0.4909
2、学前教育,数据集中该标签占比0.0281。数据集共23137人。训练集16195人,其中带标签462人;测试集6942人。其中带标签188人。
a、graphx自带lp:召回率0.0,精确率0.0
b、pregel实现改进版lp:召回率0.0426,精确率0.0952
3、炒股,数据集中该标签占比0.2192。数据集共23137人。训练集16195人,其中带标签3499人;测试集6942人。其中带标签1572人。
a、graphx自带lp:召回率0.0204,精确率0.3721
b、pregel实现改进版lp:召回率0.1501,精确率0.3940
4、游戏付费意愿用户,数据集中该标签占比0.1312。数据集共23137人。训练集16195人,其中带标签2137人;测试集6942人。其中带标签898人。
a、graphx自带lp:召回率0.0267,精确率0.2857
b、pregel实现改进版lp:召回率0.1292,精确率0.2736
5、35岁+标签,数据集中该标签占比0.3227。数据集共23137人。训练集16195人,其中带标签5204人;测试集6942人。其中带标签2262人。
a、graphx自带lp:召回率0.0469,精确率0.4953
b、pregel实现改进版lp:召回率0.2604,精确率0.5285
完整代码如下(scala)
package Graph.LPA
import org.apache.spark.graphx._
import org.apache.spark._
import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.LongType
object LPARevolution {
def sendMessage(e: EdgeTriplet[VertexId, Int]): Iterator[(VertexId, Map[VertexId, Long])] = {
Iterator((e.srcId, Map(e.dstAttr -> e.attr)), (e.dstId, Map(e.srcAttr -> e.attr)))
}
def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long])
: Map[VertexId, Long] = {
(count1.keySet ++ count2.keySet).map { i =>
val count1Val = count1.getOrElse(i, 0L)
val count2Val = count2.getOrElse(i, 0L)
i -> (count1Val + count2Val)
}(collection.breakOut)
}
// 更新点属性
def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]): VertexId = {
if(message.isEmpty){
attr
}
else{
// print(vid)
// println(" 接收到的消息: ")
// println(message)
// println("最终选择的是:")
// println(message.maxBy(_._2)._1)
message.maxBy(_._2)._1 // 按照计数排序,然后取第一个
}
}
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
val sc = new SparkContext("yarn","lpa-revolution",conf)
val hql = new HiveContext(sc)
// 获取边数据gid,usertags,wifimac,ssid,geohash,day-int
val edges = hql.sql("select cast(src as bigint), cast(dst as bigint)," +
" cast(weight as int) from yangy.graph_edge_table_3day_zoom_weight_hz").rdd.
map(row => Edge(row(0).asInstanceOf[Long], row(1).asInstanceOf[Long], row(2).asInstanceOf[Int]))
// 获取点数据id_2_label_table
val users = hql.sql("select cast(id as bigint), user_tags from yangy.id_2_label_table_3day_zoom_weight_hz").
rdd.map(row => (row(0).asInstanceOf[Long], row(1)))
// val edges = sc.textFile("file:///home/yangy/data/xh_edge_20190530_8day_1_0.txt").
// map{line =>
// val fields = line.split(" ")
// (Edge(fields(0).toLong, fields(1).toLong, fields(2).toInt))
// }
//
// val users = sc.textFile("file:///home/yangy/data/xh_vertex_with_label_1_0.txt").
// map { line =>
// val fields = line.split(" ")
// (fields(0).toLong, fields(1).toLong)
// }
val graph = Graph(vertices = users, edges = edges)
// 图初始化
val initGraph = graph.mapVertices { case (vid, attr) => vid }
// 初始化msg
val initialMessage = Map[VertexId, Long]()
println("迭代结果:")
// 分水岭,开始解决社区震荡&孤立点问题
// ---------------------------------- 迭代多轮 -------------------------------------
val cluster1 = Pregel(initGraph, initialMessage, maxIterations = 100, activeDirection = EdgeDirection.Either)(
vprog = vertexProgram,
sendMsg = sendMessage,
mergeMsg = mergeMessage)
// =====================================================================
// 优雅代码的核心部分,基于前面的结果初始化新的图
// 利用前面迭代结果重新初始化图
// 以此结果作为基础,后续在此基础上继续迭代
val users_trans = cluster1.vertices
val graph_trans = Graph(vertices = users_trans, edges = edges)
val initGraph_trans = graph_trans.mapVertices { case (vid, attr) => attr}
// ======================================================================
// 在基础数据上,额外迭代的轮数
val cluster2 = Pregel(initGraph_trans, initialMessage, maxIterations = 1, activeDirection = EdgeDirection.Either)(
vprog = vertexProgram,
sendMsg = sendMessage,
mergeMsg = mergeMessage)
val cluster3 = Pregel(initGraph_trans, initialMessage, maxIterations = 2, activeDirection = EdgeDirection.Either)(
vprog = vertexProgram,
sendMsg = sendMessage,
mergeMsg = mergeMessage)
val cluster4 = Pregel(initGraph_trans, initialMessage, maxIterations = 3, activeDirection = EdgeDirection.Either)(
vprog = vertexProgram,
sendMsg = sendMessage,
mergeMsg = mergeMessage)
// 构建label propagation结果dataframe
val colNames = "id,group_id"
val schema = StructType(colNames.split(",").map(column => StructField(column, LongType)))
// 获取每个id的分组信息,字段名是id, group_id
val groupDf1 = hql.createDataFrame(cluster1.vertices.map(x=> Row(x._1, x._2)), schema)
val groupDf2 = hql.createDataFrame(cluster2.vertices.map(x=> Row(x._1, x._2)), schema)
val groupDf3 = hql.createDataFrame(cluster3.vertices.map(x=> Row(x._1, x._2)), schema)
val groupDf4 = hql.createDataFrame(cluster4.vertices.map(x=> Row(x._1, x._2)), schema)
val group_union_df = groupDf1.unionAll(groupDf2).unionAll(groupDf3).unionAll(groupDf4)
// 选取合适的group_id,避免社区震荡
group_union_df.registerTempTable("group_union_table")
// 获取不震荡的group归属信息
val group_no_swing = hql.sql(
"""
|select t2.id as id,
| t2.group_id as group_id
|from
|(
| select t1.id as id,
| t1.group_id as group_id,
| rank() over (partition by t1.id order by t1.cnt, t1.group_id desc) as rank
| from
| (
| select id,
| group_id,
| count(1) as cnt
| from group_union_table
| group by id,
| group_id
| ) t1
|) t2
|where t2.rank = 1
|
""".stripMargin)
group_no_swing.write.mode("overwrite").
saveAsTable("yangy.graphx_cluster_zoom_no_swing_hz_100_table")
group_no_swing.show(20)
sc.stop()
}
}
** 原创内容,若要转载请联系本人 **