def LoadUserDefineUDF(user: String, spark: SparkSession): Unit = {
val brainUrl: String = PropertiesLoader.getProperties("database.properties").getProperty("brain.url")
val brainPrefix = brainUrl.substring(0, brainUrl.indexOf("feature-panel/online-feature") - 1)
val udfURL = s"$brainPrefix/udfWarehouse/findInfoByUserId?userId=$user"
val simpleHttp = new SimpleHttp
val result = simpleHttp.fetchResponseText(udfURL)
logger.info("***********************************************************")
logger.info(s"result:$result")
try {
val resultJson = JSON.parseObject(result)
val flag = resultJson.getInteger("code").toInt
flag match {
case 0 => LoadUDF(resultJson, spark)
case _ => logger.error(s"加载用户自定义离线特征处理udf失败!原因:${resultJson.getString("msg")}")
}
} catch {
case e: Exception =>
logger.error(s"加载用户自定义离线特征处理udf失败!原因:服务器异常!" + e.getMessage, e)
}
}
def LoadUDF(jsonObj: JSONObject, spark: SparkSession): Unit = {
val udfArray = jsonObj.getJSONObject("data").getJSONArray("data")
var array = mutable.ArrayBuilder.make[URL]()
logger.info("************************************************************")
val methodMap = new mutable.HashMap[String, (String, String, String)]()
for (i <- 0 to udfArray.length - 1) {
val udfJson = udfArray.getJSONObject(i)
val udfName = udfJson.getString("udfName")
val downLoadJarUrl = udfJson.getString("downLoadJarUrl")
val entryClass = udfJson.getString("entryClass")
val jarName = udfJson.getString("jarName")+".jar"
val functionName = udfJson.getString("functionName")
try {
downLoadJar(downLoadJarUrl, jarName)
spark.sparkContext.addJar(HdfsPrefix+jarName)
val url2 = new URL(s"file:./${jarName}")
logger.info(s"*********加载udf $udfName 成功**********")
methodMap.put(udfName, (functionName, entryClass, jarName))
array += url2
} catch {
case e: Exception =>
logger.error(s"$jarName $functionName $entryClass Exception!!!", e.getMessage)
}
}
ScalaGenerateFunctions(array.result())
methodMap.foreach {
map =>
try {
val (fun, inputType, returnType) = ScalaGenerateFunctions.genetateFunction(map._2._1, map._2._2, map._2._3)
val inputTypes = Try(List(inputType)).toOption
spark.udf.register(map._1, fun, returnType)
logger.info(s"*********spark 注册udf ${map._1} 成功**********")
} catch {
case e: Exception =>
logger.error(s"*********spark 注册udf ${map._1} 失败!!", e.getMessage)
}
}
}
def downLoadJar(url: String, jarName: String): Unit = {
logger.info("*******************************************")
logger.info(s"****************url:$url**********************")
//val path = "E:\\temp\\"
val path = "./"
val file = new File(path)
//val jars = Array("test.jar", "test2.jar")
if (!file.exists()) {
//如果文件夹不存在,则创建新的的文件夹
file.mkdirs()
}
var fileOut: FileOutputStream = null
var conn: HttpURLConnection = null
var inputStream: InputStream = null
try {
val httpUrl = new URL(url)
conn = httpUrl.openConnection().asInstanceOf[HttpURLConnection]
conn.setRequestMethod("GET")
conn.setDoInput(true)
conn.setDoOutput(true)
// post方式不能使用缓存
conn.setUseCaches(false)
//连接指定的资源
conn.connect()
//获取网络输入流
inputStream = conn.getInputStream();
val bis = new BufferedInputStream(inputStream)
fileOut = new FileOutputStream(path + jarName)
val bos = new BufferedOutputStream(fileOut)
val buf = new Array[Byte](4096)
var length = bis.read(buf);
//保存文件
while (length != -1) {
bos.write(buf, 0, length);
length = bis.read(buf);
}
//关闭流
bos.close();
bis.close();
conn.disconnect();
} catch {
case e: Exception =>
logger.error(s"下载jar:$jarName 出错" + e.getMessage, e)
}
}
下面是一个单元测试
@Test
def testStr2VecJson(): Unit = {
System.setProperty("hadoop.home.dir", "D:\\winutils")
val conf = new SparkConf().setAppName("test").setMaster("local[2]")//.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
// val sc = new SparkContext(conf)
val spark = SparkSession.builder().config(conf).getOrCreate()
import spark.implicits._
val data = Array("1", "2")
val rdd = spark.sparkContext.parallelize(data)
val df = rdd.toDF("str")
//这里套用工具类 E:\adworkSpace\autotask\target
val url = new URL("file:F:/ad_codes/data_flow_test/target/data_flow_test-1.0-SNAPSHOT.jar")
val urls = Array(url)
ScalaGenerateFunctions(urls)
val className = "com.vivo.ai.temp.Method"
val methodArray = Array("str2VecJson")
methodArray.foreach {
methodName =>
val (fun, inputType, returnType) = ScalaGenerateFunctions.genetateFunction(methodName, className,"autotask-2.0-SNAPSHOT.jar")
val inputTypes = Try(List(inputType)).toOption
//def builder(e: Seq[Expression]) = ScalaUDF(fun, returnType, e, inputTypes.getOrElse(Nil), Some(methodName))
spark.udf.register(methodName, fun, returnType)
// def builder(e: Seq[Expression]) = ScalaUDF(function = fun, dataType = returnType, children = e, Seq(true), inputTypes = inputTypes.getOrElse(Nil), udfName = Some(methodName))
//
// spark.sessionState.functionRegistry.registerFunction(new FunctionIdentifier(methodName), builder)
}
df.createTempView("strDF")
df.show()
spark.sql("select str2VecJson(str) from strDF").show()
}
其中 com.vivo.ai.temp.Method定义如下
import com.alibaba.fastjson.JSON
import com.vivo.ai.encode.ContinuousEncoder
import com.vivo.ai.encode.util.EncodeEnv
import com.vivo.vector.Vector
import com.alibaba.fastjson.serializer.SerializerFeature
import scala.collection.mutable
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
/**
* @author liangwt
* @create 2020/11/16
* @since 1.0.0
* @Description :
*/
class Method {
def method(value:String):String={
//用户自定义处理方法
value
}
def method2(value:Int):String={
(value+100).toString
}
def testMap(value:Int):scala.collection.Map[String,String]={
scala.collection.Map("1"->"1")
}
def testJMap(value:Int):java.util.Map[String,String]={
scala.collection.Map("1"->"1").asJava
}
def testMap2(value:Int):Map[String,String]={
Map("1"->"1")
}
def testSet(value:Int)={
val set=mutable.Set("1")
set.asJava
}
def testSeq(value:Int)={
Seq("1")
}
def inputSeq(seq:Seq[Int]): String ={
"1"
}
def inputMap(map:Map[String,Integer]):String={
"1"
}
def str2VecJson(str:String):String={
var userNewsRTFeatureVec: Vector = Vector.builder(24).build()
var userArrayBuffer = new ArrayBuffer[(Int,Float)]()
// str.split(",").foreach{
// line=>
//
// }
// Tools.str2Map(str).map {
// case (k, v) =>
// //val index = ContinuousEncoder.encode("news_category_v3", k, EncodeEnv.PRD)
// userArrayBuffer += (index -> v.toString.toFloat)
// }
userArrayBuffer +=(1->0.1f)
userArrayBuffer.sortWith((x,y) => x._1 < y._1)
userNewsRTFeatureVec.setIndices(userArrayBuffer.map(_._1).toArray)
userNewsRTFeatureVec.setValues(userArrayBuffer.map(_._2).toArray)
JSON.toJSONString(userNewsRTFeatureVec,SerializerFeature.IgnoreNonFieldGetter)
}
}