来源:
https://stackoverflow.com/questions/33372838/dealing-with-unbalanced-datasets-in-spark-mllib
数据不均衡有很多种解决方法,这里给一个比较简单的。
给dataframe增加一列权重列。
使用spark lr模型的时候设置参数setWeightCol填入权重列。
def balanceDataset(dataset: DataFrame): DataFrame = {
// Re-balancing (weighting) of records to be used in the logistic loss objective function
val numNegatives = dataset.filter(dataset("label") === 0).count
val datasetSize = dataset.count
val balancingRatio = (datasetSize - numNegatives).toDouble / datasetSize
val calculateWeights = udf { d: Double =>
if (d == 0.0) {
1 * balancingRatio
}
else {
(1 * (1.0 - balancingRatio))
}
}
val weightedDataset = dataset.withColumn("classWeightCol", calculateWeights(dataset("label")))
weightedDataset
}
val df_weighted = balanceDataset(df)
val lr = new LogisticRegression().setLabelCol(labelCol).setWeightCol("classWeightCol")