1. 前言
因为最近工作中有需要自定义udf,所以本文记录下最近所了解到的udf的知识。主要讲述hive中如何自定义udf,至于udf一些原理性的东西,比如udf在mr过程中怎么起作用的,这个涉及到hive的细节,我也不清楚,所以本文不会涉及,知道多少写多少吧。
2. UDF分类
hive中udf主要分为三类:
- 标准UDF
这种类型的udf每次接受的输入是一行数据中的一个列或者多个列(下面我把这个一行的一列叫着一个单元吧,类似表格中的一个单元格),然后输出是一个单元。比如abs, array,asin这种都是标准udf。
自定义标准函数需要继承实现抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDF
- 自定义聚合函数(UDAF)
比如max,min这种函数都是hive内置聚合函数。聚合函数和标准udf的区别是:聚合函数需要接收多行输入才能计算出结果,比如max就需要接收表中所有数据(或者group by中分组内所有数据)才能计算出最大值。
自定义聚合函数需要实现抽象类org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver
- 自定义表生成函数(UDTF)
上面1,2中的udf都只输出一个标量的数据(一个单元)。表生成函数故名思义,其输出有点像子查询,可以一次输出多行多列。
自定义表生成函数需要实现抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
。
2. 自定义UDF
引入maven依赖
<dependency>
<groupId>org.apache.hive</groupId>
<artifactId>hive-exec</artifactId>
<version>2.3.0</version>
</dependency>
2.1 自定义标准UDF
2.1.1 实现抽象类GenericUDF
该类的全路径为:org.apache.hadoop.hive.ql.udf.generic.GenericUDF
1. 抽象类GenericUDF解释
GenericUDF类如下:
public abstract class GenericUDF implements Closeable {
...
/* 实例化后initialize方法只会调用一次
- 参数arguments即udf接收的参数列表对应的objectinspector
- 返回的ObjectInspector对象就是udf返回值的对应的objectinspector
initialize方法中往往做的工作是检查一下arguments是否和你udf需要的参数个数以及类型是否匹配。
*/
public abstract ObjectInspector initialize(ObjectInspector[] arguments)
throws UDFArgumentException;
...
// 真正的udf逻辑在这里实现
// - 参数arguments即udf函数输入数据,这个数组的长度和initialize的参数长度一样
//
public abstract Object evaluate(DeferredObject[] arguments)
throws HiveException;
}
GenericUDF有很多的方法,但是只有上面两个抽象方法需要自己实现。
关于ObjectInspector,HIVE在传递数据时会包含数据本身以及对应的ObjectInspector,ObjectInspector中包含数据类型信息,通过oi去解析获得数据。
2.1.2 实例
假设这里要实现下面这种功能标准udf:
cycle_range(col_name, num)
它的接收一列,以及一了整数值为参数,然后将这列转换为一个index(index 属于[0,num))到列值的映射,像下面这样:
> SELECT cycle_range(name, 3) FROM src_table;
INDEX NAME
{1, "eric"}
{2, "aaron"}
{0, "john"}
{1, "marry"}
{2, "hellen"}
{0, "jerry"}
{1, "ellen"}
...
这里定义一个叫cycle_range的标准udf去实现列值的转换,实现如下:
/**
这里使用注解描述udf信息,当使用beeline命令'describe function cycle_range'时,会输出value中的介绍信息,其中_FUNC_会被替换成真实的udf名称。
*/
@Description(name = "cycle_range",
value = "_FUNC_(x, num) - return a map containes an index as key and x as value",
extended = "Example:\n"
+ " > SELECT _FUNC_(x, 3) FROM src;\n"
+ "{i,x}, i in [0 - 3)\n"
)
public class GenericUDFRange extends GenericUDF {
// 第二个参数是一个整型常量,放在这里
private static LongWritable rangeNum = null;
// index 递增并对rangeNum取模的结果
private static Long index = 0L;
// udf 返回值
private transient Map<Object,Object> ret = new HashMap<Object,Object>();
// udf参数整型常量可以是BYTE/SHORT/INT/LONG 这个converter将它们都转换成long处理
private transient ObjectInspectorConverters.Converter rangeConverter;
// 在inittialize里检查一下参数个数与类型
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
// 只接受两个参数
if(arguments.length != 2){
throw new UDFArgumentException(
"RANGE() requires 2 arguments, got " + arguments.length
);
}
// 第二个参数必须是PRIMITIVE这一类的,这是hive sql内置的类型,可以对应到java的primitive type,此外还必须是BYTE/SHORT/INT/LONG之一
if(arguments[1].getCategory() != ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentException(
"RANGE() only take primitive Integer type, got " + arguments[1].getTypeName()
);
}
PrimitiveObjectInspector poi = (PrimitiveObjectInspector)arguments[1];
// 获取到第二个参数的具体类型枚举
PrimitiveObjectInspector.PrimitiveCategory rangeNumType = poi.getPrimitiveCategory();
ObjectInspector outputInspector = null;
switch (rangeNumType){
case BYTE:
case SHORT:
case INT:
case LONG:
// 以上4个case是合法类型,获得一个converter将这四类都转换成WritableLong处理
rangeConverter = ObjectInspectorConverters.getConverter(
arguments[1], PrimitiveObjectInspectorFactory.writableLongObjectInspector
);
// udf的输出值的oi,输出的是一个map对应的ObjectInspector, key是long,value还是原来的列的oi
outputInspector = ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaLongObjectInspector, arguments[0]);
return outputInspector;
default:
throw new UDFArgumentException(
"RANGE only takes BYTE/SHORT/INT/LONG types as the second arguments type, got " + arguments[1].getTypeName()
);
}
}
// 这里开始接收实际的一行一行的输入数据,然后返回处理后的值
// deferredObjects应该包含两个值,第一个值是列的值,第二个值是那个整型常量range值
public Object evaluate(DeferredObject[] deferredObjects) throws HiveException {
// 拿到range值
Object rangeObject = deferredObjects[1].get();
if(rangeNum == null){
rangeNum = new LongWritable();
// 用coverter都转换成LongWritable,然后保存起来.
rangeObject = rangeConverter.convert(rangeObject);
rangeNum.set(Math.abs(((LongWritable)rangeObject).get()));
}
// 计算index 对range的模
index = (index + 1) % rangeNum.get();
// 由于ret是这个udf实例的成员,用来保存返回的map,而evaluate又会不停的调用,所以这里put前都会clear一下,保证始终只有当前处理后的返回值。
ret.clear();
// 设置返回值,返回
ret.put(index, deferredObjects[0].get());
return ret;
}
public String getDisplayString(String[] strings) {
return getStandardDisplayString("range", strings,",");
}
}
编写好后:
- 打jar包,最好打fat jar,把依赖都打进去,假设我的jar包的路径:"/Users/eric/udf-1.0-SNAPSHOT.jar"
- 在beeline 终端将jar加入hive的classpath:
add jar /Users/eric/udf-1.0-SNAPSHOT.jar
- 创建udf
create temporary function cycle_range as 'me.eric.udfs.GenericUDFRange'
成功后就可以使用了。
2.2 自定义聚合函数UDAF
2.2.1 实现抽象类AbstractGenericUDAFResolver
实现自定义UDAF首先要继承并实现类AbstractGenericUDAFResolver
,有下面两个方法:
public abstract class AbstractGenericUDAFResolver
implements GenericUDAFResolver2
{
@SuppressWarnings("deprecation")
@Override
public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
throws SemanticException {
if (info.isAllColumns()) {
throw new SemanticException(
"The specified syntax for UDAF invocation is invalid.");
}
return getEvaluator(info.getParameters());
}
/**
由于上面的getEvaluator也是调用的这个方法实现,所以只需要重写着这个
getEvaluator即可。 udaf函数的主要逻辑不是getEvaluator方法里里完成的。
而是在其返回的GenericUDAFEvaluator中实现的,那么在getEvaluator方法中往往只需要根据参数info(info中保存了传递给udaf的实际参数信息)做一下udaf的参数类型检查即可,
然后返回用户自定义的GenericUDAFEvaluator。
*/
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] info)
throws SemanticException {
throw new SemanticException(
"This UDAF does not support the deprecated getEvaluator() method.");
}
上面介绍中说到GenericUDAFEvaluator才是真正实现udaf业务逻辑的地方,下面是GenericUDAFEvaluator抽象类的的实现:
public abstract class GenericUDAFEvaluator implements Closeable {
@Retention(RetentionPolicy.RUNTIME)
public static @interface AggregationType {
boolean estimable() default false;
}
...
public static enum Mode {
/**
* PARTIAL1: from original data to partial aggregation data: iterate() and
* terminatePartial() will be called.
*/
PARTIAL1,
/**
* PARTIAL2: from partial aggregation data to partial aggregation data:
* merge() and terminatePartial() will be called.
*/
PARTIAL2,
/**
* FINAL: from partial aggregation to full aggregation: merge() and
* terminate() will be called.
*/
FINAL,
/**
* COMPLETE: from original data directly to full aggregation: iterate() and
* terminate() will be called.
*/
COMPLETE
};
Mode mode;
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
// This function should be overriden in every sub class
// And the sub class should call super.init(m, parameters) to get mode set.
mode = m;
return null;
}
public abstract AggregationBuffer getNewAggregationBuffer() throws HiveException;
public abstract void reset(AggregationBuffer agg) throws HiveException;
public void aggregate(AggregationBuffer agg, Object[] parameters) throws HiveException {
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
iterate(agg, parameters);
} else {
assert (parameters.length == 1);
merge(agg, parameters[0]);
}
}
public Object evaluate(AggregationBuffer agg) throws HiveException {
if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
return terminatePartial(agg);
} else {
return terminate(agg);
}
}
public abstract void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException;
public abstract Object terminatePartial(AggregationBuffer agg) throws HiveException;
public abstract void merge(AggregationBuffer agg, Object partial) throws HiveException;
public abstract Object terminate(AggregationBuffer agg) throws HiveException;
首先是上面枚举类型Mode的几个枚举值:PARTIAL1,PARTIAL2,FINAL,COMPLETE, 同时mode也是方法init的参数。这几个枚举值跟聚合涉及到的过程有关系, map-reduce中聚合往往涉及到shuffle的过程,这其中又可能涉及到map端的combine,然后map到reduce过程中数据的shuffle,然后在在reduce端merge。
下面这张图大概的描述了一下各个阶段的对应关系:
这张图中没有包含COMPLETE,从上面代码中COMPLETE的注释可以看出来,COMPLETE表示直接从原始数据聚合到最终结果,也就是说不存在中间需要先在map端完成部分聚合结果,然后再到reduce端完成最终聚合一个过程,COMPLETE出现在一个完全map only的任务中,所以没有和其他三个阶段一起出现。
上图描述了三个阶段调用的方法,这也就是需要自己实现的方法:
- PARTIAL1
- iterate(AggregationBuffer agg, Object[] parameters)
AggregationBuffer是一个需要你实现的数据结构,用来临时保存聚合的数据,parameters是传递给udaf的实际参数,这个方法的功能可以描述成: 拿到一条条数据记录方法在parameters里,然后聚合到agg中,怎么聚合自己实现,比如agg就是一个数组,你把所有迭代的数据保存到数组中都可以。agg相当于返回结果, - terminatePartial(AggregationBuffer agg)
iterate迭代了map中的数据并保存到agg中,并传递给terminatePartial,接下来terminatePartial完成计算,terminatePartial返回Object类型结果显然还是要传递给下一个阶段PARTIAL2的,但是PARTIAL2怎么知道Object到底是什么?前面提到HIVE都是通过ObjectInspector来获取数据类型信息的,但是PARTIAL2的输入数据ObjectInspector怎么来的?显然每个阶段输出数据对应的ObjectInspector只有你自己知道,上面代码中还有一个init()方法是需要你实现了(init在每一个阶段都会调用一次 ),init的参数m表明了当前阶段(当前处于PARTIAL1),你需要在init中根据当前阶段m,设置一个ObjectInspector表示当前的输出oi就行了,init返回一个ObjectInspcetor表示当前阶段的输出数据类信息(也就是下一阶段的输入数据信息)。
- iterate(AggregationBuffer agg, Object[] parameters)
- PARTIAL2
PARTIAL2的输入是基于PARTIAL1的输出的,PARTIAL1输出即terminatePartial的返回值。- merge(AggregationBuffer agg, Object partial)
agg和partial1中的一样,既是参数,也是返回值。partial就是partial1中terminatePartial的返回值,partial的具体数据信息需要你根据ObjectInspector获取了。merger就表示把partial值先放到agg里,待会计算。 - terminatePartial
和partial1一样。
- merge(AggregationBuffer agg, Object partial)
- FINAL
FINAL进入到reduce阶段,也就是要完成最终结果的计算,和PARTIAL2不同的是它调用terminate,没什么好说的,输出最终结果而已。
关于init方法,方法原型:
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
这个方法会在每个阶段都会调用一次,参数m表示当前调用的阶段,parameters表示当前阶段输入数据的oi。前面提到partial1的terminatePartial的输出就是partial2的输入数据,那么此时partial1的输出数据对应的oi,应该和partial2时调用init的参数parameters对应起来才能保存不出错。
2.2.1 UDAF实例
这里实现的udaf的实例,他完成如下功能:
> SELECT col_concat(id, '<' , '>', ',' ) FROM person;
输出:
<1,2,3,4,5,6>
udaf实现将某一个使用特定符号连接起来,并使用另外的字符包围左右。
第一个参数就是列名,然后open,close,seperator
代码如下:
public class GenericUDAFColConcat extends AbstractGenericUDAFResolver{
public GenericUDAFColConcat() {
}
/**
在getEvaluator中做一些类型检查,
*/
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
// col_concat这个udaf需要接收4个参数
if(parameters.length != 4){
throw new UDFArgumentTypeException(parameters.length - 1,
"COL_CONCAT requires 4 argument, got " + parameters.length);
}
// 且只能用于连接一下PRIMITIVE类型的列
if(parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentTypeException(0,
"COL_CONCAT can only be used to concat PRIMITIVE type column, got " + parameters[0].getTypeName());
}
// 分隔符和包围符,只能时char或者STRING
for(int i = 1; i < parameters.length; ++i){
if(parameters[i].getCategory() != ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentTypeException(i,
"COL_CONCAT only receive type CHAR/STRING as its 2nd to 4th argument's type, got " + parameters[i].getTypeName());
}
PrimitiveObjectInspector poi = (PrimitiveObjectInspector) TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[i]);
if(poi.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.CHAR &&
poi.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING){
throw new UDFArgumentTypeException(i,
"COL_CONCAT only receive type CHAR/STRING as its 2nd to 4th argument's type, got " + parameters[i].getTypeName());
}
}
// 返回自定义的XXXEvaluator
return new GenericUDAFCOLCONCATEvaluator();
}
// 前一节也说过需要实现AbstractAggregationBuffer用来保存聚合的值
private static class ColCollectAggregationBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer{
// 遍历的列值暂时都方放到列表中保存起来。
private List<String> colValueList ;
private String open;
private String close;
private String seperator;
private boolean isInit;
public ColCollectAggregationBuffer() {
colValueList = new LinkedList<>();
this.isInit = false;
}
public void init(String open, String close, String seperator){
this.open = open;
this.close = close;
this.seperator = seperator;
this.isInit = true;
}
public boolean isInit(){
return isInit;
}
public String concat(){
String c = StringUtils.join(colValueList,seperator);
return open + c + close;
}
}
public static class GenericUDAFCOLCONCATEvaluator extends GenericUDAFEvaluator{
// transient避免序列化,因为这些成员其实都是在init中初始化了,没有序列化的意义
// inputOIs用来保存PARTIAL1和COMPELE输入数据的oi,这个各个阶段都可能不一样
private transient List<ObjectInspector> inputOIs = new LinkedList<>();
private transient Mode m;
private transient String pString;
// soi保存PARTIAL2和FINAL的输入数据的oi
private transient StructObjectInspector soi;
private transient ListObjectInspector valueFieldOI;
private transient PrimitiveObjectInspector openFieldOI;
private transient PrimitiveObjectInspector closeFieldOI;
private transient PrimitiveObjectInspector seperatorFieldOI;
private transient StructField valueField;
private transient StructField openField;
private transient StructField closeField;
private transient StructField seperatorField;
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
// 父类的init必须调用
super.init(m,parameters);
this.m = m;
pString = "";
for(ObjectInspector p : parameters){
pString += p.getTypeName();
}
if(m == Mode.PARTIAL1 || m == Mode.COMPLETE){
// 在PARTIAL1和COMPLETE阶段,输入数据都是原始表中数据,而不是中间聚合数据,这里初始化inputOIs
inputOIs.clear();
for(ObjectInspector p : parameters){
inputOIs.add((PrimitiveObjectInspector)p);
}
}else {
// FINAL和PARTIAL2的输入数据OI都是上一阶段的输出,而不是原始表数据,这里parameter[0]其实就是上一阶段的输出oi,具体情况看下面
soi = (StructObjectInspector)parameters[0];
valueField = soi.getStructFieldRef("values");
valueFieldOI = (ListObjectInspector)valueField.getFieldObjectInspector();
openField = soi.getStructFieldRef("open");
openFieldOI = (PrimitiveObjectInspector) openField.getFieldObjectInspector();
closeField = soi.getStructFieldRef("close");
closeFieldOI = (PrimitiveObjectInspector)closeField.getFieldObjectInspector();
seperatorField = soi.getStructFieldRef("seperator");
seperatorFieldOI = (PrimitiveObjectInspector)seperatorField.getFieldObjectInspector();
}
// 这里开始返回各个阶段的输出OI
if(m == Mode.PARTIAL1 || m == Mode.PARTIAL2){
// 后面的terminatePartial实现中,PARTIAL1 PARTIAL2的输出数据都是一个列表,我把中间聚合和结果values, 以及open,close, seperator
// 按序方到列表中,所以这个地方返回的oi是一个StructObjectInspector的实现类,它能够获取list中的值。
ArrayList<ObjectInspector> foi = new ArrayList<>();
foi.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector));
foi.add(
PrimitiveObjectInspectorFactory.javaStringObjectInspector
);
foi.add(
PrimitiveObjectInspectorFactory.javaStringObjectInspector
);
foi.add(
PrimitiveObjectInspectorFactory.javaStringObjectInspector
);
ArrayList<String> fname = new ArrayList<String>();
fname.add("values");
fname.add("open");
fname.add("close");
fname.add("seperator");
return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
}else{
// COMPLETE和FINAL都是返回最终聚合结果了,也就是String,所以这里返回javaStringObjectInspector即可
return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
}
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
return new ColCollectAggregationBuffer();
}
@Override
public void reset(AggregationBuffer aggregationBuffer) throws HiveException {
((ColCollectAggregationBuffer)aggregationBuffer).colValueList.clear();
}
// PARTIAL1和COMPLETE调用,iterate里就是把原始数据(参数objects[0])中的值保存到aggregationBuffer的列表中
@Override
public void iterate(AggregationBuffer aggregationBuffer, Object[] objects) throws HiveException {
assert objects.length == 4;
ColCollectAggregationBuffer ccAggregationBuffer = (ColCollectAggregationBuffer)aggregationBuffer;
ccAggregationBuffer.colValueList.add(
PrimitiveObjectInspectorUtils.getString(objects[0], (PrimitiveObjectInspector)inputOIs.get(0)));
if(!ccAggregationBuffer.isInit()){
ccAggregationBuffer.init(
PrimitiveObjectInspectorUtils.getString(objects[1], (PrimitiveObjectInspector)inputOIs.get(1)),
PrimitiveObjectInspectorUtils.getString(objects[2],(PrimitiveObjectInspector)inputOIs.get(2)),
PrimitiveObjectInspectorUtils.getString(objects[3],(PrimitiveObjectInspector)inputOIs.get(3))
);
}
}
// PARTIAL1和PARTIAL2调用,没做什么,但是返回的值的一个‘List<Object> partialRet’ 和init中返回的StructObjectInspector对应,
@Override
public Object terminatePartial(AggregationBuffer aggregationBuffer) throws HiveException {
ColCollectAggregationBuffer ccAggregationBuffer = (ColCollectAggregationBuffer)aggregationBuffer;
List<Object> partialRet = new ArrayList<>();
partialRet.add(ccAggregationBuffer.colValueList);
partialRet.add(ccAggregationBuffer.open);
partialRet.add(ccAggregationBuffer.close);
partialRet.add(ccAggregationBuffer.seperator);
return partialRet;
}
// PARTIAL2和FINAL调用,参数partial对应上面terminatePartial返回的列表,
@Override
public void merge(AggregationBuffer aggregationBuffer, Object partial) throws HiveException {
ColCollectAggregationBuffer ccAggregationBuffer = (ColCollectAggregationBuffer)aggregationBuffer;
if(partial != null){
// soi在init中初始化了,用它来获取partial中数据。
List<Object> partialList = soi.getStructFieldsDataAsList(partial);
// terminalPartial中数据被保存在list中,这个地方拿出来只是简单了合并两个list,其他不变。
List<String> values = (List<String>)valueFieldOI.getList(partialList.get(0));
ccAggregationBuffer.colValueList.addAll(values);
if(!ccAggregationBuffer.isInit){
ccAggregationBuffer.open = PrimitiveObjectInspectorUtils.getString(partialList.get(1), openFieldOI);
ccAggregationBuffer.close = PrimitiveObjectInspectorUtils.getString(partialList.get(2), closeFieldOI);
ccAggregationBuffer.seperator = PrimitiveObjectInspectorUtils.getString(partialList.get(3), seperatorFieldOI);
}
}
}
// FINAL和COMPLETE调用,此时aggregationBuffer中用list保存了原始表表中一列的所有值,这里完成连接操作,返回一个string类型的连接结果。
@Override
public Object terminate(AggregationBuffer aggregationBuffer) throws HiveException {
return ((ColCollectAggregationBuffer)aggregationBuffer).concat();
}
}
}
2.3 自定义表生成函数
待完成。。。