Spark 源码浅析之 Shuffle 读部分

Shuffle Read

在 Task 实例化的时候就会调用 runTask() 方法运行任务,runTask() 方法中会调用 RDD.getOrCompute() 方法来进行任务的运算工作:

private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
  val blockId = RDDBlockId(id, partition.index)
  var readCachedBlock = true
    
  SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
    readCachedBlock = false
    computeOrReadCheckpoint(partition, context)
  }) match {
    // ...
  }
}

computeOrReadCheckpoint() 方法回会先判断这个 RDD 是否 checkpoint 和物化过,如果没有就会调用 compute() 方法进行计算操作。

类关系

ShuffledRDD

流程概览

对 Shuffle-Read 进行剖析,我们需要从 ShuffleRDD.compute() 方法入手:

override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
  val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
  // 从 ShuffleManager 中获取 Reader
  SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
    // 调用 Reader 的 read() 方法
    .read()
    // 将读到的数据的迭代器返回
    .asInstanceOf[Iterator[(K, C)]]
}

Spark 2.2.3 默认使用 SortShuffleManager 作为 Shuffle 管理器,SortShuffleManager.getReader() 的实现细节:

override def getReader[K, C](
    handle: ShuffleHandle,
    startPartition: Int,
    endPartition: Int,
    context: TaskContext): ShuffleReader[K, C] = {
  // 实例化 BlockStoreShuffleReader 作为 Reader
  new BlockStoreShuffleReader(
    handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}

从 ShuffleManager 获取到 BlockStoreShuffleReader 后调用了其 read() 方法:

override def read(): Iterator[Product2[K, C]] = {
  // 实例化 ShuffleBlockFetcherIterator
  val wrappedStreams = new ShuffleBlockFetcherIterator(
    context,
    // 传入 RPC 通信端
    blockManager.shuffleClient,
    blockManager,
    // 获取该 ReduceTask 的数据来源的元数据信息
    mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
    serializerManager.wrapStream,
    // 从 Map 端一次拉取的最大数据量
    SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
    SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
    // 每次拉取的最大请求地址数
    // 通过 spark.reducer.maxBlocksInFlightPerAddress 来配置
    // 默认为 Int.MaxValue
    SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
    SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
    SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

  val serializerInstance = dep.serializer.newInstance()

  // 将 reduce 数据转换成 key-value 迭代器
  val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
    // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
    // NextIterator. The NextIterator makes sure that close() is called on the
    // underlying InputStream when all records have been read.
    serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
  }

  // ...

  // An interruptible iterator must be used here in order to support task cancellation
  val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

  // 聚合操作
  // 稍后分析
  val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
    if (dep.mapSideCombine) {
      // We are reading values that are already combined
      // 在 map 端聚合过
      // 创建聚合过的 key-value 迭代器
      val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
      // 聚合
      dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
    } else {
      // We don't know the value type, but also don't care -- the dependency *should*
      // have made sure its compatible w/ this aggregator, which will convert the value
      // type to the combined type C
      // 在 map 端没有聚合过
      val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
      dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
    }
  } else {
    require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
    interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
  }

  // 如果需要排序,对聚合后的数据进行排序操作
  // 返回 CompletionIterator
  dep.keyOrdering match {
    case Some(keyOrd: Ordering[K]) =>
      // Create an ExternalSorter to sort the data.
      val sorter =
        new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
      sorter.insertAll(aggregatedIter)
      context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
      context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
      context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
      CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
    case None =>
      aggregatedIter
  }
}

由于这部分代码过多,我将其分为 ShuffleBlockFetcherIterator 的初始化、aggregatedIter 聚合数据迭代器的生成和排序数据迭代器的生成,这三部分进行分别探讨。

1. ShuffleBlockFetcherIterator 的初始化

这部分我们剖析下 ShuffleBlockFetcherIterator 实例化时传入的连个参数,分别为 blockManager.shuffleClient 和 mapOutputTracker.getMapSizesByExecutorId(...):

// 实例化 ShuffleBlockFetcherIterator
val wrappedStreams = new ShuffleBlockFetcherIterator(
  context,
  // 传入 RPC 通信端
  blockManager.shuffleClient,
  blockManager,
  // 获取该 ReduceTask 的数据来源的元数据信息
  mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
  serializerManager.wrapStream,
  // 从 Map 端一次拉取的最大数据量
  SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
  SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
  // 每次拉取的最大请求地址数
  // 通过 spark.reducer.maxBlocksInFlightPerAddress 来配置
  // 默认为 Int.MaxValue
  SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
  SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
  SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

blockManager.shuffleClient:

// externalShuffleServiceEnabled 默认为 false
// 通过在 conf 配置 spark.shuffle.service.enabled 可更改
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
  val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores)
  new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
} else {
  // NettyBlockTransferService 的实例化对象
  // 使用 Netty 作为通信框架
  blockTransferService
}

默认情况下,ShuffleBlockFetcherIterator 使用 Netty 作为通信服务框架。

MapOutputTracker 类关系

mapOutputTracker 是 MapOutputTrackerWorker 的实例化对象,其 etMapSizesByExecutorId(...) 方法实现细节:

override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
    : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
  logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
  // 获取要拉取数据的元数据信息
  val statuses = getStatuses(shuffleId)
  try {
    // 将元数据信息转换为 Seq[(BlockManagerId, Seq[(BlockId, Long)])]
    MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
  } catch {
    // ..
  }
}

getStatuses() 的实现细节:

private def getStatuses(shuffleId: Int): Array[MapStatus] = {
  // mapStatuses 为 ConcurrentHashMap 的实例化对象
  // 从 map 缓存中先拿
  val statuses = mapStatuses.get(shuffleId).orNull
  if (statuses == null) {
    // 缓存中没有
      
    fetching.synchronized {
      // fetching 是 HashSet 的实例化对象
      // fetching 是正在获取取元数据信息的 shuffleId 的集合
      // 也就是其他线程正在获取相同 shuffleId 的要拉取数据的元数据信息
      while (fetching.contains(shuffleId)) {
        // 有其它线程正在拉取元数据信息
        try {
          fetching.wait()
        } catch {
          case e: InterruptedException =>
        }
      }

      // 再次尝试从 map 缓存中获取
      // 因为有的线程可能是被唤醒的
      fetchedStatuses = mapStatuses.get(shuffleId).orNull
      if (fetchedStatuses == null) {
        // 如果没有获取到,将 shuffleId 加入到拉取集合中
        // 防止重复拉取
        // 这也正是 fetching.synchronized 的目的
        fetching += shuffleId
      }
    }

    if (fetchedStatuses == null) {
      // 真正的进行元数据信息的拉取工作
      try {
        // 向 MapOutputTrackerMasterEndpoint 发送一个获取元数据信息的请求
        // 这里发送的是同步请求
        // 由 MapOutputTrackerMasterEndpoint.receiveAndReply() 处理
        val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
        // 反序列化
        fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
        // 加入到缓存中
        mapStatuses.put(shuffleId, fetchedStatuses)
      } finally {
        fetching.synchronized {
          // 从 fetching 中移除
          fetching -= shuffleId
          fetching.notifyAll()
        }
      }
    }

    if (fetchedStatuses != null) {
      fetchedStatuses
    } else {
      throw new MetadataFetchFailedException(
        shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
    }
  } else {
    statuses
  }
}

我们再看看 MapOutputTrackerMasterEndpoint 在收到 GetMapOutputStatuses 消息后,会做哪些工作:

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
  // 处理 GetMapOutputStatuses 消息
  case GetMapOutputStatuses(shuffleId: Int) =>
    val hostPort = context.senderAddress.hostPort
    // tracker 是 MapOutputTrackerMaster 的实例化对象 
    val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))

  case StopMapOutputTracker =>
    context.reply(true)
    stop()
}

tracker 是 MapOutputTrackerMaster 的实例化对象,其 post() 方法的实现细节:

def post(message: GetMapOutputMessage): Unit = {
  // mapOutputRequests 是 LinkedBlockingQueue 的实例化对象
  // 将消息放入到队列中
  // 由后台线程去处理
  mapOutputRequests.offer(message)
}

MessageLoop 负责处理 MapOutputTrackerMaster 加入队列中的消息:

private class MessageLoop extends Runnable {
  override def run(): Unit = {
    try {
      while (true) {
        try {
          // 取出消息
          val data = mapOutputRequests.take()
           if (data == PoisonPill) {
            mapOutputRequests.offer(PoisonPill)
            return
          }
          // 获取基本信息
          val context = data.context
          val shuffleId = data.shuffleId
          val hostPort = context.senderAddress.hostPort
          // shuffleStatuses 为 ConcurrentHashMap 的实例化对象
          // 从缓存中取出 shuffleId 对应的元数据信息
          val shuffleStatus = shuffleStatuses.get(shuffleId).head
          context.reply(
            // 序列化并应答
            shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast))
        } catch {
          case NonFatal(e) => logError(e.getMessage, e)
        }
      }
    } catch {
      case ie: InterruptedException => // exit
    }
  }
  }

这样,在 ShuffleBlockFetcherIterator 中既有通信框架,又有要拉取数据的元数据信息了,接下来,我们看看 ShuffleBlockFetcherIterator 的初始化工作:

initialize()

private[this] def initialize(): Unit = {
  // Add a task completion callback (called in both success case and failure case) to cleanup.
  context.addTaskCompletionListener(_ => cleanup())

  // Split local and remote blocks.
  // 将本地 blocks 和远程 blocks 请求分离开
  val remoteRequests = splitLocalRemoteBlocks()
  // 随机打乱需要进行远程拉取请求的
  // 避免热点问题
  fetchRequests ++= Utils.randomize(remoteRequests)
 
  // 远程拉取
  fetchUpToMaxBytes()

  val numFetches = remoteRequests.size - fetchRequests.size
    
  // 本地拉取
  fetchLocalBlocks()

}

到这里,又需要分为三部分进行剖析,分别为 splitLocalRemoteBlocks()、fetchUpToMaxBytes() 和 fetchLocalBlocks():

  • splitLocalRemoteBlocks() 的实现细节:

    private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
      // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
      // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
      // nodes, rather than blocking on reading output from one node.
      // 将远程请求的数据大小设置为 maxBytesInFlight / 5
      // maxBytesInFlight 上面提到过,为从 Map 端一次拉取的最大数据量
      // 变为 1/5 主要是为了提高并行度,而不是单一的从一个节点上拉取
      // 这个做,可以同时从 5 个节点上拉取,每次只拉取一小部分
      val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    
      // 远程请求数组
      val remoteRequests = new ArrayBuffer[FetchRequest]
    
      // Tracks total number of blocks (including zero sized blocks)
      var totalBlocks = 0
      for ((address, blockInfos) <- blocksByAddress) {
        totalBlocks += blockInfos.size
        if (address.executorId == blockManager.blockManagerId.executorId) {
          // 本地节点 
          // 过滤掉块大小为 0 的 block
          localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
          numBlocksToFetch += localBlocks.size
        } else {
          val iterator = blockInfos.iterator
          var curRequestSize = 0L
          var curBlocks = new ArrayBuffer[(BlockId, Long)]
          while (iterator.hasNext) {
            val (blockId, size) = iterator.next()
            // Skip empty blocks
            if (size > 0) {
              curBlocks += ((blockId, size))
              remoteBlocks += blockId
              numBlocksToFetch += 1
              curRequestSize += size
            } else if (size < 0) {
              throw new BlockException(blockId, "Negative block size " + size)
            }
            if (curRequestSize >= targetRequestSize ||
                curBlocks.size >= maxBlocksInFlightPerAddress) {
              // 将满足大小的加入到 remoteRequests 中
              remoteRequests += new FetchRequest(address, curBlocks)
              curBlocks = new ArrayBuffer[(BlockId, Long)]
              curRequestSize = 0
            }
          }
          // 将剩余的构建成一个 FetchRequest 加入到 remoteRequests 中
          if (curBlocks.nonEmpty) {
            remoteRequests += new FetchRequest(address, curBlocks)
          }
        }
      }
      remoteRequests
    }
    

    splitLocalRemoteBlocks() 负责将远程数据和本地数据进行分割,分而治之。

  • fetchUpToMaxBytes() 的实现细节:

    private def fetchUpToMaxBytes(): Unit = {
      // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
      // immediately, defer the request until the next time it can be processed.
    
      // 处理需要延迟的拉取请求
      if (deferredFetchRequests.nonEmpty) {
        for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
          while (isRemoteBlockFetchable(defReqQueue) &&
              !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
            val request = defReqQueue.dequeue()
            // 发送拉取请求
            send(remoteAddress, request)
            if (defReqQueue.isEmpty) {
              deferredFetchRequests -= remoteAddress
            }
          }
        }
      }
    
      // Process any regular fetch requests if possible.
      while (isRemoteBlockFetchable(fetchRequests)) {
        val request = fetchRequests.dequeue()
        val remoteAddress = request.address
        if (isRemoteAddressMaxedOut(remoteAddress, request)) {
          // 需要请求的地址数超过限制
          val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
          defReqQueue.enqueue(request)
          deferredFetchRequests(remoteAddress) = defReqQueue
        } else {
          // 直接发送请求
          send(remoteAddress, request)
        }
      }
    
      // 处理发送请求
      def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
        sendRequest(request)
        numBlocksInFlightPerAddress(remoteAddress) =
          numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
      }
    
      def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
        fetchReqQueue.nonEmpty &&
          (bytesInFlight == 0 ||
            (reqsInFlight + 1 <= maxReqsInFlight &&
              bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
      }
    
      // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
      // given remote address.
      def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
        numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
          // 上面提到过
          // 默认为 Int.MaxValue
          // 可通过 spark.reducer.maxBlocksInFlightPerAddress 来设置
          maxBlocksInFlightPerAddress
      }
    }
    

    拉取工作,最终都是由 sendRequest() 方法发出的,我们来看看它的实现细节:

    private[this] def sendRequest(req: FetchRequest) {
        
      bytesInFlight += req.size
      reqsInFlight += 1
    
      // so we can look up the size of each blockID
      val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
      val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
      val blockIds = req.blocks.map(_._1.toString)
      val address = req.address
    
      // 创建拉取监听器
      val blockFetchingListener = new BlockFetchingListener {
        // 拉取成功
        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
          // Only add the buffer to results queue if the iterator is not zombie,
          // i.e. cleanup() has not been called yet.
          ShuffleBlockFetcherIterator.this.synchronized {
            if (!isZombie) {
              // Increment the ref count because we need to pass this to a different thread.
              // This needs to be released after use.
              buf.retain()
              remainingBlocks -= blockId
              // 将结果添加到 results 中
              results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
                remainingBlocks.isEmpty))
            }
          }
          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
        }
      // 拉取失败
        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
          // 将结果添加到 results 中
          results.put(new FailureFetchResult(BlockId(blockId), address, e))
        }
      }
    
      // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
      // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
      // the data and write it to file directly.
      if (req.size > maxReqSizeShuffleToMem) {
        // 需要拉取的数据大小无法放到内存中
        // 超过了最大的放置大小
        // 直接写入磁盘
        shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
          blockFetchingListener, this)
      } else {
        // 拉取
        shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
          blockFetchingListener, null)
      }
    }
    

    上面说过 ShuffleBlockFetcherIterator 使用 NettyBlockTransferService 进行通信,所以,我们看看 NettyBlockTransferService.fetchBlocks() 的实现原理:

    override def fetchBlocks(
        host: String,
        port: Int,
        execId: String,
        blockIds: Array[String],
        listener: BlockFetchingListener,
        tempShuffleFileManager: TempShuffleFileManager): Unit = {
        
      try {
        // 支持重试的拉取器
        val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
          override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
            // 通信端
            val client = clientFactory.createClient(host, port)
            // 真正干活的
            new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
              transportConf, tempShuffleFileManager).start()
          }
        }
    
        // 获取最大的重试次数
        val maxRetries = transportConf.maxIORetries()
        if (maxRetries > 0) {
          // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
          // a bug in this code. We should remove the if statement once we're sure of the stability.
          // 可重试
          new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
        } else {
          // 直接调用 blockFetchStarter.blockFetchStarter() 方法
          blockFetchStarter.createAndStart(blockIds, listener)
        }
      } catch {
      // ...
      }
    }
    

    OneForOneBlockFetcher.start() 的实现细节:

    public void start() {
        if (this.blockIds.length == 0) {
            throw new IllegalArgumentException("Zero-sized blockIds array");
        } else {
            this.client.sendRpc(this.openMessage.toByteBuffer(), new RpcResponseCallback() {
              // 拉取成功
                public void onSuccess(ByteBuffer response) { // 响应数据
                    try {
                      // 数据流管理器
                        OneForOneBlockFetcher.this.streamHandle = (StreamHandle)Decoder.fromByteBuffer(response);
                        // 挨个块遍历
                        for(int i = 0; i < OneForOneBlockFetcher.this.streamHandle.numChunks; ++i) {
                            if (OneForOneBlockFetcher.this.tempShuffleFileManager != null) {
                              // 直接写入磁盘
                                OneForOneBlockFetcher.this.client.stream(OneForOneStreamManager.genStreamChunkId(OneForOneBlockFetcher.this.streamHandle.streamId, i), OneForOneBlockFetcher.this.new DownloadCallback(i));
                            } else {
                              // 写入内存
                                OneForOneBlockFetcher.this.client.fetchChunk(OneForOneBlockFetcher.this.streamHandle.streamId, i, OneForOneBlockFetcher.this.chunkCallback);
                            }
                        }
                    } catch (Exception var3) {
                        OneForOneBlockFetcher.logger.error("Failed while starting block fetches after success", var3);
                        OneForOneBlockFetcher.this.failRemainingBlocks(OneForOneBlockFetcher.this.blockIds, var3);
                    }
    
                }
              
              // 失败
                public void onFailure(Throwable e) {
                    OneForOneBlockFetcher.logger.error("Failed while starting block fetches", e);
                    OneForOneBlockFetcher.this.failRemainingBlocks(OneForOneBlockFetcher.this.blockIds, e);
                }
            });
        }
    }
    

    远程拉取工作我们就简单剖析到这里。

  • 相比较远程拉取工作,fetchLocalBlocks() 方法就相对简单的多了:

private[this] def fetchLocalBlocks() {
  val iter = localBlocks.iterator
  // 迭代
  while (iter.hasNext) {
    val blockId = iter.next()
    try {
      // 获取数据
      val buf = blockManager.getBlockData(blockId)
      shuffleMetrics.incLocalBlocksFetched(1)
      shuffleMetrics.incLocalBytesRead(buf.size)
      // 保留操作
      buf.retain()
      // 加入到 results 中
      results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
    } catch {
      // ..
    }
  }
}

ShuffleBlockFetcherIterator 的初始化工作就这么多,简单的总结一下,ShuffleBlockFetcherIterator 使用传递进来的通信框架和要拉取数据的元数据信息,进行远程和本地的数据拉取工作,并将最终结果,存放到 results 中。

2. aggregatedIter 聚合数据迭代器的生成

aggregatedIter 是聚合数据的迭代器,也就是,在这步完成的 reduce 端聚合操作:

// 聚合操作
// 获取聚合后的迭代器
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
  if (dep.mapSideCombine) {
    // We are reading values that are already combined
    // 在 map 端聚合过
    // 创建聚合过的数据 key-value 迭代器
    val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
    // 聚合
    dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
  } else {
    // We don't know the value type, but also don't care -- the dependency *should*
    // have made sure its compatible w/ this aggregator, which will convert the value
    // type to the combined type C
    // 在 map 端没有聚合过
    val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
    dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
  }
} else {
  require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
  interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}

处理在 map 端聚合过的 combineCombinersByKey() 方法的实现细节:

def combineCombinersByKey(
    iter: Iterator[_ <: Product2[K, C]],
    context: TaskContext): Iterator[(K, C)] = {
  val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
  combiners.insertAll(iter)
  updateMetrics(context, combiners)
  combiners.iterator
}

处理没在 map 端聚合过的 combineValuesByKey() 的实现细节:

def combineValuesByKey(
    iter: Iterator[_ <: Product2[K, V]],
    context: TaskContext): Iterator[(K, C)] = {
  // 传入的参数不同
  val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
  combiners.insertAll(iter)
  updateMetrics(context, combiners)
  combiners.iterator
}

基本流程一样,就是创建 ExternalAppendOnlyMap 是传递参数不同。

ExternalAppendOnlyMap 类结构

ExternalAppendOnlyMap.insertAll() 的实现细节:

// 与 Shuffle-Writer 的 ExternalSorter.insertAll() 实现类似
def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
  if (currentMap == null) {
    throw new IllegalStateException(
      "Cannot insert new elements into a map after calling iterator")
  }
  // An update function for the map that we reuse across entries to avoid allocating
  // a new closure each time
  var curEntry: Product2[K, V] = null
  // update 函数,Shuffle-Writer 提到过
  val update: (Boolean, C) => C = (hadVal, oldVal) => {
    if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
  }

  while (entries.hasNext) {
    curEntry = entries.next()
    val estimatedSize = currentMap.estimateSize()
    if (estimatedSize > _peakMemoryUsedBytes) {
      _peakMemoryUsedBytes = estimatedSize
    }
    // 是否需要溢写
    if (maybeSpill(currentMap, estimatedSize)) {
      currentMap = new SizeTrackingAppendOnlyMap[K, C]
    }
    // 值聚合或初次创建
    currentMap.changeValue(curEntry._1, update)
    addElementsRead()
  }
}

我们再看看 ExternalAppendOnlyMap.iterator 的实现细节:

override def iterator: Iterator[(K, C)] = {
  if (currentMap == null) {
    throw new IllegalStateException(
      "ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
  }
  // 根据时候产生过溢写文件,创建不同的迭代器
  if (spilledMaps.isEmpty) {
    // 没有产生过溢写文件
    // 数据都在内存中
    CompletionIterator[(K, C), Iterator[(K, C)]](
      destructiveIterator(currentMap.iterator), freeCurrentMap())
  } else {
    // 产生过溢写文件,创建外部迭代器
    new ExternalIterator()
  }
}

我们看下实例化 ExternalIterator 都做了哪些工作:

// A queue that maintains a buffer for each stream we are currently merging
// This queue maintains the invariant that it only contains non-empty buffers
private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]

// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
// 按照 key 进行排序
private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator(
  currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap())
// 将缓存中的数据和文件中的数据的 iter 进行合并
private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
// 将数据读取出来,放到 mergeHeap 中
inputStreams.foreach { it =>
  val kcPairs = new ArrayBuffer[(K, C)]
  readNextHashCode(it, kcPairs)
  if (kcPairs.length > 0) {
    mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
  }
}

这样我们就获取到了聚合数的迭代器了。

3. 排序数据迭代器的生成(如果有需要)

排序迭代器是通过 ExternalSorter 来生成的,ExternalSorter 在 Shuffle-Write 中剖析过:

dep.keyOrdering match {
  case Some(keyOrd: Ordering[K]) =>
    // Create an ExternalSorter to sort the data.
    val sorter =
      new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
    // 将数据插入到 ExternalSorter 中
    // 在 Shuffle-Writer 中剖析过
    sorter.insertAll(aggregatedIter)
    context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
    context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
    context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
    // 还是个迭代器
    CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
  case None =>
    aggregatedIter
}

在这里,我们看戏 ExternalSorter.iterator() 的实现细节:

def iterator: Iterator[Product2[K, C]] = {
  isShuffleSort = false
  partitionedIterator.flatMap(pair => pair._2)
}

partitionedIterator() 的实现细节:

def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
  val usingMap = aggregator.isDefined
  val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
  if (spills.isEmpty) {
    // 没有溢写文件
    if (!ordering.isDefined) {
      // 按分区进行分组
      groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
    } else {
      // 按分区进行分组
      groupByPartition(destructiveIterator(
        collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
    }
  } else {
    // 将内存和文件中的数据进行合并
    merge(spills, destructiveIterator(
      collection.partitionedDestructiveSortedIterator(comparator)))
  }
}

到这里我们就不往下看了,可以参考 Shuffle-Write.

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,530评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 86,403评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,120评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,770评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,758评论 5 367
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,649评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,021评论 3 398
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,675评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,931评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,659评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,751评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,410评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,004评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,969评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,203评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,042评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,493评论 2 343

推荐阅读更多精彩内容

  • Shuffle read 是等到Mapper stage结束后才开始读取数据。边读取数据边处理,数据先放在内存,最...
    imarch1阅读 675评论 0 0
  • Lua 5.1 参考手册 by Roberto Ierusalimschy, Luiz Henrique de F...
    苏黎九歌阅读 13,729评论 0 38
  • 现在的人大多爱发朋友圈,我也不例外。女孩子发自己买的衣服、包包、美食、国外旅游的照片,下面一大堆评论,有的说...
    猪八戒他大哥阅读 365评论 0 0
  • “锦书,你什么时候寄信给我呢。”男子微微弯下腰,似笑非笑的望着面前的她,眼眸如泉潭般清澈,泛着点点笑意。 梅花林中...
    宫主是东倾阅读 280评论 0 0
  • 别让理想成为你的累赘,别把爱人当做拖累。 理想是你前进的理由,爱人是你前进的动力。 别说是爱人阻挡了奔向理想的道路...
    黄新雨阅读 182评论 0 0