数据算法 Hadoop/Spark大数据处理---第四章

本章主要介绍左外连接(LEFT JOIN)


  1. 用传统的mapreduce()方法
  2. 用传统的spark方法
  3. 用spark的leftOutJoin()方法
  4. 用传统的Scala实现
  5. 用Scala的leftOutJoin()方法
  6. 用Scala高效的DataFrame实现


  • 输入user表(user_id,location_id)
  • 输入transaction表(transaction_id,product_id,user_id,quantity,amount)
  • 输出为(product_id,{distinct<location_id> as L, L.size})


  • SELECT product_id,location_id FROM transactions LEFT OUTER JOIN users ON transactions.user_ID = users.user_ID
  • SELECT product_id,count(distinct location_id) FROM transactions LEFT OUTER JOIN users ON transactions.user_ID = users.user_ID group by product_id


header 1 header 2
LeftJoinDriver 提交阶段1作业的驱动器
LeftJoinReducer 左连接归约器
LeftJoinTransactionMapper 左连接交易映射器
LeftJoinUserMapper 左连接用户映射器
SecondarySortPartitioner 对自然键分区
SecondarySortGroupComparator 对自然键分组
LocationCountDriver 提交阶段2作业的驱动器
LocationCountMapper 定义map()完成地址统计
LocationCountReducer 定义reduce()完成地址统计
 public void map(LongWritable key, Text value, Context context) 
      throws java.io.IOException, InterruptedException {
      String[] tokens = StringUtils.split(value.toString(), "\t");
      if (tokens.length == 2) {
         // tokens[0] = user_id
         // tokens[1] = location_id
         // to make sure location arrives before products
         outputKey.set(tokens[0], "1");    // set user_id
         outputValue.set("L", tokens[1]);  // set location_id
         context.write(outputKey, outputValue);
 public void map(LongWritable key, Text value, Context context) 
      throws java.io.IOException, InterruptedException {
      String[] tokens = StringUtils.split(value.toString(), "\t");
      String productID = tokens[1];
      String userID = tokens[2];
      // make sure products arrive at a reducer after location
      outputKey.set(userID, "2");
      outputValue.set("P", productID);
      context.write(outputKey, outputValue);

    public int compare(PairOfStrings first, PairOfStrings second) {
       return first.getLeftElement().compareTo(second.getLeftElement());
    public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2 ) {
        DataInputBuffer buffer = new DataInputBuffer();
        PairOfStrings a = new PairOfStrings();
        PairOfStrings b = new PairOfStrings();
        try {
            buffer.reset(b1, s1, l1);
            buffer.reset(b2, s2, l2);
            return compare(a,b);  
        catch(Exception ex) {
            return -1;


   public void reduce(PairOfStrings key, Iterable<PairOfStrings> values, Context context) 
      throws java.io.IOException, InterruptedException {
      Iterator<PairOfStrings> iterator = values.iterator();
      if (iterator.hasNext()) {
         // firstPair must be location pair
         PairOfStrings firstPair = iterator.next(); 
         if (firstPair.getLeftElement().equals("L")) {
      while (iterator.hasNext()) {
         // the remaining elements must be product pair
         PairOfStrings productPair = iterator.next(); 
         context.write(productID, locationID);
    public void reduce(Text productID, Iterable<Text> locations, Context context)
        throws  IOException, InterruptedException {
        Set<String> set = new HashSet<String>();
        for (Text location: locations) {
        context.write(productID, new LongWritable(set.size()));


      MultipleInputs.addInputPath(job, transactions, TextInputFormat.class, LeftJoinTransactionMapper.class);
      MultipleInputs.addInputPath(job, users, TextInputFormat.class, LeftJoinUserMapper.class);



public static void main(String[] args) throws Exception {
    if (args.length < 2) {
       System.err.println("Usage: SparkLeftOuterJoin <users> <transactions>");
    String usersInputFile = args[0];
    String transactionsInputFile = args[1];
    System.out.println("users="+ usersInputFile);
    System.out.println("transactions="+ transactionsInputFile);
    JavaSparkContext ctx = new JavaSparkContext();
    JavaRDD<String> users = ctx.textFile(usersInputFile, 1);

    // 从一个RDD转换成一个RDD类似(user_id,("L",location))
    // PairFunction<T, K, V>    
    // T => Tuple2<K, V>
    JavaPairRDD<String,Tuple2<String,String>> usersRDD = 
          users.mapToPair(new PairFunction<
                                           String,                // T 
                                           String,                // K
                                           Tuple2<String,String>  // V
                                          >() {
      public Tuple2<String,Tuple2<String,String>> call(String s) {
        String[] userRecord = s.split("\t");
        Tuple2<String,String> location = new Tuple2<String,String>("L", userRecord[1]);
        return new Tuple2<String,Tuple2<String,String>>(userRecord[0], location);
    JavaRDD<String> transactions = ctx.textFile(transactionsInputFile, 1);

    // PairFunction<T, K, V>    
    // T => Tuple2<K, V>
    JavaPairRDD<String,Tuple2<String,String>> transactionsRDD = 
          transactions.mapToPair(new PairFunction<String, String, Tuple2<String,String>>() {
      public Tuple2<String,Tuple2<String,String>> call(String s) {
        String[] transactionRecord = s.split("\t");
        Tuple2<String,String> product = new Tuple2<String,String>("P", transactionRecord[1]);
        return new Tuple2<String,Tuple2<String,String>>(transactionRecord[2], product);
    // union() 函数合并两个RDD
    JavaPairRDD<String,Tuple2<String,String>> allRDD = transactionsRDD.union(usersRDD);
    // 对userID进行排序
    // 变成专业<userID, List[T2("L", location), T2("P", p1), T2("P", p2), T2("P", p3), 
    JavaPairRDD<String, Iterable<Tuple2<String,String>>> groupedRDD = allRDD.groupByKey(); 
    // PairFlatMapFunction<T, K, V> 
    // T => Iterable<Tuple2<K, V>>
    JavaPairRDD<String,String> productLocationsRDD = 
         //                                               T                                                K       V 
         groupedRDD.flatMapToPair(new PairFlatMapFunction<Tuple2<String, Iterable<Tuple2<String,String>>>, String, String>() {
      public Iterator<Tuple2<String,String>> call(Tuple2<String, Iterable<Tuple2<String,String>>> s) {
        // String userID = s._1;  // NOT Needed
        Iterable<Tuple2<String,String>> pairs = s._2;
        String location = "UNKNOWN";
        List<String> products = new ArrayList<String>();
        for (Tuple2<String,String> t2 : pairs) {
            if (t2._1.equals("L")) {
                location = t2._2;
            else {
                // t2._1.equals("P")
        // now emit (K, V) pairs
        List<Tuple2<String,String>> kvList = new ArrayList<Tuple2<String,String>>();
        for (String product : products) {
            kvList.add(new Tuple2<String, String>(product, location));
        return kvList.iterator();
    // 发射过来的是一个个的{product, location}需根据product分组
    JavaPairRDD<String, Iterable<String>> productByLocations = productLocationsRDD.groupByKey();    
    // debug3
    List<Tuple2<String, Iterable<String>>> debug3 = productByLocations.collect();
    System.out.println("--- debug3 begin ---");
    for (Tuple2<String, Iterable<String>> t2 : debug3) {
      System.out.println("debug3 t2._1="+t2._1);
      System.out.println("debug3 t2._2="+t2._2);
    System.out.println("--- debug3 end ---");
    JavaPairRDD<String, Tuple2<Set<String>, Integer>> productByUniqueLocations = 
          productByLocations.mapValues(new Function< Iterable<String>,                   // input
                                                     Tuple2<Set<String>, Integer>        // output
                                                   >() {
      public Tuple2<Set<String>, Integer> call(Iterable<String> s) {
        Set<String> uniqueLocations = new HashSet<String>();
        for (String location : s) {
        return new Tuple2<Set<String>, Integer>(uniqueLocations, uniqueLocations.size());
     // 打印最终的结果
    System.out.println("=== Unique Locations and Counts ===");
    List<Tuple2<String, Tuple2<Set<String>, Integer>>>  debug4 = productByUniqueLocations.collect();
    System.out.println("--- debug4 begin ---");
    for (Tuple2<String, Tuple2<Set<String>, Integer>> t2 : debug4) {
      System.out.println("debug4 t2._1="+t2._1);
      System.out.println("debug4 t2._2="+t2._2);
    System.out.println("--- debug4 end ---");


  • 对users和transaction使用javapairRDD.union操作的话,开销太大
  • 引入定制标识“L”和"P"


JavaPairRDD<String,Tuple2<String,Optional<String>>> joined = transactionsRDD.leftOuterJoin(usersRDD);

JavaPairRDD<String,String> products = joined.mapToPair(new PairFunctions<Tuple2<String,Tuple2<String,Optional<String>>>>,String,String)(){
    public Tuple2<String,String> call(Tuple2<String,Tuple2<String,Optional<String>>> t){
        Tuple2<String,Optional<String>> list = t._2;
        return new Tuple2<String,String>(list._1,list._2.get())



//createCombiner 将一个C转换成一个C(如:创建一个单元素列表)
public <C> JavaPairRDD<K,C> combineByKey(Funciton<V,C> createCombiner,Funciton2<C,V,C> mergeValue,Funciton2<C,C,C> mergeCombiners)

那么在本实例中,我们的目标为各个键创建一个Set<String>,即从String -> Set<String>

Function<String,Set<String>> createCombiner = new Function<String,Set<String>>{
    public Set<String> call(String x){
        Set<String> set = new HashSet<String>();
        return set;

Function2<Set<String>,String,Set<String>> mergerValue = new Function2<Set<String>,String,Set<String>>{
    public Set<String> call(Set<String> set,String x){
        return set;

Function2<Set<String>,Set<String>,Set<String>> mergerCombiners = new Function2<Set<String>,Set<String>,Set<String>>{
    public Set<String> call(Set<String> a,Set<String> b){
        return a;

JavaPairRDD<String,Set<String>> productUniqueLocations = products.combineByKey(createCombiner,mergerValue,mergerCombiners);
Map<String,Set<String>> productMap = productUniqueLocations.CollectAsMap();


def main(args: Array[String]): Unit = {
    if (args.size < 3) {
      println("Usage: LeftOuterJoin <users-data-path> <transactions-data-path> <output-path>")

    val sparkConf = new SparkConf().setAppName("LeftOuterJoin")
    val sc = new SparkContext(sparkConf)

    val usersInputFile = args(0)
    val transactionsInputFile = args(1)
    val output = args(2)

    val usersRaw = sc.textFile(usersInputFile)
    val transactionsRaw = sc.textFile(transactionsInputFile)
    val users = usersRaw.map(line => {
      val tokens = line.split("\t")
      (tokens(0), ("L", tokens(1))) // Tagging Locations with L

    val transactions = transactionsRaw.map(line => {
      val tokens = line.split("\t")
      (tokens(2), ("P", tokens(1))) // Tagging Products with P

    val all = users union transactions
    val grouped = all.groupByKey()

    val productLocations = grouped.flatMap {
      case (userId, iterable) =>
        // span 返回两个iterable
        val (location, products) = iterable span (_._1 == "L")
        val loc = location.headOption.getOrElse(("L", "UNKNOWN"))
        products.filter(_._1 == "P").map(p => (p._2, loc._2)).toSet
    val productByLocations = productLocations.groupByKey()
    val result = productByLocations.map(t => (t._1, t._2.size)) // Return (product, location count) tuple

    result.saveAsTextFile(output) // Saves output to the file.

    // done


def main(args: Array[String]): Unit = {
    if (args.size < 3) {
      println("Usage: SparkLeftOuterJoin <users> <transactions> <output>")
    val sparkConf = new SparkConf().setAppName("SparkLeftOuterJoin")
    val sc = new SparkContext(sparkConf)

    val usersInputFile = args(0)
    val transactionsInputFile = args(1)
    val output = args(2)

    val usersRaw = sc.textFile(usersInputFile)
    val transactionsRaw = sc.textFile(transactionsInputFile)

    val users = usersRaw.map(line => {
      val tokens = line.split("\t")
      (tokens(0), tokens(1))

    val transactions = transactionsRaw.map(line => {
      val tokens = line.split("\t")
      (tokens(2), tokens(1))

    val joined =  transactions leftOuterJoin users

    //返回的格式为k, (v, Some(w))),在这里只对value进行操作
    val productLocations = joined.values.map(f => (f._1, f._2.getOrElse("unknown"))) 
    val productByLocations = productLocations.groupByKey()

    val productWithUniqueLocations = productByLocations.mapValues(_.toSet) // Converting toSet removes duplicates.
    val result = productWithUniqueLocations.map(t => (t._1, t._2.size)) // Return (product, location count) tuple.
    result.saveAsTextFile(output) // Saves output to the file.

    // done


def main(args: Array[String]): Unit = {
    if (args.size < 3) {
      println("Usage: DataFrameLeftOuterJoin <users-data-path> <transactions-data-path> <output-path>")

    val usersInputFile = args(0)
    val transactionsInputFile = args(1)
    val output = args(2)

    val sparkConf = new SparkConf()

    // Use for Spark 1.6.2 or below
    // val sc = new SparkContext(sparkConf)
    // val spark = new SQLContext(sc) 

    // Use below for Spark 2.0.0
    val spark = SparkSession
      .appName("DataFram LeftOuterJoin")

    // Use below for Spark 2.0.0
    val sc = spark.sparkContext

    import spark.implicits._
    import org.apache.spark.sql.types._

    // 定义用户模型
    val userSchema = StructType(Seq(
      StructField("userId", StringType, false),
      StructField("location", StringType, false)))

    val transactionSchema = StructType(Seq(
      StructField("transactionId", StringType, false),
      StructField("productId", StringType, false),
      StructField("userId", StringType, false),
      StructField("quantity", IntegerType, false),
      StructField("price", DoubleType, false)))

    def userRows(line: String): Row = {
      val tokens = line.split("\t")
      Row(tokens(0), tokens(1))

    def transactionRows(line: String): Row = {
      val tokens = line.split("\t")
      Row(tokens(0), tokens(1), tokens(2), tokens(3).toInt, tokens(4).toDouble)

    val usersRaw = sc.textFile(usersInputFile) // Loading user data
    val userRDDRows = usersRaw.map(userRows(_)) // Converting to RDD[org.apache.spark.sql.Row]
    val users = spark.createDataFrame(userRDDRows, userSchema) // obtaining DataFrame from RDD

    val transactionsRaw = sc.textFile(transactionsInputFile) // Loading transactions data
    val transactionsRDDRows = transactionsRaw.map(transactionRows(_)) // Converting to  RDD[org.apache.spark.sql.Row]
    val transactions = spark.createDataFrame(transactionsRDDRows, transactionSchema) // obtaining DataFrame from RDD

    // Approach 1 using DataFrame API
    // 连接两个表,条件是userId相同
    val joined = transactions.join(users, transactions("userId") === users("userId")) // performing join on on userId
    joined.printSchema() //Prints schema on the console
    val product_location = joined.select(joined.col("productId"), joined.col("location")) // Selecting only productId and location
    val product_location_distinct = product_location.distinct // Getting only disting values、
    val products = product_location_distinct.groupBy("productId").count()
    products.show() // Print first 20 records on the console
    products.write.save(output + "/approach1") // Saves output in compressed Parquet format, recommended for large projects.
    products.rdd.saveAsTextFile(output + "/approach1_textFormat") // Converts DataFram to RDD[Row] and saves it to in text file. To see output use cat command, e.g. cat output/approach1_textFormat/part-00*
    // Approach 1 ends

    // Approach 2 using plain old SQL query
    // Use below for Spark 1.6.2 or below
    // users.registerTempTable("users") // Register as table (temporary) so that query can be performed on the table
    // transactions.registerTempTable("transactions") // Register as table (temporary) so that query can be performed on the table
    // 方法二:注册两个临时表格
    users.createOrReplaceTempView("users") // Register as table (temporary) so that query can be performed on the table
    transactions.createOrReplaceTempView("transactions") // Register as table (temporary) so that query can be performed on the table

    import spark.sql

    // 用SQL语句查询
    val sqlResult = sql("SELECT productId, count(distinct location) locCount FROM transactions LEFT OUTER JOIN users ON transactions.userId = users.userId group by productId")
    sqlResult.show() // Print first 20 records on the console
    sqlResult.write.save(output + "/approach2") // Saves output in compressed Parquet format, recommended for large projects.
    sqlResult.rdd.saveAsTextFile(output + "/approach2_textFormat") // Converts DataFram to RDD[Row] and saves it to in text file. To see output use cat command, e.g. cat output/approach2_textFormat/part-00*
    // Approach 2 ends

    // done

