Shuffle Read
在 Task 实例化的时候就会调用 runTask() 方法运行任务,runTask() 方法中会调用 RDD.getOrCompute() 方法来进行任务的运算工作:
private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
val blockId = RDDBlockId(id, partition.index)
var readCachedBlock = true
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
readCachedBlock = false
computeOrReadCheckpoint(partition, context)
}) match {
// ...
}
}
computeOrReadCheckpoint() 方法回会先判断这个 RDD 是否 checkpoint 和物化过,如果没有就会调用 compute() 方法进行计算操作。
ShuffledRDD
对 Shuffle-Read 进行剖析,我们需要从 ShuffleRDD.compute() 方法入手:
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
// 从 ShuffleManager 中获取 Reader
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
// 调用 Reader 的 read() 方法
.read()
// 将读到的数据的迭代器返回
.asInstanceOf[Iterator[(K, C)]]
}
Spark 2.2.3 默认使用 SortShuffleManager 作为 Shuffle 管理器,SortShuffleManager.getReader() 的实现细节:
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
// 实例化 BlockStoreShuffleReader 作为 Reader
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
从 ShuffleManager 获取到 BlockStoreShuffleReader 后调用了其 read() 方法:
override def read(): Iterator[Product2[K, C]] = {
// 实例化 ShuffleBlockFetcherIterator
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
// 传入 RPC 通信端
blockManager.shuffleClient,
blockManager,
// 获取该 ReduceTask 的数据来源的元数据信息
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// 从 Map 端一次拉取的最大数据量
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
// 每次拉取的最大请求地址数
// 通过 spark.reducer.maxBlocksInFlightPerAddress 来配置
// 默认为 Int.MaxValue
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()
// 将 reduce 数据转换成 key-value 迭代器
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// ...
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
// 聚合操作
// 稍后分析
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
// 在 map 端聚合过
// 创建聚合过的 key-value 迭代器
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
// 聚合
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
// 在 map 端没有聚合过
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// 如果需要排序,对聚合后的数据进行排序操作
// 返回 CompletionIterator
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
}
由于这部分代码过多,我将其分为 ShuffleBlockFetcherIterator 的初始化、aggregatedIter 聚合数据迭代器的生成和排序数据迭代器的生成,这三部分进行分别探讨。
1. ShuffleBlockFetcherIterator 的初始化
这部分我们剖析下 ShuffleBlockFetcherIterator 实例化时传入的连个参数,分别为 blockManager.shuffleClient 和 mapOutputTracker.getMapSizesByExecutorId(...):
// 实例化 ShuffleBlockFetcherIterator
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
// 传入 RPC 通信端
blockManager.shuffleClient,
blockManager,
// 获取该 ReduceTask 的数据来源的元数据信息
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// 从 Map 端一次拉取的最大数据量
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
// 每次拉取的最大请求地址数
// 通过 spark.reducer.maxBlocksInFlightPerAddress 来配置
// 默认为 Int.MaxValue
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
blockManager.shuffleClient:
// externalShuffleServiceEnabled 默认为 false
// 通过在 conf 配置 spark.shuffle.service.enabled 可更改
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores)
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
} else {
// NettyBlockTransferService 的实例化对象
// 使用 Netty 作为通信框架
blockTransferService
}
默认情况下,ShuffleBlockFetcherIterator 使用 Netty 作为通信服务框架。
mapOutputTracker 是 MapOutputTrackerWorker 的实例化对象,其 etMapSizesByExecutorId(...) 方法实现细节:
override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
// 获取要拉取数据的元数据信息
val statuses = getStatuses(shuffleId)
try {
// 将元数据信息转换为 Seq[(BlockManagerId, Seq[(BlockId, Long)])]
MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
} catch {
// ..
}
}
getStatuses() 的实现细节:
private def getStatuses(shuffleId: Int): Array[MapStatus] = {
// mapStatuses 为 ConcurrentHashMap 的实例化对象
// 从 map 缓存中先拿
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
// 缓存中没有
fetching.synchronized {
// fetching 是 HashSet 的实例化对象
// fetching 是正在获取取元数据信息的 shuffleId 的集合
// 也就是其他线程正在获取相同 shuffleId 的要拉取数据的元数据信息
while (fetching.contains(shuffleId)) {
// 有其它线程正在拉取元数据信息
try {
fetching.wait()
} catch {
case e: InterruptedException =>
}
}
// 再次尝试从 map 缓存中获取
// 因为有的线程可能是被唤醒的
fetchedStatuses = mapStatuses.get(shuffleId).orNull
if (fetchedStatuses == null) {
// 如果没有获取到,将 shuffleId 加入到拉取集合中
// 防止重复拉取
// 这也正是 fetching.synchronized 的目的
fetching += shuffleId
}
}
if (fetchedStatuses == null) {
// 真正的进行元数据信息的拉取工作
try {
// 向 MapOutputTrackerMasterEndpoint 发送一个获取元数据信息的请求
// 这里发送的是同步请求
// 由 MapOutputTrackerMasterEndpoint.receiveAndReply() 处理
val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
// 反序列化
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
// 加入到缓存中
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
// 从 fetching 中移除
fetching -= shuffleId
fetching.notifyAll()
}
}
}
if (fetchedStatuses != null) {
fetchedStatuses
} else {
throw new MetadataFetchFailedException(
shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
}
} else {
statuses
}
}
我们再看看 MapOutputTrackerMasterEndpoint 在收到 GetMapOutputStatuses 消息后,会做哪些工作:
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
// 处理 GetMapOutputStatuses 消息
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = context.senderAddress.hostPort
// tracker 是 MapOutputTrackerMaster 的实例化对象
val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))
case StopMapOutputTracker =>
context.reply(true)
stop()
}
tracker 是 MapOutputTrackerMaster 的实例化对象,其 post() 方法的实现细节:
def post(message: GetMapOutputMessage): Unit = {
// mapOutputRequests 是 LinkedBlockingQueue 的实例化对象
// 将消息放入到队列中
// 由后台线程去处理
mapOutputRequests.offer(message)
}
MessageLoop 负责处理 MapOutputTrackerMaster 加入队列中的消息:
private class MessageLoop extends Runnable {
override def run(): Unit = {
try {
while (true) {
try {
// 取出消息
val data = mapOutputRequests.take()
if (data == PoisonPill) {
mapOutputRequests.offer(PoisonPill)
return
}
// 获取基本信息
val context = data.context
val shuffleId = data.shuffleId
val hostPort = context.senderAddress.hostPort
// shuffleStatuses 为 ConcurrentHashMap 的实例化对象
// 从缓存中取出 shuffleId 对应的元数据信息
val shuffleStatus = shuffleStatuses.get(shuffleId).head
context.reply(
// 序列化并应答
shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast))
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
case ie: InterruptedException => // exit
}
}
}
这样,在 ShuffleBlockFetcherIterator 中既有通信框架,又有要拉取数据的元数据信息了,接下来,我们看看 ShuffleBlockFetcherIterator 的初始化工作:
initialize()
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
context.addTaskCompletionListener(_ => cleanup())
// Split local and remote blocks.
// 将本地 blocks 和远程 blocks 请求分离开
val remoteRequests = splitLocalRemoteBlocks()
// 随机打乱需要进行远程拉取请求的
// 避免热点问题
fetchRequests ++= Utils.randomize(remoteRequests)
// 远程拉取
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size
// 本地拉取
fetchLocalBlocks()
}
到这里,又需要分为三部分进行剖析,分别为 splitLocalRemoteBlocks()、fetchUpToMaxBytes() 和 fetchLocalBlocks():
-
splitLocalRemoteBlocks() 的实现细节:
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. // 将远程请求的数据大小设置为 maxBytesInFlight / 5 // maxBytesInFlight 上面提到过,为从 Map 端一次拉取的最大数据量 // 变为 1/5 主要是为了提高并行度,而不是单一的从一个节点上拉取 // 这个做,可以同时从 5 个节点上拉取,每次只拉取一小部分 val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) // 远程请求数组 val remoteRequests = new ArrayBuffer[FetchRequest] // Tracks total number of blocks (including zero sized blocks) var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size if (address.executorId == blockManager.blockManagerId.executorId) { // 本地节点 // 过滤掉块大小为 0 的 block localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator var curRequestSize = 0L var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() // Skip empty blocks if (size > 0) { curBlocks += ((blockId, size)) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size } else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= targetRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { // 将满足大小的加入到 remoteRequests 中 remoteRequests += new FetchRequest(address, curBlocks) curBlocks = new ArrayBuffer[(BlockId, Long)] curRequestSize = 0 } } // 将剩余的构建成一个 FetchRequest 加入到 remoteRequests 中 if (curBlocks.nonEmpty) { remoteRequests += new FetchRequest(address, curBlocks) } } } remoteRequests }
splitLocalRemoteBlocks() 负责将远程数据和本地数据进行分割,分而治之。
-
fetchUpToMaxBytes() 的实现细节:
private def fetchUpToMaxBytes(): Unit = { // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host // immediately, defer the request until the next time it can be processed. // 处理需要延迟的拉取请求 if (deferredFetchRequests.nonEmpty) { for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { while (isRemoteBlockFetchable(defReqQueue) && !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) { val request = defReqQueue.dequeue() // 发送拉取请求 send(remoteAddress, request) if (defReqQueue.isEmpty) { deferredFetchRequests -= remoteAddress } } } } // Process any regular fetch requests if possible. while (isRemoteBlockFetchable(fetchRequests)) { val request = fetchRequests.dequeue() val remoteAddress = request.address if (isRemoteAddressMaxedOut(remoteAddress, request)) { // 需要请求的地址数超过限制 val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) defReqQueue.enqueue(request) deferredFetchRequests(remoteAddress) = defReqQueue } else { // 直接发送请求 send(remoteAddress, request) } } // 处理发送请求 def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { sendRequest(request) numBlocksInFlightPerAddress(remoteAddress) = numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size } def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { fetchReqQueue.nonEmpty && (bytesInFlight == 0 || (reqsInFlight + 1 <= maxReqsInFlight && bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) } // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a // given remote address. def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > // 上面提到过 // 默认为 Int.MaxValue // 可通过 spark.reducer.maxBlocksInFlightPerAddress 来设置 maxBlocksInFlightPerAddress } }
拉取工作,最终都是由 sendRequest() 方法发出的,我们来看看它的实现细节:
private[this] def sendRequest(req: FetchRequest) { bytesInFlight += req.size reqsInFlight += 1 // so we can look up the size of each blockID val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val remainingBlocks = new HashSet[String]() ++= sizeMap.keys val blockIds = req.blocks.map(_._1.toString) val address = req.address // 创建拉取监听器 val blockFetchingListener = new BlockFetchingListener { // 拉取成功 override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { // Only add the buffer to results queue if the iterator is not zombie, // i.e. cleanup() has not been called yet. ShuffleBlockFetcherIterator.this.synchronized { if (!isZombie) { // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() remainingBlocks -= blockId // 将结果添加到 results 中 results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, remainingBlocks.isEmpty)) } } logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } // 拉取失败 override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { // 将结果添加到 results 中 results.put(new FailureFetchResult(BlockId(blockId), address, e)) } } // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { // 需要拉取的数据大小无法放到内存中 // 超过了最大的放置大小 // 直接写入磁盘 shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, blockFetchingListener, this) } else { // 拉取 shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, blockFetchingListener, null) } }
上面说过 ShuffleBlockFetcherIterator 使用 NettyBlockTransferService 进行通信,所以,我们看看 NettyBlockTransferService.fetchBlocks() 的实现原理:
override def fetchBlocks( host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, tempShuffleFileManager: TempShuffleFileManager): Unit = { try { // 支持重试的拉取器 val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { // 通信端 val client = clientFactory.createClient(host, port) // 真正干活的 new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, transportConf, tempShuffleFileManager).start() } } // 获取最大的重试次数 val maxRetries = transportConf.maxIORetries() if (maxRetries > 0) { // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's // a bug in this code. We should remove the if statement once we're sure of the stability. // 可重试 new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start() } else { // 直接调用 blockFetchStarter.blockFetchStarter() 方法 blockFetchStarter.createAndStart(blockIds, listener) } } catch { // ... } }
OneForOneBlockFetcher.start() 的实现细节:
public void start() { if (this.blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } else { this.client.sendRpc(this.openMessage.toByteBuffer(), new RpcResponseCallback() { // 拉取成功 public void onSuccess(ByteBuffer response) { // 响应数据 try { // 数据流管理器 OneForOneBlockFetcher.this.streamHandle = (StreamHandle)Decoder.fromByteBuffer(response); // 挨个块遍历 for(int i = 0; i < OneForOneBlockFetcher.this.streamHandle.numChunks; ++i) { if (OneForOneBlockFetcher.this.tempShuffleFileManager != null) { // 直接写入磁盘 OneForOneBlockFetcher.this.client.stream(OneForOneStreamManager.genStreamChunkId(OneForOneBlockFetcher.this.streamHandle.streamId, i), OneForOneBlockFetcher.this.new DownloadCallback(i)); } else { // 写入内存 OneForOneBlockFetcher.this.client.fetchChunk(OneForOneBlockFetcher.this.streamHandle.streamId, i, OneForOneBlockFetcher.this.chunkCallback); } } } catch (Exception var3) { OneForOneBlockFetcher.logger.error("Failed while starting block fetches after success", var3); OneForOneBlockFetcher.this.failRemainingBlocks(OneForOneBlockFetcher.this.blockIds, var3); } } // 失败 public void onFailure(Throwable e) { OneForOneBlockFetcher.logger.error("Failed while starting block fetches", e); OneForOneBlockFetcher.this.failRemainingBlocks(OneForOneBlockFetcher.this.blockIds, e); } }); } }
远程拉取工作我们就简单剖析到这里。
相比较远程拉取工作,fetchLocalBlocks() 方法就相对简单的多了:
private[this] def fetchLocalBlocks() {
val iter = localBlocks.iterator
// 迭代
while (iter.hasNext) {
val blockId = iter.next()
try {
// 获取数据
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
// 保留操作
buf.retain()
// 加入到 results 中
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
} catch {
// ..
}
}
}
ShuffleBlockFetcherIterator 的初始化工作就这么多,简单的总结一下,ShuffleBlockFetcherIterator 使用传递进来的通信框架和要拉取数据的元数据信息,进行远程和本地的数据拉取工作,并将最终结果,存放到 results 中。
2. aggregatedIter 聚合数据迭代器的生成
aggregatedIter 是聚合数据的迭代器,也就是,在这步完成的 reduce 端聚合操作:
// 聚合操作
// 获取聚合后的迭代器
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
// 在 map 端聚合过
// 创建聚合过的数据 key-value 迭代器
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
// 聚合
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
// 在 map 端没有聚合过
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
处理在 map 端聚合过的 combineCombinersByKey() 方法的实现细节:
def combineCombinersByKey(
iter: Iterator[_ <: Product2[K, C]],
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
combiners.insertAll(iter)
updateMetrics(context, combiners)
combiners.iterator
}
处理没在 map 端聚合过的 combineValuesByKey() 的实现细节:
def combineValuesByKey(
iter: Iterator[_ <: Product2[K, V]],
context: TaskContext): Iterator[(K, C)] = {
// 传入的参数不同
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
updateMetrics(context, combiners)
combiners.iterator
}
基本流程一样,就是创建 ExternalAppendOnlyMap 是传递参数不同。
ExternalAppendOnlyMap.insertAll() 的实现细节:
// 与 Shuffle-Writer 的 ExternalSorter.insertAll() 实现类似
def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
if (currentMap == null) {
throw new IllegalStateException(
"Cannot insert new elements into a map after calling iterator")
}
// An update function for the map that we reuse across entries to avoid allocating
// a new closure each time
var curEntry: Product2[K, V] = null
// update 函数,Shuffle-Writer 提到过
val update: (Boolean, C) => C = (hadVal, oldVal) => {
if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
}
while (entries.hasNext) {
curEntry = entries.next()
val estimatedSize = currentMap.estimateSize()
if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
// 是否需要溢写
if (maybeSpill(currentMap, estimatedSize)) {
currentMap = new SizeTrackingAppendOnlyMap[K, C]
}
// 值聚合或初次创建
currentMap.changeValue(curEntry._1, update)
addElementsRead()
}
}
我们再看看 ExternalAppendOnlyMap.iterator 的实现细节:
override def iterator: Iterator[(K, C)] = {
if (currentMap == null) {
throw new IllegalStateException(
"ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
}
// 根据时候产生过溢写文件,创建不同的迭代器
if (spilledMaps.isEmpty) {
// 没有产生过溢写文件
// 数据都在内存中
CompletionIterator[(K, C), Iterator[(K, C)]](
destructiveIterator(currentMap.iterator), freeCurrentMap())
} else {
// 产生过溢写文件,创建外部迭代器
new ExternalIterator()
}
}
我们看下实例化 ExternalIterator 都做了哪些工作:
// A queue that maintains a buffer for each stream we are currently merging
// This queue maintains the invariant that it only contains non-empty buffers
private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
// 按照 key 进行排序
private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator(
currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap())
// 将缓存中的数据和文件中的数据的 iter 进行合并
private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
// 将数据读取出来,放到 mergeHeap 中
inputStreams.foreach { it =>
val kcPairs = new ArrayBuffer[(K, C)]
readNextHashCode(it, kcPairs)
if (kcPairs.length > 0) {
mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
}
}
这样我们就获取到了聚合数的迭代器了。
3. 排序数据迭代器的生成(如果有需要)
排序迭代器是通过 ExternalSorter 来生成的,ExternalSorter 在 Shuffle-Write 中剖析过:
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
// 将数据插入到 ExternalSorter 中
// 在 Shuffle-Writer 中剖析过
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
// 还是个迭代器
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
在这里,我们看戏 ExternalSorter.iterator() 的实现细节:
def iterator: Iterator[Product2[K, C]] = {
isShuffleSort = false
partitionedIterator.flatMap(pair => pair._2)
}
partitionedIterator() 的实现细节:
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
val usingMap = aggregator.isDefined
val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
if (spills.isEmpty) {
// 没有溢写文件
if (!ordering.isDefined) {
// 按分区进行分组
groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
} else {
// 按分区进行分组
groupByPartition(destructiveIterator(
collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
}
} else {
// 将内存和文件中的数据进行合并
merge(spills, destructiveIterator(
collection.partitionedDestructiveSortedIterator(comparator)))
}
}
到这里我们就不往下看了,可以参考 Shuffle-Write.