原文:SparkSQL自定义 UDF 函数median求中位数
前言
我的场景:提供一个聚合组件操作Spark的DataFrame,然后支持先分组在聚合的功能,这里聚合要求支持最大值个数、求和、去重后求和、均值、中位数、最大值、最小值、方差、标准差、唯一值个数、唯一值、归一化等。
实现下来发现除中位数和归一化外其他聚合均有内置函数,实现起来也就很容易了。
但是在分组后计算中位数这里卡了很长时间,最后的解决办法是:自定义一个UDF函数实现分组后中位数的计算
自定义中位数函数:CustomMedian.scala
/**
* 自定义计算中位数聚合函数
* qi.wang<Email>1124602935@qq.com</Email>
*/
object CustomMedian extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(StructField("input", StringType) :: Nil)
override def bufferSchema: StructType = StructType(StructField("sum", StringType) :: StructField("count", StringType) :: Nil)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ""
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buffer(0) = buffer.get(0) + "," + input.get(0)
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.get(0) + "," + buffer2.get(0))
}
override def evaluate(buffer: Row): Any = {
val list = new util.ArrayList[Integer]
val stringList:Array[String] = buffer.getString(0).split(",")
for (s <- stringList) {
if (StringUtils.isNotBlank(s))
list.add(s.toInt)
}
Collections.sort(list)
val size = list.size
var num:Double = 0L
if (size % 2 == 1) num = list.get(((size+1) / 2) - 1).toDouble
if (size % 2 == 0) num = (list.get(size / 2 - 1) + list.get(size / 2)) / 2.00
num
}
}
函数测试
- 造一个数据文件:/tmp/data.csv, 内容如下
id|name|mobile|idnumber
10|aa|11111111111|111111111111111111
12|bb|12321321321|213123123213333333
13|aa|21312332322|333333333333333334
15|dd|23114567888|872837482374932794
17|bb|44444444444|827183787373733333
18|bb|55555555555|823048320999399999
- 测试代码
package www.relaxheart.cn
import www.relaxheart.cn.CustomMedian
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types._
import scala.util.Random
/**
* @author 王琦<QQ.Email>1124602935@qq.com</QQ.Email>
* @date 19/8/13 下午20:33
* @description
*/
object MedianUDFTest extends App {
val spark = SparkSession.builder().master("local[*]").appName("MedianUDFTest").config("spark.sql.crossJoin.enabled", "true").getOrCreate()
// 读取data.csv得到RDD
val rdd = spark.sparkContext.textFile("/tmp/data.csv")
// 从第一行数据中获取最后转成的DataFrame应该有多少列 并给每一列命名
val colNames = rdd.first.split("\\|")
// 设置DataFrame的结构
val schema = StructType(colNames.map(fieldName => StructField(fieldName, StringType)))
// 对每一行的数据进行处理
val rowRDD = rdd.filter(_.split("\\|")(0) != "id").map(_.split("\\|")).map(p => Row(p: _*))
// 创建DataFrame
val data = spark.createDataFrame(rowRDD, schema)
// 创建临时表
val tmpTable = "_table"+System.currentTimeMillis()+Random.nextInt(10000000)
data.createOrReplaceTempView(tmpTable)
// 这步很关键,注册我们的自定义中位数函数
spark.udf.register("median", CustomMedian)
// 利用SparkSQL + 自定义中位数函数实现分组后求中位数
// 这里对测试数据按name进行分组,然后组内id的中位数
val medianGroupDF = spark.sql(s"select name , median(id) as median from $tmpTable group by name")
// 打印分组中位数聚合结果
medianGroupDF.show()
}
结果验证
看打印结果是符合我们预期的。