Spark Sql Aggregate 源码剖析
本文基于 Spark 2.3.0 源码,其他版本实现可能会略有不同
纵观 Spark Sql 源码,聚合的实现是其中较为复杂的部分,本文希望能以例子结合流程图的方式来说清楚整个过程。这里仅关注 Aggregate 在物理执行计划相关的内容,之前的 parse、analyze 及 optimize 阶段暂不做分析。在 Spark Sql 中,有一个专门的 Aggregation strategy 用来处理聚合,我们先来看看这个策略。
一、Strategy Aggregation

本文暂不讨论 distinct Aggregate 的实现(有兴趣的可以看看另一篇博文 https://www.jianshu.com/p/77e0a70db8cd),我们来看看 AggUtils#planAggregateWithoutDistinct 是如何生成聚合的物理执行计划的
二、Create Aggregate 核心流程
创建聚合分为两个阶段:
- 创建 partial agg
- 创建以
partial agg为 child 的final agg
AggregateExpression 共有以下几种 mode:
- Partial: 局部数据的聚合。会根据读入的原始数据更新对应的聚合缓冲区,当处理完所有的输入数据后,返回的是局部聚合的结果
- PartialMerge: 主要是对 Partial 返回的聚合缓冲区(局部聚合结果)进行合并,但此时仍不是最终结果,还要经过 Final 才是最终结果
- Final: 起到的作用是将聚合缓冲区的数据进行合并,然后返回最终的结果
-
Complete: 和 Partial/Final 不同,不进行局部聚合计算用于应用在不支持 Partial 模式的聚合函数上
大家常用的 min/max, avg, sum 等聚合函数都是包含 Partial 和 Final 两个 mode,也是两个阶段。举例来说 sum 函数在 map 阶段处于 Partial 模式,在 reduce 阶段的 sum 函数处于 Final 模式
2.1、partial agg

Q:是否支持使用 hash based agg 是如何判断的?

摘自我另一篇文章:https://www.jianshu.com/p/77e0a70db8cd
2.2、final agg

三、HashAggregateExec 详解
为了说明最常用也是最复杂的的 hash based agg,本小节暂时将示例 sql 改为
SELECT a, COUNT(b), COUNT(b) , SUM(c) + 100 FROM alifin_jtest_dev.testagg GROUP BY a
这样就能进入 HashAggregateExec 的分支
3.1、构造函数
构造函数主要工作就是对 groupingExpressions、aggregateExpressions、aggregateAttributes、resultExpressions 进行了初始化


3.2、HashAggregateExec#doExecute
在 enable code gen 的情况下,会调用 HashAggregateExec#inputRDDs 来生成 RDD,为了分析 HashAggregateExec 是如何生成 RDD 的,我们设置 spark.sql.codegen.wholeStage 为 false 来 disable code gen,这样就会调用 HashAggregateExec#doExecute 来生成 RDD,如下:
protected override def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsWithIndex { (partIndex, iter) =>
// 如果输入迭代器为空,返回一个空迭代器,这里不展开
val aggregationIterator = new TungstenAggregationIterator(partIndex, ...)
if (!hasInput && groupingExpressions.isEmpty) {
// 非分组聚合,输入迭代器不为空
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
} else {
// 分组聚合,迭代器不为空
aggregationIterator
}
}
可以看到,关键的部分就是根据 child.execute() 生成的 RDD 的每一个 partition 的迭代器转化生成一个新的 TungstenAggregationIterator,即 HashAggregateExec 生成的 RDD 的各个 partition。由于 TungstenAggregationIterator 涉及内容非常多,我们单开一大节来进行介绍。
四、TungstenAggregationIterator
此迭代器:
- 使用 hashMap(
UnsafeKVExternalSorter类型,内部存储 groupingKey 和 UnsafeRow 的映射关系)来存储 group 及其相应的agg buffer - 如果此 hashMap 无法从内存管理器分配内存(说明内存已满),则会将 hashMap spill 到磁盘并创建一个新 hashMap(如果无法创建则抛出 OOM Error)。
- 处理完所有输入后,使用 external sorter 将所有 spills merge 在一起,并进行基于 sort 的聚合。
注:UnsafeKVExternalSorter 的实现可以参考:
- https://blog.csdn.net/asongoficeandfire/article/details/53728182
- https://blog.csdn.net/asongoficeandfire/article/details/61668186
UnsafeRow 是 InternalRow(表示一行记录) 的 unsafe 实现,由原始内存(byte array)而不是 Java 对象支持,由三个区域组成:
- 空位设置位图区域:用于跟踪空(null)值
- 定长8字节值区域:为每个字段存储一个 8 字节的 word:
- 对于包含固定长度基本类型的字段,例如 long,double 或 int,我们将值直接存储在 word 中
- 对于具有非原始值或可变长度值的字段,存储指向可变长度字段的开头的相对偏移量(该行的基址)和长度(它们组合成长整数);充当指针的功能
- 可变长度数据部分
使用 UnsafeRow 的收益:
- 自主管理内存资源(自己申请、自己释放),不需要 gc
- 精确使用内存,不会浪费
4.1、构造函数

构造函数的主要流程已在上图中说明,需要注意的是:当内存不足时(毕竟每个 grouping 对应的 agg buffer 直接占用内存,如果 grouping 非常多,或者 agg buffer 较大,容易出现内存用尽)会从 hash based aggregate 切换为 sort based aggregate(会 spill 数据到磁盘),后文会进行详述。先来看看最关键的 processInputs 方法的实现
4.2、TungstenAggregationIterator#processInputs
函数 processInputs 用于读取和处理输入行
上图中,需要注意的是:hashMap 中 get 一个 groupingKey 对应的 agg buffer 时,若已经存在该 buffer 则直接返回;若不存在,尝试申请内存新建一个:
- 若成功则返回
- 若因为内存不足导致申请失败,则返回 null,这个时候就要进行 spill 了

- hashMap 在处理 rows 条数超过
Integer.MaxValue时或因内存不足无法为新的 groupingKey 分配新的 agg buffer 时,需要进行 spill。多次 spill 的数据会通过 externalSorter 进行 merge,需要注意的是这里的 merge 并不是把数据都合并了,而是externalSorter.merge(sorter)后 externalSorter 包含了 sorter 对应的 spill 文件的 reader,即可以通过 externalSorter 读取 sorter 对应的 spill 文件 - 当发生过 hashMap spill,就会从 hash based agg 切换为 sort based agg
上图中,用于真正处理一条 row 的 AggregationIterator#processRow 还需进一步展开分析。在此之前,我们先来看看 AggregateFunction 的分类
4.3、AggregateFunction 的分类
AggregateFunction 可以分为 DeclarativeAggregate 和 ImperativeAggregate 两大类,具体的聚合函数均为这两类的子类。
①. DeclarativeAggregate
DeclarativeAggregate 是一类直接由 Catalyst 中的 Expressions 构成的聚合函数,主要逻辑通过调用 4 个表达式完成,分别是:
- initialValues:聚合缓冲区初始化表达式
- updateExpressions:聚合缓冲区更新表达式,Partial mode 下
AggregationIterator#processRow会调用该方法读取一行行的输入来更新聚合聚合缓冲区 - mergeExpressions:聚合缓冲区合并表达式,Final mode 下
AggregationIterator#processRow会调用该方法来对 Partial mode 下生成的相同 groupingKey 的一个个聚合缓冲区进行 merge - evaluteExpression:最终结果生成表达式
我们再次以容易理解的 Count 来举例说明:
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
override def nullable: Boolean = false
// Return data type.
override def dataType: DataType = LongType
private lazy val count = AttributeReference("count", LongType, nullable = false)()
override lazy val aggBufferAttributes = count :: Nil
// 聚合缓冲区初始化表达式,初始值为 0
override lazy val initialValues = Seq(
/* count = */ Literal(0L)
)
// 聚合缓冲区更新表达式,当 input 为非 nulll 的时候对 count 加 1
override lazy val updateExpressions = {
val nullableChildren = children.filter(_.nullable)
if (nullableChildren.isEmpty) {
Seq(
/* count = */ count + 1L
)
} else {
Seq(
/* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
)
}
}
// 聚合缓冲区合并表达式,将两个 agg buffer 进行 merge,这里直接进行相加
override lazy val mergeExpressions = Seq(
/* count = */ count.left + count.right
)
// 最终结果生成表达式,即 count
override lazy val evaluateExpression = count
override def defaultResult: Option[Literal] = Option(Literal(0L))
}
通常来讲,实现一个基于 Expressions 的 DeclarativeAggregate 函数包含以下几个重要的组成部分:
- 定义一个或多个聚合缓冲区的属性(bufferAttribute),例如 Count 只需要 count,这些属性会在 updateExpressions 等各种表达式中用到
- 设定 DeclarativeAggregate 函数的初始值,count 函数的初始值为 0
- 实现数据处理逻辑表达式 updateExpressions,在 count 函数中,当处理新的数据时,上述定义的 count 属性转换为 Add 表达式,即
count + 1L,注意其中对 Null 的处理逻辑 - 实现 merge 处理逻辑的表达式,函数中直接把 count 相加,对应上述代码中的
count.left + count.right,由 DeclarativeAggregate 中定义的 RichAttribute 隐式类完成 - 实现结果输出的表达式
evaluteExpression,返回 count 的值
②. ImperativeAggregate
再来看看 AggregationIterator#processRow
4.4、AggregationIterator#processRow
AggregationIterator#processRow 会调用
def generateProcessRow(
expressions: Seq[AggregateExpression],
functions: Seq[AggregateFunction],
inputAttributes: Seq[Attribute]): (InternalRow, InternalRow) => Unit
生成用于处理一行数据(row)的函数

说白了 processRow 生成了函数才是直接用来接受一条 input row 来更新对应的 agg buffer,具体是根据 mode 及 aggExpression 中的 aggFunction 的类型调用其 updateExpressions 或 mergeExpressions 方法:
比如,对于 aggFunction 为 DeclarativeAggregate 类型的 Partial 下的 Count 来说就是调用其 updateExpressions 方法,即:
val updateExpressions = {
val nullableChildren = children.filter(_.nullable)
if (nullableChildren.isEmpty) {
Seq(
/* count = */ count + 1L
)
} else {
Seq(
/* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
)
}
}
对于 Final 的 Count 来说就是调用其 mergeExpressions 方法,即:
val mergeExpressions = Seq(
/* count = */ count.left + count.right
)
对于 aggFunction 为 ImperativeAggregate 类型的 Partial 下的 Collect 来说就是调用其 update 方法,即:
def update(buffer: T, input: InternalRow): T = {
val value = child.eval(input)
if (value != null) {
buffer += InternalRow.copyValue(value)
}
buffer
}
对于 Final 的 Collect 来说就是调用其 merge 方法,即:
def merge(buffer: T, other: T): T = {
buffer ++= other
}
4.6、读取聚合数据
我们都知道,读取一个迭代器的数据,是要不断调用 hasNext 方法进行 check 是否还有数据,当该方法返回 true 的时候再调用 next 方法取得下一条数据。所以要知道如何读取 TungstenAggregationIterator 的数据,就得分析其这两个方法。
①、TungstenAggregationIterator#hasNext
override final def hasNext: Boolean = {
(sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext)
}
分为两种情况,分别是:
- sortBased 为 true,即由于发生过 spill 切换为 sort based agg 了:sortedInputHasNewGroup 表示是否还有下一条数据,该值在
switchToSortBasedAggregation初始化 - sortBased 为 false,即尚未发生过 spill,依然是 hash based agg:mapIteratorHasNext 表示是否还有下一条数据,在完成 processInputs 后进行初始化
②、TungstenAggregationIterator#next

Agg 的实现确实复杂,本文虽然篇幅已经很长,但还有很多方面没有 cover 到,但基本最核心、最复杂的点都详细介绍了,如果对于未 cover 的部分有兴趣,请自行阅读源码进行分析~
