SELECT:
val df = trainData.select("ruid", "log_date","fans1","watch_num1","danmu_cnt1","gap_days1","money_num1")
JOIN:
val result2=kmeansData.join(result1,Seq("prediction"))
GROUP BY:
kmeansData.groupBy("prediction").count().show()
kmeansData.groupBy("prediction").mean("fans","gap_days")
主播聚类code示例:
package com.bilibili.live
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.feature.{StandardScaler, VectorAssembler}
import org.apache.spark.sql.SparkSession
object AnchorCluster {
def main(args: Array[String]): Unit = {
println("anchor cluster label process start!")
val saveMode = args(0)
val log_dt = args(1)
println(s"log_date: ${log_dt}")
//
val spark = SparkSession
.builder()
.appName("AnchorCluster")
// .master("local[4]")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.sql.warehouse.dir", "spark-warehouse")
.config("spark.sql.broadcastTimeout", "36000")
.enableHiveSupport()
.getOrCreate()
// spark.conf.set("hive.security.authorization.enabled", false)
println("spark version:%s".format(spark.version))
// // 开始聚类
// val numClusters = 4
// val numIterations = 500
// val clusters = KMeans.train(parsedData, numClusters, numIterations)
//
// // Evaluate clustering by computing Within Set Sum of Squared Errors
// // val WSSSE = clusters.computeCost(parsedData)
// // println(s"Within Set Sum of Squared Errors = $WSSSE")
//
// // Save and load model
// clusters.save(sc, "target/org/apache/spark/KMeansExample/KMeansModel")
// val sameModel = KMeansModel.load(sc, "target/org/apache/spark/KMeansExample/KMeansModel")
//
// trainingData.collect().foreach(
// sample =>{
// val predictedCluster=model.predict(sample)
// println(sample.toString + "belongs to cluster" + predictedCluster)}
//
// // 返回数据集和结果
// val result = data.map {
// line =>
// val linevectore = Vectors.dense(line.split(" ").map(_.toDouble))
// val prediction = model.predict(linevectore)
// line + " " + prediction
// }.collect.foreach(println)
//-----------------------------
// 读取sql:
val trainDataSql = "select *,log10(fans+1) as fans1,log10(watch_num+1) as watch_num1," +
"log10(danmu_cnt+1) as danmu_cnt1,log10(gap_days+1) as gap_days1,log10(money_num+1) as money_num1 " +
"from ai_live.ruid_cluster_daily_train" +
s" where log_date = '${log_dt}' "
println("-----------------train data sql----------------")
println(trainDataSql)
val trainData = spark.sql(trainDataSql).limit(500000)
println("-----------------train feature----------------")
trainData.show(3)
// 读取本地数据:
// val trainData: DataFrame = spark.read.format("csv")
// .option("header", "true")
// .option("inferSchema", "true")
// .load("data/anchor_cluster_test.csv")
// .na.fill(-99)
//
// trainData.show(3)
// 处理数据,toarry+归一化
val df = trainData.select("ruid", "log_date","fans1","watch_num1","danmu_cnt1","gap_days1","money_num1",
"fans","watch_num","danmu_cnt","gap_days","money_num")
df.show(numRows = 3)
val assembler = new VectorAssembler()
.setInputCols(df.drop("ruid","log_date","fans","watch_num","danmu_cnt","gap_days","money_num").columns.toArray)
.setOutputCol("features")
val df1=assembler.transform(df)
val scaler = new StandardScaler()
.setInputCol("features")
.setOutputCol("new_feature")
.setWithStd(true)
.setWithMean(false)
val scalerModel = scaler.fit(df1)
val scaledData = scalerModel.transform(df1)
scaledData.show
//聚类
val kmeans=new KMeans()
.setK(4)
.setMaxIter(500)//最大迭代次数
.setFeaturesCol("new_feature")
.setPredictionCol("prediction")
val kmeansModel = kmeans.fit(scaledData)
// 查看聚类结果
val kmeansData = kmeansModel.transform(scaledData)
kmeansData.show()
kmeansData.groupBy("prediction").count().show()
val result1=kmeansData
.groupBy("prediction").mean("fans","gap_days")
.withColumnRenamed("avg(fans)","avg_fans")
.withColumnRenamed("avg(gap_days)","avg_gap_days")
val result2=kmeansData.join(result1,Seq("prediction"))
result2.show()
result2.createOrReplaceTempView("result_final")
val result_final = spark.sql("select ruid,fans,watch_num,danmu_cnt,gap_days,money_num,prediction,"+
"case when avg_fans>10000 then '成熟'" +
" when avg_fans>1000 then '成长' " +
" when avg_gap_days<10 then '新' else '尾部主播' end as label from result_final")
result_final.show()
//写入数据库
spark.catalog.setCurrentDatabase("bili_live")
spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
if (saveMode == "overwrite") {
result_final.write.mode("overwrite")
.partitionBy("log_date")
.format("parquet")
.saveAsTable("ai_live.data_mining_anchor_cluster_predict_d")
} else {
result_final.write.mode("overwrite")
.format("parquet")
.insertInto("ai_live.data_mining_anchor_cluster_predict_d")
}
println("Write table success!")
println("Done!!!")
spark.stop()
}
}