Project Tungsten: Bringing Apache Spark Closer to Bare Metal
udaf的注册必须使用sqlContext.udf.register("",)进行
An UDAF inherits the base class UserDefinedAggregateFunctionand implements the following eight methods, which are:
inputSchema: inputSchemareturns aStructTypeand every field of this StructType represents an input argument of this UDAF.
bufferSchema: bufferSchemareturns aStructTypeand every field of this StructType represents a field of this UDAF’s intermediate results.
dataType: dataTypereturns aDataTyperepresenting the data type of this UDAF’s returned value.
deterministic: deterministicreturns a boolean indicating if this UDAF always generate the same result for a given set of input values.
initialize: initializeis used to initialize values of an aggregation buffer, represented by aMutableAggregationBuffer.
update: updateis used to update an aggregation buffer represented by aMutableAggregationBufferfor an inputRow.
merge: mergeis used to merge two aggregation buffers and store the result to aMutableAggregationBuffer.
evaluate: evaluateis used to generate the final result value of this UDAF based on values stored in an aggregation buffer represented by aRow.
https://databricks.com/blog/2015/09/16/apache-spark-1-5-dataframe-api-highlights.html
定义:
importorg.apache.spark.sql.expressions.MutableAggregationBuffer
importorg.apache.spark.sql.expressions.UserDefinedAggregateFunction
importorg.apache.spark.sql.Row
importorg.apache.spark.sql.types._
classGeometricMeanextendsUserDefinedAggregateFunction{
// 输入的数据格式
def inputSchema:org.apache.spark.sql.types.StructType=
StructType(StructField("value",DoubleType)::Nil)
// 缓存临时结果的数据格式
def bufferSchema:StructType=StructType(
StructField("count",LongType)::
StructField("product",DoubleType)::Nil
)
// 返回的结果数据类型
def dataType:DataType=DoubleType
// 幂等性 多次执行 相同
def deterministic:Boolean=true
// 初始化 一个保留个数 一个保留乘积
def initialize(buffer:MutableAggregationBuffer):Unit={
buffer(0)=0L
buffer(1)=1.0
}
// 更新 数据个数加1 与之前数据相乘
def update(buffer:MutableAggregationBuffer,input:Row):Unit={
buffer(0)=buffer.getAs[Long](0)+1
buffer(1)=buffer.getAs[Double](1)*input.getAs[Double](0)
}
// 缓存数据合并
def merge(buffer1:MutableAggregationBuffer,buffer2:Row):Unit={
buffer1(0)=buffer1.getAs[Long](0)+buffer2.getAs[Long](0)
buffer1(1)=buffer1.getAs[Double](1)*buffer2.getAs[Double](1)
}
// 计算结果
def evaluate(buffer:Row):Any={
math.pow(buffer.getDouble(1),1.toDouble/buffer.getLong(0))
}
}
使用:
importorg.apache.spark.sql.functions._
// Create a simple DataFrame with a single column called "id"
// containing number 1 to 10.
val df=sqlContext.range(1,11) // 默认生成的列名为id
// Create an instance of UDAF GeometricMean.
val gm=newGeometricMean
// Show the geometric mean of values of column "id".
df.groupBy().agg(gm(col("id")).as("GeometricMean")).show()
// Register the UDAF and call it "gm".
sqlContext.udf.register("gm",gm)
// Invoke the UDAF by its assigned name.
df.groupBy().agg(expr("gm(id) as GeometricMean")).show()