网上graphx实现最短路径的代码比较多,但是都是scala版本,java版本的实现很少。
1.创建图数据
使用的方法是Graph.apply(),下面看一下scala的该方法的定义:
def apply[VD, ED](vertices : org.apache.spark.rdd.RDD[scala.Tuple2[org.apache.spark.graphx.VertexId, VD]], edges : org.apache.spark.rdd.RDD[org.apache.spark.graphx.Edge[ED]], defaultVertexAttr : VD = { /* compiled code */ }, edgeStorageLevel : org.apache.spark.storage.StorageLevel = { /* compiled code */ }, vertexStorageLevel : org.apache.spark.storage.StorageLevel = { /* compiled code */ })(implicit evidence$18 : scala.reflect.ClassTag[VD], evidence$19 : scala.reflect.ClassTag[ED]) : org.apache.spark.graphx.Graph[VD, ED] = { /* compiled code */ }
其中VD是顶点的属性的类型(可以是list或者Tuple),ED是边的属性的类型(可以是list或者Tuple)。
vertices是顶点的rdd,其中rdd中的元素结构是Tuple2<点id,VD>
edges是边的rdd,其中rdd中的元素结构是Tuple2<边id,VD>
defaultVertexAttr是点属性的默认值,假设创建一条边,1->2,但是id为2的点我没有创建,只创建了1的点,那么这时候就会自动生产一个id为2的点,点的属性就是这个默认值。
edgeStorageLevel点的存储等级
vertexStorageLevel边的存储等级
点属性的classTag
边属性的classTag
其中classTag,包含实际运行时的类的类型。
创建的图数据如下:
代码:
private static Graph<Tuple2<String, Integer>, Tuple2<Integer, Integer>> createGraph(
JavaSparkContext javaSparkContext) {
// 必须是Tuple2<Object, Tuple2<String, Integer>>,不能是Tuple2<Long, Tuple2<String, Integer>>
List<Tuple2<Object, Tuple2<String, Integer>>> vertexList = new ArrayList<>();
vertexList.add(new Tuple2<>(1L, new Tuple2<>("mar_1", 21)));
vertexList.add(new Tuple2<>(2L, new Tuple2<>("mar_2", 22)));
vertexList.add(new Tuple2<>(3L, new Tuple2<>("mar_3", 23)));
vertexList.add(new Tuple2<>(4L, new Tuple2<>("mar_4", 24)));
vertexList.add(new Tuple2<>(5L, new Tuple2<>("mar_5", 25)));
vertexList.add(new Tuple2<>(6L, new Tuple2<>("mar_6", 26)));
vertexList.add(new Tuple2<>(7L, new Tuple2<>("mar_7", 27)));
vertexList.add(new Tuple2<>(8L, new Tuple2<>("mar_8", 28)));
vertexList.add(new Tuple2<>(9L, new Tuple2<>("mar_9", 29)));
JavaRDD<Tuple2<Object, Tuple2<String, Integer>>> vertexRdd = javaSparkContext
.parallelize(vertexList);
List<Edge<Tuple2<Integer, Integer>>> edgeList = new ArrayList<>();
edgeList.add(new Edge<>(1, 2, new Tuple2<>(1, 1)));
edgeList.add(new Edge<>(2, 3, new Tuple2<>(2, 2)));
edgeList.add(new Edge<>(1, 4, new Tuple2<>(3, 3)));
edgeList.add(new Edge<>(1, 5, new Tuple2<>(4, 4)));
edgeList.add(new Edge<>(1, 6, new Tuple2<>(5, 5)));
edgeList.add(new Edge<>(4, 7, new Tuple2<>(6, 6)));
edgeList.add(new Edge<>(7, 8, new Tuple2<>(7, 7)));
edgeList.add(new Edge<>(5, 8, new Tuple2<>(8, 8)));
edgeList.add(new Edge<>(8, 9, new Tuple2<>(9, 9)));
edgeList.add(new Edge<>(6, 9, new Tuple2<>(10, 10)));
edgeList.add(new Edge<>(3, 9, new Tuple2<>(11, 11)));
JavaRDD<Edge<Tuple2<Integer, Integer>>> edgeRdd = javaSparkContext.parallelize(edgeList);
Tuple2<String, Integer> defaultVertex = new Tuple2<>("default", -1);
// ClassTag$.MODULE$.apply(Tuple2.class)所有用到的都改为ClassTag$.MODULE$.apply(Object.class)否则报错
Graph<Tuple2<String, Integer>, Tuple2<Integer, Integer>> graph = Graph
.apply(vertexRdd.rdd(), edgeRdd.rdd(), defaultVertex, StorageLevels.MEMORY_ONLY,
StorageLevels.MEMORY_ONLY, ClassTag$.MODULE$.apply(Object.class),
ClassTag$.MODULE$.apply(Object.class));
graph.vertices().toJavaRDD()
.foreach(x -> System.out.println("vertex id:: " + x._1 + " , attr:: " + x._2));
graph.edges().toJavaRDD().foreach(
x -> System.out.println(
"edge id:: " + x.attr._1 + " , src:: " + x.srcId() + " , dest:: " + x
.dstId() + " weigh:: " + x.attr._2));
return graph;
}
注意其中,点的id要用Object类型的否则会报错;还有classTag也用Object.class的否则也报错,目前不知道什么原因。
2.最短路径
这里用的是Pregel,这里简单介绍一下详细了解可以自行百度,Pregel框架是有谷歌提出,图并行技术框架,以顶点为中心不断的进行算法的迭代和数据同步。
Pregel的迭代过程如下:
- 最开始,图中的所有顶点都会收到一个默认的消息,这个默认值就是方法的第一个参数。
- 各个顶点收到消息后调用vprog函数,生产新的消息。
- 调用sendMsg函数发送消息给下一轮迭代的顶点,这个函数将决定将消息发送给谁。
- 接收到消息的顶点,调用vprog函数(如果收到多个消息,先调用mergeMsg),生产新的消息。其中最开始时是所有顶点都能收到消息的。
- 如果没有顶点收到消息,或者到底迭代次数maxIterations退出计算,完成。
下面看下用到的方法的几个参数:
def pregel[A](initialMsg : A, maxIterations : scala.Int = { /* compiled code */ }, activeDirection : org.apache.spark.graphx.EdgeDirection = { /* compiled code */ })(vprog : scala.Function3[org.apache.spark.graphx.VertexId, VD, A, VD], sendMsg : scala.Function1[org.apache.spark.graphx.EdgeTriplet[VD, ED], scala.Iterator[scala.Tuple2[org.apache.spark.graphx.VertexId, A]]], mergeMsg : scala.Function2[A, A, A])(implicit evidence$6 : scala.reflect.ClassTag[A]) : org.apache.spark.graphx.Graph[VD, ED] = { /* compiled code */ }
initialMsg:
第一轮迭代计算时,所有顶点收到的消息。(A表示消息类型)
maxIterations:
最大迭代次数(整型)
activeDirection:
沿着边迭代的方向。
vprog : scala.Function3[org.apache.spark.graphx.VertexId, VD, A, VD]:
在步骤2中调用的函数,接收消息,然后生产顶点的新的属性。
可以看到这个函数是scala.Function3类型的,3个入参,1个返回值。
第一个入参是点id,第二个参数是点的原来的属性,第三个参数是接收到的消息,返回值是点的新的属性。
scala.Function1[org.apache.spark.graphx.EdgeTriplet[VD, ED], scala.Iterator[scala.Tuple2[org.apache.spark.graphx.VertexId, A]]]:
发送消息的函数,1个入参,1个返回值。
入参是EdgeTriplet类型,这是保存边的信息的一个类,这个类包括源点的属性、目的点的属性、源点的id、目的点的id、以及边的属性。例如,如果a点收到消息,那么这个入参就是以a为源的边,a->b。返回值是发送消息到的点的id(例如b的id),和发送的消息。
mergeMsg : scala.Function2[A, A, A]:
合并函数,每个点可能收到多个消息,需要对消息进行合并。合并后在将消息作为入参调用vprog。
求最短路径(指定起始点)的思路:
- 先对图的点进行加工,即mapVertices,如果是起点,那么该点的属性为0,否则为整型最大值(这个值要大于图中的最长路径),这个属性的意思就是路径的长度,起始点路径长度是0,其他的点路径长度是一个很大的值。
2.然后就开始用Pregel进行迭代,第一次所有点都收到消息(代码里面设置的是整型最大值),收到消息后点的属性和消息取最小值,结果作为点的新的属性。所以起始点的属性为0,其他的为整型最大值。
3.发送消息给下一轮迭代的顶点,函数中判断srcAttr + 1 < dstAttr,如果满足这发送,不满足则不发送,这样只有和起始顶点直接连接的顶点才能收到消息。发送的消息是srcAttr + 1,这正好表示路径的长度。
4.顶点收到消息后,进行merge操作,取最小的,这就模拟了最短路径,例如到该点有两条路径,一条长度为3,一条为长度为2,那么把2作为该点的新的属性。
5.满足迭代条件后结束计算,最后生产的点,包含id和一个属性,这个属性就是起始点到该点的路径长度。
整体代码:
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.StorageLevels;
import org.apache.spark.graphx.Edge;
import org.apache.spark.graphx.EdgeDirection;
import org.apache.spark.graphx.EdgeTriplet;
import org.apache.spark.graphx.Graph;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.JavaConversions;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction1;
import scala.runtime.AbstractFunction2;
import scala.runtime.AbstractFunction3;
public class ShortestPath {
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf().setAppName("shortest_path").setMaster("local[2]");
JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
Graph<Tuple2<String, Integer>, Tuple2<Integer, Integer>> graph = createGraph(
javaSparkContext);
Object obj = Predef1.reflexivity();
scala.Predef.$eq$colon$eq<Tuple2<String, Integer>, Long> ev = (scala.Predef.$eq$colon$eq<Tuple2<String, Integer>, Long>) obj;
Graph<Long, Tuple2<Integer, Integer>> initGraph = graph
.mapVertices(new MapVerticesFunction(), ClassTag$.MODULE$.apply(Object.class), ev);
// initGraph.vertices().toJavaRDD().foreach(x -> System.out.println(x));
Graph<Long, Tuple2<Integer, Integer>> sssp = initGraph.ops()
.pregel((long)Integer.MAX_VALUE, 1, EdgeDirection.Out(), new VertexProgram(), new SendMsgFunction(),
new MergeMsgFunction(), ClassTag$.MODULE$.apply(Object.class));
sssp.vertices().toJavaRDD().foreach(x -> System.out.println(x));
}
static class VertexProgram extends AbstractFunction3<Object, Long, Long, Long> implements
Serializable {
@Override
public Long apply(Object id, Long vertexAttr, Long newVertexAttr) {
Long min = Math.min(vertexAttr, newVertexAttr);
System.out.println("id:: "+id+" :: "+vertexAttr+" :: "+newVertexAttr+" min "+min);
return min;
}
}
static class SendMsgFunction extends
AbstractFunction1<EdgeTriplet<Long, Tuple2<Integer, Integer>>, Iterator<Tuple2<Object, Long>>> implements
Serializable {
@Override
public Iterator<Tuple2<Object, Long>> apply(
EdgeTriplet<Long, Tuple2<Integer, Integer>> triplet) {
long srcAttr = triplet.srcAttr();
long dstAttr = triplet.dstAttr();
// System.out.println(srcAttr+" "+attr+" "+dstAttr);
System.out.println("srcid: "+triplet.srcId()+" destid: "+triplet.dstId()+" srca: "+srcAttr+" desa: "+dstAttr);
if (srcAttr + 1 < dstAttr) {
return JavaConversions.asScalaIterator(
Collections.singletonList(
new Tuple2<Object, Long>(triplet.dstId(), srcAttr + 1))
.iterator());
} else {
return JavaConversions.asScalaIterator(Collections.emptyIterator());
}
}
}
static class MergeMsgFunction extends AbstractFunction2<Long, Long, Long>
implements Serializable {
@Override
public Long apply(Long a, Long b) {
// System.out.println("marge ->" + Math.min((long) a, (long) b) + " -> " + a + " " + b);
return Math.min(a, b);
}
}
public static class Predef1 {
static public <T> scala.Predef.$eq$colon$eq<T, T> reflexivity() {
return scala.Predef.$eq$colon$eq$.MODULE$.tpEquals();
}
}
static class MapVerticesFunction extends
AbstractFunction2<Object, Tuple2<String, Integer>, Long> implements Serializable {
@Override
public Long apply(Object v1, Tuple2<String, Integer> v2) {
if ((long) v1 == 1) {
return 0L;
} else {
return (long) Integer.MAX_VALUE;
}
}
}
private static Graph<Tuple2<String, Integer>, Tuple2<Integer, Integer>> createGraph(
JavaSparkContext javaSparkContext) {
// 必须是Tuple2<Object, Tuple2<String, Integer>>,不能是Tuple2<Long, Tuple2<String, Integer>>
List<Tuple2<Object, Tuple2<String, Integer>>> vertexList = new ArrayList<>();
vertexList.add(new Tuple2<>(1L, new Tuple2<>("mar_1", 21)));
vertexList.add(new Tuple2<>(2L, new Tuple2<>("mar_2", 22)));
vertexList.add(new Tuple2<>(3L, new Tuple2<>("mar_3", 23)));
vertexList.add(new Tuple2<>(4L, new Tuple2<>("mar_4", 24)));
vertexList.add(new Tuple2<>(5L, new Tuple2<>("mar_5", 25)));
vertexList.add(new Tuple2<>(6L, new Tuple2<>("mar_6", 26)));
vertexList.add(new Tuple2<>(7L, new Tuple2<>("mar_7", 27)));
vertexList.add(new Tuple2<>(8L, new Tuple2<>("mar_8", 28)));
vertexList.add(new Tuple2<>(9L, new Tuple2<>("mar_9", 29)));
JavaRDD<Tuple2<Object, Tuple2<String, Integer>>> vertexRdd = javaSparkContext
.parallelize(vertexList);
List<Edge<Tuple2<Integer, Integer>>> edgeList = new ArrayList<>();
edgeList.add(new Edge<>(1, 2, new Tuple2<>(1, 1)));
edgeList.add(new Edge<>(2, 3, new Tuple2<>(2, 2)));
edgeList.add(new Edge<>(1, 4, new Tuple2<>(3, 3)));
edgeList.add(new Edge<>(1, 5, new Tuple2<>(4, 4)));
edgeList.add(new Edge<>(1, 6, new Tuple2<>(5, 5)));
edgeList.add(new Edge<>(4, 7, new Tuple2<>(6, 6)));
edgeList.add(new Edge<>(7, 8, new Tuple2<>(7, 7)));
edgeList.add(new Edge<>(5, 8, new Tuple2<>(8, 8)));
edgeList.add(new Edge<>(8, 9, new Tuple2<>(9, 9)));
edgeList.add(new Edge<>(6, 9, new Tuple2<>(10, 10)));
edgeList.add(new Edge<>(3, 9, new Tuple2<>(11, 11)));
JavaRDD<Edge<Tuple2<Integer, Integer>>> edgeRdd = javaSparkContext.parallelize(edgeList);
Tuple2<String, Integer> defaultVertex = new Tuple2<>("default", -1);
// ClassTag$.MODULE$.apply(Tuple2.class)所有用到的都改为ClassTag$.MODULE$.apply(Object.class)否则报错
Graph<Tuple2<String, Integer>, Tuple2<Integer, Integer>> graph = Graph
.apply(vertexRdd.rdd(), edgeRdd.rdd(), defaultVertex, StorageLevels.MEMORY_ONLY,
StorageLevels.MEMORY_ONLY, ClassTag$.MODULE$.apply(Object.class),
ClassTag$.MODULE$.apply(Object.class));
graph.vertices().toJavaRDD()
.foreach(x -> System.out.println("vertex id:: " + x._1 + " , attr:: " + x._2));
graph.edges().toJavaRDD().foreach(
x -> System.out.println(
"edge id:: " + x.attr._1 + " , src:: " + x.srcId() + " , dest:: " + x
.dstId() + " weigh:: " + x.attr._2));
return graph;
}
}
看下输出结果:
(1,0)
(3,2147483647)
(7,2147483647)
(4,1)
(9,2147483647)
(5,1)
(6,1)
(8,2147483647)
(2,1)
上面的迭代次数只设置为1次,就是以起始点开始向外第一层(可以自行修改),对比图看结果,和1点直接连接的是2、4、5、6,结果中对应的长度为1,其他点为Integer.MAX_VALUE,如果想取固定的目的点,加过过滤即可。
注意:
由于java和scala的兼容性问题,如果你的idea爆红,不用理会,不影响编译和运行。
maven依赖:
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>2.3.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.3.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.3.0</version>
</dependency>
版权声明
1.以上文章为本人原创,首发于简书网,文责自负。
2.未经作者同意不得转载,如需转载请联系作者。感谢。