Spark系列——关于 mapPartitions的误区

前言

今天 Review 了一下同事的代码,
发现其代码中有非常多的 mapPartitions,
问其原因,他说性能比 map 更好。
我说为什么性能好呢?
于是就有了这篇文章

网上推崇 mapPartitions 的原因

按照某些文章的原话来说
一次函数调用会处理一个partition所有的数据,而不是一次函数调用处理一条,性能相对来说会高一些。
又比如说:
如果是普通的map,比如一个partition中有1万条数据;
那么你的function要执行和计算1万次。
但是,使用MapPartitions操作之后,
一个task仅仅会执行一次function,
function一次接收所有的partition数据。
只要执行一次就可以了,性能比较高
这种说法如果按照上面的方式来理解其实也是那么一回事,
但是也很容易让一些新人理解为:
map要执行1万次,而 MapPartitions 只需要一次,这速度杠杠的提升了啊
实际上,你使用MapPartitions迭代的时候,
还是是一条条数据处理的,这个次数其实完全没变。

其实这个问题我们可以来看看源码
map算子源码

  def map[U: ClassTag](f: T => U): RDD[U] = withScope {
    val cleanF = sc.clean(f)
    new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
  }

接受用户传入的一个函数,
new 一个 MapPartitionsRDD 对象,
我们的函数是作用在 MapPartitionsRDD 的迭代器 iter 上。

mapPartition算子源码

  def mapPartitions[U: ClassTag](
      f: Iterator[T] => Iterator[U],
      preservesPartitioning: Boolean = false): RDD[U] = withScope {
    val cleanedF = sc.clean(f)
    new MapPartitionsRDD(
      this,
      (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter),
      preservesPartitioning)
  }

接受一个迭代器,
new 一个 MapPartitionsRDD 对象,
传入的迭代器是作为 MapPartitionsRDD 的迭代器。

说白了,这个两算子真没什么差,
map 算子可以理解为 mapPartitions 的一个高级封装而已。

mapPartitions 带来的问题

其实就我个人经验来看,
mapPartitions 的正确使用其实并不会造成什么大的问题,
当然我也没看出普通场景 mapPartitions 比 map 有什么优势,
所以 完全没必要刻意使用 mapPartitions
反而,mapPartitions 会带来一些问题。

  1. 使用起来并不是很方便,这个写过代码的人应该都知道。
    当然这个问题并不是不能解决,我们可以写类似下面的代码,
    确实也变的和 map 简洁性也差不太多,
    恩,我不会告诉你可以尝试在生产环境中用用噢。
    //抽象出一个函数,以后所有的 mapPartitions 都可以用
    def mapFunc[T, U](iterator: Iterator[T], f: T => U) = {
      iterator.map(x => {
        f(x)
      })
    }
    //使用    
    rdd.mapPartitions(x => {
        mapFunc(x, line => {
            s"${line}转换数据"
        })
      })
    
    
  2. 容易造成 OOM,这个也是很多博客提到的问题,
    他们大致会写出如下的代码来做测试,
    rdd.mapPartitions(x => {
        xxxx操作
       while (x.hasNext){
         val next = x.next()
       }
        xxx操作
      })
    
    如果你的代码是上面那样,那OOM也就不足为奇了,
    不知道你注意到了没有,mapPartitions 是接受一个迭代器,
    再返回一个迭代器的,
    如果你这么写代码,就完全没有使用到迭代器的懒执行特性。
    将数据都堆积到了内存,
    真就变成了一次处理一个partition的数据了,
    在某种程度上已经破坏了 Spark Pipeline 的计算模式了。

mapPartitions 到底该怎么用

一对一的普通使用

存在即是道理,
虽然上面一直在吐槽,
但是其确实有存在的理由。
其一个分区只会被调用一次的特性,
在一些写数据库的时候确实很有帮助,
因为我们的 Spark 是分布式执行的,
所以连接数据库的操作必须放到算子内部才能正确的被Executor执行,
那么 mapPartitions 就显示比 map 要有优势的多了。
比如下面这段伪代码

rdd.mapPartitions(x => {
        println("连接数据库")
        val res = x.map(line=>{
          print("写入数据:" + line)
          line
        })
        res
      })

这样我就一个分区只要连接一次数据库,
而如果是 map 算子,那可能就要连接 n 多次了。
不过上面这种就没法关闭数据库连接了,
所以可以换另外一种方式:

rdd1.mapPartitions(x => {
      println("连接数据库")
      new Iterator[Any] {
        override def hasNext: Boolean = {
          if (x.hasNext) {
            true
          } else {
            println("关闭数据库")
            false
          }
        }
        override def next(): Any = "写入数据:" + x.next()
      }
    })

自定义一个迭代器,
这样虽然麻烦了一点,
但是无疑才是正确的。
当然还有一些复杂的处理,
比如类似 flatMap那种要输出多条怎么办?
这个时候可以去参考下 Iterator 的源码是怎么实现的,
同样不难,这里就不赘述了。

一对多的的高级使用

本来是想偷点懒的,
不过既然有人问起这个,
这里就补充说下输出多条的方式。

思路其实很简单,
我们可以查看迭代器的源码,
他是有一个 flatMap 算子的,
我们仿照一下就ok啦。
下面我们来解读下 Iterator.flatMap算子这段的源码吧。

        // f 函数是 传入每一条数据都需要返回一个迭代器
        // 也就是说一条记录可以返回多个值
        def flatMap[B](f: A => GenTraversableOnce[B]): Iterator[B] = new AbstractIterator[B] {
          //定义当前的迭代器是空的
          private var cur: Iterator[B] = empty
        //这是源码,为了方便理解,我稍微改写了下
//          def hasNext: Boolean =
//            cur.hasNext || self.hasNext && {
//              cur = f(self.next).toIterator;
//              hasNext
//           }
        def hasNext: Boolean ={
          if(cur.hasNext){
            //如果当前迭代器还有值,
            //则返回true
            return true
          }
          if(self.hasNext){
            //如果cur已经没有值了
            //但是本身的迭代器还有值
            //则我们把本身迭代器的一个值拿出来
            //通过 f函数 构造一个迭代器放到当前的迭代器
            cur = f(self.next).toIterator;
            //再递归一次本函数来看是否还有值
            return hasNext
          }
        }
          //这个就没什么好说的了
          def next(): B = (if (hasNext) cur else empty).next()
        }

上面的代码为了方便理解,
我修改了下,并加了注释,
应该是很好理解了。

这里如果你如果要做伸手党的话,
我也给出一个实例代码

 val conf = new SparkConf()
      .setMaster("local[1]")
      .setAppName("test")
    val sc = new SparkContext(conf)
    sc.setLogLevel("WARN")
    sc.parallelize(Seq("a,a,a,a,a"))
      .mapPartitions(iter => {

        new AbstractIterator[String] {
          def myF(data: String): Iterable[String] = {
            println(data)
            data.split(",").toIterable
          }

          var cur: Iterator[String] = Iterator.empty

          override def hasNext: Boolean = {
            cur.hasNext || iter.hasNext && {
              cur = myF(iter.next).toIterator
              hasNext
            }
          }

          override def next(): String = (if (hasNext) cur else Iterator.empty).next()
        }
      })
      .foreach(println)

这里捎带提一下就是,
其实迭代器本身就有 Map flatMap 等算子,
之所以还要去自定义,
就是因为自定义提供了更加自由的一些操作,
比如开启和关闭数据库等,
但是大部分情况下,
还是能不自定义,谁想去折腾呢?

其他

另外一点就是 mapPartitions 提供给了我们更加强大的数据控制力,
怎么理解呢?我们可以一次拿到一个分区的数据,
那么我们就可以对一个分区的数据进行统一处理,
虽然会加大内存的开销,但是在某些场景下还是很有用的,
比如一些矩阵的乘法。

后记

不管你要使用哪个算子,其实都是可以的,
但是大多数时候,我还是推荐你使用 map 算子,
当然遇到一些map算子不合适的场景,
那就没办法了...
不过就算你是真的要使用 mapPartitions,
那么请记得充分发挥一下 迭代器的 懒执行特性。

最后,如果本文对你有帮助,帮忙点个赞呗

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容