自定义聚合函数的场景
业务需要统计最接近两年某商品在门店销售价格的中位数
由于spark 原生并不支持这样的聚合操作,所这个时候自定义聚合函数产生了。
中位数:所有输入数据排序,取中间的一个结果,或者中间两个结果的平均数。
自定义聚合函数开发步骤
1、 自定义类 class,并且继承 UserDefinedAggregateFunction。
2、 重写父类方法、、以及属性。
3、 注册自方法 使用 session.udf.register。
实现类
package cn.harsons.mbd.fun
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import scala.collection.mutable.ListBuffer
/**
* 自定义聚合函数
*
* @author liyabin
* @date 2020/3/11 0011
*/
class Middle extends UserDefinedAggregateFunction {
/**
* 分割字符串
*/
val split_str = "_"
// 输入值 类型
override def inputSchema: StructType = StructType(StructField("data", DoubleType) :: Nil)
// 缓冲类型
override def bufferSchema: StructType = StructType(StructField("middle", StringType) :: Nil)
// 返回值类型
override def dataType: DataType = DoubleType
//对于数据一样的情况下 返回值时候一样
override def deterministic: Boolean = true
/**
* 初始化时调用
*
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, "")
}
/**
* 一个节点统计操作,每次输入一行记录。需要根据旧的缓冲和新来的数据 做逻辑处理
*
* @param buffer 缓冲引用
* @param input 新的值
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.get(0).asInstanceOf[String] + split_str + input.getDouble(0).toString)
}
/**
* 多条记录时如何处理 -》 其实就是两个Node计算出来的结果合并操作
*
* @param buffer1 节点一的缓冲区
* @param buffer2 节点二缓冲区
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.get(0).asInstanceOf[String] + split_str + buffer2.get(0).asInstanceOf[String])
}
/**
* 最后输出 即 函数输出。 这里作用主要是取中位数。
*
* @param buffer 汇集后的缓冲区
* @return
*/
override def evaluate(buffer: Row): Any = {
val str = buffer.get(0).asInstanceOf[String]
val arrays = str.split(split_str)
val list = new ListBuffer[Double]
for (str <- arrays) {
if (str != null && !str.isEmpty) {
list.append(str.toDouble)
}
}
if (list.isEmpty) {
return null
}
val sorted = list.sorted
var size = sorted.size
size = sorted.size
// 偶数
if (size % 2 == 0) {
val middle_first = size / 2
val middle_second = (size / 2) - 1
(sorted(middle_first) + sorted(middle_second)) / 2
} else {
sorted(size / 2)
}
}
}
执行查询
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[2]").getOrCreate()
val middle = spark.udf.register("middle", new Middle)
val data = spark.createDataFrame(Seq(
("篮球", 56.0), ("足球", 66.0), ("高尔夫", 666.0),
("篮球", 57.0), ("足球", 166.0), ("高尔夫", 424.0),
("篮球", 58.0), ("足球", 266.0), ("高尔夫", 369.0),
("篮球", 59.0), ("足球", 111.0), ("高尔夫", 99.0),
("篮球", 66.0), ("足球", 99.0), ("高尔夫", 100.0))).toDF("name", "price")
data.createOrReplaceTempView("orders")
spark.sql("select name , middle(price) as middlePrice from orders group by name ").show(10)
spark.stop()
}
结果输出
踩过的坑
楼主也是刚接触Spark,刚接触这个自定义函数时使用的是强类型自定义聚合函数。当时是想着使用ListBuffer 还缓冲列中所有结果,发现使用ListBuffer Spark 在生成代码时会报错,类型不支持。后面改成弱类型的ObjectType 也是报错。最终无奈之下只能用String 拼接。拼接完后在切割。如果大佬有好的解决办法还请赐教 !