概念
- 自定义标量函数,接收一个或多个列,输出一个列,行与行是一一对应的
- 构造函数在jobmanager上创建udf时就执行
- open方法在所有并行子任务上都执行一次,且在调用该udf时才会执行open方法
- 通过DataTypeHint注解和FunctionHint注解可以自定义udf参数和返回类型
- 通过重写getTypeInference方法动态指定udf返回类型
- deterministic为true时,若eval方法无参或传入常量参数,则eval方法仅会执行一次,所有行的调用结果都采用此次执行的eval返回值。若deterministic为false时,则无论何种情况均每行执行一次eval方法获得对应的返回值。deterministic默认为true,可通过重写isDeterministic方法指定其值。
定义
定义udf类,继承ScalarFunction,并实现eval方法,参数自定义
// 实现nvl函数,接收任意类型的参数,若第一个参数为null则返回第二个参数的值,否则返回第一个参数,且返回值类型恒等于第一个参数类型
import org.apache.flink.table.functions.ScalarFunction;
private Class<?> valueConvertClass;
private Class<?> defaultValueConvertClass;
Constructor<?> convertConstructor = null;
Method staticConvertMethod = null;
Map<Class<?>, Class<?>> typeMap;
public Nvl() {
// 保存引用类型与基本类型的对应关系, 因为所有valueOf转换方法都要求传入基本类型, 而defaultValueConvertClass获取到的有可能是其引用类型
typeMap = new HashMap<>();
typeMap.put(Integer.class, int.class);
typeMap.put(Long.class, long.class);
typeMap.put(Double.class, double.class);
typeMap.put(Character.class, char.class);
typeMap.put(Byte.class, byte.class);
typeMap.put(Short.class, short.class);
typeMap.put(Float.class, float.class);
}
@Override
public void open(FunctionContext context) throws Exception {
if (valueConvertClass != defaultValueConvertClass) {
if (valueConvertClass.equals(BigDecimal.class)) {
// 对应FlinkSQL的DECIMAL类型
// 使用BigDecimal的构造函数把目标对象转为BigDecimal对象, 源类型BigDecimal,int,long,char[],string,double,BigInteger
convertConstructor = BigDecimal.class.getConstructor(typeMap.getOrDefault(defaultValueConvertClass, defaultValueConvertClass));
} else if (valueConvertClass.equals(String.class)) {
// 对应flinkSQL的STRING、VARCHAR、CHAR类型
// 使用String.valueOf方法把目标对象转为String对象, 原类型支持所有基本数据类型
staticConvertMethod = String.class.getMethod("valueOf", typeMap.getOrDefault(defaultValueConvertClass, defaultValueConvertClass));
} else if (valueConvertClass.equals(Integer.class)) {
// 对应FlinkSQL的INT类型
// 使用Integer.valueOf方法把目标对象转为Integer对象, 源类型仅支持int类型和String类型
staticConvertMethod = Integer.class.getMethod("valueOf", typeMap.getOrDefault(defaultValueConvertClass, defaultValueConvertClass));
} else if (valueConvertClass.equals(Boolean.class)) {
// 对应flinkSQL的BOOLEAN类型
// 使用Boolean.valueOf方法把目标对象转为Boolean对象, 源类型仅支持boolean类型和String类型
staticConvertMethod = Boolean.class.getMethod("valueOf", typeMap.getOrDefault(defaultValueConvertClass, defaultValueConvertClass));
} else if (valueConvertClass.equals(Byte.class)) {
// 对应FlinkSQL的TINYINT类型
// 使用Byte.valueOf方法把目标对象转为Byte对象, 源类型仅支持byte类型和String类型
staticConvertMethod = Byte.class.getMethod("valueOf", typeMap.getOrDefault(defaultValueConvertClass, defaultValueConvertClass));
} else if (valueConvertClass.equals(Short.class)) {
// 对应FlinkSQL的SMALLINT类型
// 使用Short.valueOf方法把目标对象转为Short对象, 源类型仅支持short类型和String类型
staticConvertMethod = Short.class.getMethod("valueOf", typeMap.getOrDefault(defaultValueConvertClass, defaultValueConvertClass));
}else if (valueConvertClass.equals(Long.class)) {
// 对应FlinkSQL的BIGINT类型
// 使用Long.valueOf方法把目标对象转为Long对象, 源类型仅支持long类型和String类型
staticConvertMethod = Long.class.getMethod("valueOf", typeMap.getOrDefault(defaultValueConvertClass, defaultValueConvertClass));
} else if (valueConvertClass.equals(Float.class)) {
// 对应FlinkSQL的FLOAT类型
// 使用Float.valueOf方法把目标对象转为Float对象, 源类型仅支持float类型和String类型
staticConvertMethod = Float.class.getMethod("valueOf", typeMap.getOrDefault(defaultValueConvertClass, defaultValueConvertClass));
} else if (valueConvertClass.equals(Double.class)) {
// 对应FlinkSQL的DOUBLE类型
// 使用Double.valueOf方法把目标对象转为Double对象, 源类型仅支持double类型和String类型
staticConvertMethod = Double.class.getMethod("valueOf", typeMap.getOrDefault(defaultValueConvertClass, defaultValueConvertClass));
} else if (valueConvertClass.equals(Date.class)) {
// 对应FlinkSQL的DATE类型
// 使用Date.valueOf方法把目标对象转为Date对象, 源类型仅支持LocalDate类型和String类型
staticConvertMethod = Date.class.getMethod("valueOf", defaultValueConvertClass);
} else if (valueConvertClass.equals(LocalDate.class)) {
// 对应FlinkSQL的DATE类型
// 使用LocalDate.parse方法把目标对象转为LocalDate对象, 源类型仅支持CharSequence类型
staticConvertMethod = LocalDate.class.getMethod("parse", defaultValueConvertClass);
} else if (valueConvertClass.equals(Time.class)) {
// 对应FlinkSQL的TIME(0)类型
// 使用Time.valueOf方法把目标对象转为Time对象, 源类型仅支持LocalTime类型和String类型
staticConvertMethod = Time.class.getMethod("valueOf", defaultValueConvertClass);
} else {
throw new RuntimeException("unsupported datatype: " + defaultValueConvertClass.getName());
}
}
}
// eval方法,实现udf返回
// 重写getTypeInference方法,以及声明eval方法返回类型为Object,实现动态返回类型
// 使用DataTypeHin注解自定义udf参数类型,inputGroup = InputGroup.ANY时表示接收任意类型的参数,搭配Object类型的参数类型,实现对任意类型参数的接收处理
public Object eval(@DataTypeHint(inputGroup = InputGroup.ANY) Object value,
@DataTypeHint(inputGroup = InputGroup.ANY) Object defaultValue) throws InvocationTargetException, InstantiationException, IllegalAccessException {
if (value != null) {
return value;
} else if (staticConvertMethod != null) {
return staticConvertMethod.invoke(null, defaultValue);
} else if (convertConstructor != null) {
return convertConstructor.newInstance(defaultValue);
} else {
return defaultValue;
}
}
// 获取第一个参数的类型, 此类型通过其字段类型得到,并将其作为udf返回类型
@Override
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
return TypeInference.newBuilder()
.outputTypeStrategy(callContext -> {
// getConversionClass为引用类型
valueConvertClass = callContext.getArgumentDataTypes().get(0).getConversionClass();
defaultValueConvertClass = callContext.getArgumentDataTypes().get(1).getConversionClass();
return Optional.of(callContext.getArgumentDataTypes().get(0));
})
.build();
}
}
使用
// table api
table.select($"nvl",call(new Nvl(),$"col1",$"col2"));
// flink sql
tableEnv.createTemporaryFunction("nvl",Nvl.class);
tableEnv.createTemporaryFunction("nvl",new Nvl());