collect_list和collect_set,都用于将同一个分组内的指定字段的值串起来,变成一个数组
常用于行转列
比如说
depId=1, employee=leo
depId=1, employee=jack
depId=1, employees=[leo, jack]
代码
object AggregateFunction {
case class Employee(name: String, age: Long, depId: Long, gender: String, salary: Long)
case class Department(id: Long, name: String)
def main(args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName("AggregateFunction")
.master("local")
.getOrCreate()
import sparkSession.implicits._
import org.apache.spark.sql.functions._
val employeePath = this.getClass.getClassLoader.getResource("employee.json").getPath
val departmentPath = this.getClass.getClassLoader.getResource("department.json").getPath
val employeeDF = sparkSession.read.json(employeePath)
val departmentDF = sparkSession.read.json(departmentPath)
val employeeDS = employeeDF.as[Employee]
val departmentDS = departmentDF.as[Department]
employeeDS
.groupBy(employeeDS("depId"))
.agg(collect_set(employeeDS("name")), collect_list(employeeDS("name")))
.collect()
.foreach(println(_))
}
}