UDAF自定义函数实战
UDAF:User Defined Aggregate Function。用户自定义聚合函数。是Spark 1.5.x引入的最新特性。
上节课讲解了UDF,其实更多的是针对单行输入,返回一个输出
这里的UDAF,则可以针对多行输入,进行聚合计算,返回一个输出,功能更加强大
案例,统计字符串出现的次数
public class StringCount extends UserDefinedAggregateFunction {
@Override
public StructType inputSchema() {
List<StructField> fieldList = new ArrayList<StructField>();
fieldList.add(DataTypes.createStructField("str", DataTypes.StringType, true));
return DataTypes.createStructType(fieldList);
}
@Override
public StructType bufferSchema() {
List<StructField> fieldList = new ArrayList<StructField>();
fieldList.add(DataTypes.createStructField("count", DataTypes.IntegerType, true));
return DataTypes.createStructType(fieldList);
}
@Override
public DataType dataType() {
return DataTypes.IntegerType;
}
@Override
public boolean deterministic() {
return true;
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0,0);
}
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update(0, buffer.getInt(0) + 1);
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0));
}
@Override
public Object evaluate(Row buffer) {
return buffer.getInt(0);
}
}
public class UDAF {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("UDAFJava").setMaster("local");
JavaSparkContext sparkContext = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(sparkContext);
List<String> stringList = new ArrayList<String>();
stringList.add("Feng Xiangbin");
stringList.add("Feng Xiangbin");
stringList.add("Feng Xiangbin");
stringList.add("Zhao Jun");
stringList.add("Zhao Jun");
stringList.add("Zhao Jun");
stringList.add("Spark");
stringList.add("Spark");
stringList.add("Hadoop");
stringList.add("Hadoop");
JavaRDD<String> rdd = sparkContext.parallelize(stringList);
JavaRDD<Row> nameRDD = rdd.map(new Function<String, Row>() {
@Override
public Row call(String v1) throws Exception {
return RowFactory.create(v1);
}
});
List<StructField> fieldList = new ArrayList<StructField>();
fieldList.add(DataTypes.createStructField("name", DataTypes.StringType, true));
StructType structType = DataTypes.createStructType(fieldList);
DataFrame dataFrame = sqlContext.createDataFrame(nameRDD, structType);
dataFrame.registerTempTable("name");
sqlContext.udf().register("stringCount", new StringCount());
sqlContext.sql("select name,stringCount(name) from name group by name").javaRDD().foreach(new VoidFunction<Row>() {
@Override
public void call(Row row) throws Exception {
System.out.println("row:" + row);
}
});
}
}
Scala版本
class StringCount extends UserDefinedAggregateFunction{
// inputSchema,指的是,输入数据的类型
override def inputSchema: StructType = StructType(Array(StructField("str", StringType, true)))
// bufferSchema,指的是,中间进行聚合时,所处理的数据的类型
override def bufferSchema: StructType = StructType(Array(StructField("count", IntegerType, true)))
// dataType,指的是,函数返回值的类型
override def dataType: DataType = IntegerType
override def deterministic: Boolean = true
// 为每个分组的数据执行初始化操作
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = 0
// 指的是,每个分组,有新的值进来的时候,如何进行分组对应的聚合值的计算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit ={
buffer(0) = buffer.getAs[Int](0) + 1
}
// 由于Spark是分布式的,所以一个分组的数据,可能会在不同的节点上进行局部聚合,就是update
// 但是,最后一个分组,在各个节点上的聚合值,要进行merge,也就是合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
}
// 最后,指的是,一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
override def evaluate(buffer: Row): Any = buffer.getAs[Int](0)
}
object UDAF {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("UDAFScala").setMaster("local")
val sparkContext = new SparkContext(conf)
val sqlContext = new SQLContext(sparkContext)
val name = Array("Feng Xiangbin","Feng Xiangbin","Feng Xiangbin", "Zhao Jun","Zhao Jun","Zhao Jun","Zhao Jun", "Spark", "Hadoop")
val nameRDD = sparkContext.parallelize(name)
val nameRowRDD = nameRDD.map(s => Row(s))
val structType = StructType(Array(StructField("name", StringType, true)))
val df = sqlContext.createDataFrame(nameRowRDD, structType)
df.registerTempTable("name")
sqlContext.udf.register("stringCount", new StringCount)
sqlContext.sql("select name, stringCount(name) from name group by name").rdd.foreach(row => println("row:" + row))
}
}