Spark Sql Aggregate 源码剖析
本文基于 Spark 2.3.0 源码,其他版本实现可能会略有不同
纵观 Spark Sql 源码,聚合的实现是其中较为复杂的部分,本文希望能以例子结合流程图的方式来说清楚整个过程。这里仅关注 Aggregate 在物理执行计划相关的内容,之前的 parse、analyze 及 optimize 阶段暂不做分析。在 Spark Sql 中,有一个专门的 Aggregation
strategy 用来处理聚合,我们先来看看这个策略。
一、Strategy Aggregation
本文暂不讨论 distinct Aggregate 的实现(有兴趣的可以看看另一篇博文 https://www.jianshu.com/p/77e0a70db8cd),我们来看看 AggUtils#planAggregateWithoutDistinct
是如何生成聚合的物理执行计划的
二、Create Aggregate 核心流程
创建聚合分为两个阶段:
- 创建 partial agg
- 创建以
partial agg
为 child 的final agg
AggregateExpression 共有以下几种 mode:
- Partial: 局部数据的聚合。会根据读入的原始数据更新对应的聚合缓冲区,当处理完所有的输入数据后,返回的是局部聚合的结果
- PartialMerge: 主要是对 Partial 返回的聚合缓冲区(局部聚合结果)进行合并,但此时仍不是最终结果,还要经过 Final 才是最终结果
- Final: 起到的作用是将聚合缓冲区的数据进行合并,然后返回最终的结果
-
Complete: 和 Partial/Final 不同,不进行局部聚合计算用于应用在不支持 Partial 模式的聚合函数上
大家常用的 min/max, avg, sum 等聚合函数都是包含 Partial 和 Final 两个 mode,也是两个阶段。举例来说 sum 函数在 map 阶段处于 Partial 模式,在 reduce 阶段的 sum 函数处于 Final 模式
2.1、partial agg
Q:是否支持使用 hash based agg 是如何判断的?
摘自我另一篇文章:https://www.jianshu.com/p/77e0a70db8cd
2.2、final agg
三、HashAggregateExec 详解
为了说明最常用也是最复杂的的 hash based agg,本小节暂时将示例 sql 改为
SELECT a, COUNT(b), COUNT(b) , SUM(c) + 100 FROM alifin_jtest_dev.testagg GROUP BY a
这样就能进入 HashAggregateExec 的分支
3.1、构造函数
构造函数主要工作就是对 groupingExpressions、aggregateExpressions、aggregateAttributes、resultExpressions 进行了初始化
3.2、HashAggregateExec#doExecute
在 enable code gen 的情况下,会调用 HashAggregateExec#inputRDDs
来生成 RDD,为了分析 HashAggregateExec 是如何生成 RDD 的,我们设置 spark.sql.codegen.wholeStage
为 false 来 disable code gen,这样就会调用 HashAggregateExec#doExecute
来生成 RDD,如下:
protected override def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsWithIndex { (partIndex, iter) =>
// 如果输入迭代器为空,返回一个空迭代器,这里不展开
val aggregationIterator = new TungstenAggregationIterator(partIndex, ...)
if (!hasInput && groupingExpressions.isEmpty) {
// 非分组聚合,输入迭代器不为空
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
} else {
// 分组聚合,迭代器不为空
aggregationIterator
}
}
可以看到,关键的部分就是根据 child.execute()
生成的 RDD 的每一个 partition 的迭代器转化生成一个新的 TungstenAggregationIterator
,即 HashAggregateExec 生成的 RDD 的各个 partition。由于 TungstenAggregationIterator 涉及内容非常多,我们单开一大节来进行介绍。
四、TungstenAggregationIterator
此迭代器:
- 使用 hashMap(
UnsafeKVExternalSorter
类型,内部存储 groupingKey 和 UnsafeRow 的映射关系)来存储 group 及其相应的agg buffer
- 如果此 hashMap 无法从内存管理器分配内存(说明内存已满),则会将 hashMap spill 到磁盘并创建一个新 hashMap(如果无法创建则抛出 OOM Error)。
- 处理完所有输入后,使用 external sorter 将所有 spills merge 在一起,并进行基于 sort 的聚合。
注:UnsafeKVExternalSorter 的实现可以参考:
- https://blog.csdn.net/asongoficeandfire/article/details/53728182
- https://blog.csdn.net/asongoficeandfire/article/details/61668186
UnsafeRow 是 InternalRow(表示一行记录) 的 unsafe 实现,由原始内存(byte array)而不是 Java 对象支持,由三个区域组成:
- 空位设置位图区域:用于跟踪空(null)值
- 定长8字节值区域:为每个字段存储一个 8 字节的 word:
- 对于包含固定长度基本类型的字段,例如 long,double 或 int,我们将值直接存储在 word 中
- 对于具有非原始值或可变长度值的字段,存储指向可变长度字段的开头的相对偏移量(该行的基址)和长度(它们组合成长整数);充当指针的功能
- 可变长度数据部分
使用 UnsafeRow 的收益:
- 自主管理内存资源(自己申请、自己释放),不需要 gc
- 精确使用内存,不会浪费
4.1、构造函数
构造函数的主要流程已在上图中说明,需要注意的是:当内存不足时(毕竟每个 grouping 对应的 agg buffer 直接占用内存,如果 grouping 非常多,或者 agg buffer 较大,容易出现内存用尽)会从 hash based aggregate 切换为 sort based aggregate(会 spill 数据到磁盘),后文会进行详述。先来看看最关键的 processInputs
方法的实现
4.2、TungstenAggregationIterator#processInputs
函数 processInputs 用于读取和处理输入行
上图中,需要注意的是:hashMap 中 get 一个 groupingKey 对应的 agg buffer 时,若已经存在该 buffer 则直接返回;若不存在,尝试申请内存新建一个:
- 若成功则返回
- 若因为内存不足导致申请失败,则返回 null,这个时候就要进行 spill 了
- hashMap 在处理 rows 条数超过
Integer.MaxValue
时或因内存不足无法为新的 groupingKey 分配新的 agg buffer 时,需要进行 spill。多次 spill 的数据会通过 externalSorter 进行 merge,需要注意的是这里的 merge 并不是把数据都合并了,而是externalSorter.merge(sorter)
后 externalSorter 包含了 sorter 对应的 spill 文件的 reader,即可以通过 externalSorter 读取 sorter 对应的 spill 文件 - 当发生过 hashMap spill,就会从 hash based agg 切换为 sort based agg
上图中,用于真正处理一条 row 的 AggregationIterator#processRow
还需进一步展开分析。在此之前,我们先来看看 AggregateFunction 的分类
4.3、AggregateFunction 的分类
AggregateFunction
可以分为 DeclarativeAggregate
和 ImperativeAggregate
两大类,具体的聚合函数均为这两类的子类。
①. DeclarativeAggregate
DeclarativeAggregate 是一类直接由 Catalyst 中的 Expressions 构成的聚合函数,主要逻辑通过调用 4 个表达式完成,分别是:
- initialValues:聚合缓冲区初始化表达式
- updateExpressions:聚合缓冲区更新表达式,Partial mode 下
AggregationIterator#processRow
会调用该方法读取一行行的输入来更新聚合聚合缓冲区 - mergeExpressions:聚合缓冲区合并表达式,Final mode 下
AggregationIterator#processRow
会调用该方法来对 Partial mode 下生成的相同 groupingKey 的一个个聚合缓冲区进行 merge - evaluteExpression:最终结果生成表达式
我们再次以容易理解的 Count
来举例说明:
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
override def nullable: Boolean = false
// Return data type.
override def dataType: DataType = LongType
private lazy val count = AttributeReference("count", LongType, nullable = false)()
override lazy val aggBufferAttributes = count :: Nil
// 聚合缓冲区初始化表达式,初始值为 0
override lazy val initialValues = Seq(
/* count = */ Literal(0L)
)
// 聚合缓冲区更新表达式,当 input 为非 nulll 的时候对 count 加 1
override lazy val updateExpressions = {
val nullableChildren = children.filter(_.nullable)
if (nullableChildren.isEmpty) {
Seq(
/* count = */ count + 1L
)
} else {
Seq(
/* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
)
}
}
// 聚合缓冲区合并表达式,将两个 agg buffer 进行 merge,这里直接进行相加
override lazy val mergeExpressions = Seq(
/* count = */ count.left + count.right
)
// 最终结果生成表达式,即 count
override lazy val evaluateExpression = count
override def defaultResult: Option[Literal] = Option(Literal(0L))
}
通常来讲,实现一个基于 Expressions 的 DeclarativeAggregate 函数包含以下几个重要的组成部分:
- 定义一个或多个聚合缓冲区的属性(bufferAttribute),例如 Count 只需要 count,这些属性会在 updateExpressions 等各种表达式中用到
- 设定 DeclarativeAggregate 函数的初始值,count 函数的初始值为 0
- 实现数据处理逻辑表达式 updateExpressions,在 count 函数中,当处理新的数据时,上述定义的 count 属性转换为 Add 表达式,即
count + 1L
,注意其中对 Null 的处理逻辑 - 实现 merge 处理逻辑的表达式,函数中直接把 count 相加,对应上述代码中的
count.left + count.right
,由 DeclarativeAggregate 中定义的 RichAttribute 隐式类完成 - 实现结果输出的表达式
evaluteExpression
,返回 count 的值
②. ImperativeAggregate
再来看看 AggregationIterator#processRow
4.4、AggregationIterator#processRow
AggregationIterator#processRow
会调用
def generateProcessRow(
expressions: Seq[AggregateExpression],
functions: Seq[AggregateFunction],
inputAttributes: Seq[Attribute]): (InternalRow, InternalRow) => Unit
生成用于处理一行数据(row)的函数
说白了 processRow 生成了函数才是直接用来接受一条 input row 来更新对应的 agg buffer,具体是根据 mode 及 aggExpression
中的 aggFunction
的类型调用其 updateExpressions 或 mergeExpressions 方法:
比如,对于 aggFunction
为 DeclarativeAggregate 类型的 Partial
下的 Count
来说就是调用其 updateExpressions
方法,即:
val updateExpressions = {
val nullableChildren = children.filter(_.nullable)
if (nullableChildren.isEmpty) {
Seq(
/* count = */ count + 1L
)
} else {
Seq(
/* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
)
}
}
对于 Final
的 Count
来说就是调用其 mergeExpressions
方法,即:
val mergeExpressions = Seq(
/* count = */ count.left + count.right
)
对于 aggFunction
为 ImperativeAggregate
类型的 Partial
下的 Collect
来说就是调用其 update
方法,即:
def update(buffer: T, input: InternalRow): T = {
val value = child.eval(input)
if (value != null) {
buffer += InternalRow.copyValue(value)
}
buffer
}
对于 Final
的 Collect
来说就是调用其 merge
方法,即:
def merge(buffer: T, other: T): T = {
buffer ++= other
}
4.6、读取聚合数据
我们都知道,读取一个迭代器的数据,是要不断调用 hasNext
方法进行 check 是否还有数据,当该方法返回 true
的时候再调用 next
方法取得下一条数据。所以要知道如何读取 TungstenAggregationIterator 的数据,就得分析其这两个方法。
①、TungstenAggregationIterator#hasNext
override final def hasNext: Boolean = {
(sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext)
}
分为两种情况,分别是:
- sortBased 为 true,即由于发生过 spill 切换为 sort based agg 了:sortedInputHasNewGroup 表示是否还有下一条数据,该值在
switchToSortBasedAggregation
初始化 - sortBased 为 false,即尚未发生过 spill,依然是 hash based agg:mapIteratorHasNext 表示是否还有下一条数据,在完成 processInputs 后进行初始化
②、TungstenAggregationIterator#next
Agg 的实现确实复杂,本文虽然篇幅已经很长,但还有很多方面没有 cover 到,但基本最核心、最复杂的点都详细介绍了,如果对于未 cover 的部分有兴趣,请自行阅读源码进行分析~