- dataframe字段操作
- 打印权重
- 解析概率
- 模型调参
- 初始化spark
- 常用缺失值填充
- StringIndexer 多字段处理
dataframe字段操作
字段split为array: withColumn('catenew', split(col('cates'), ','))
打印权重
rfModel = model_pipe.stages[-1]
attrs = sorted(
(attr['idx'], attr['name']) for attr in
(chain(*df_test_result.schema['features_asb'].metadata['ml_attr']['attrs'].values())) # features_asb为assemble的output
)
feature_weight = [(idx, name, float(rfModel.featureImportances[idx])) for idx, name in attrs]
df_weight = spark.createDataFrame(feature_weight, ['idx', 'feature', 'weight'])
df_weight.orderBy(df_weight.weight.desc()).show(df_weight.count(), truncate=False)
解析概率
from pyspark.sql import functions as F
split_udf = udf(lambda value: float(value[1])) # 需将dataframe的numpy.float64 cast to a python float.
df_result = df_result.withColumn('proba', split_udf('probability')).select('member_id', 'prediction', F.round('proba', 3).alias('proba'))
模型调参
see databricks
pipeline = Pipeline(stages=[assembler, gbdt])
paramGrid = (ParamGridBuilder()
.addGrid(gbdt.maxDepth, [3, 5, 7])
.addGrid(gbdt.maxIter, [15, 20, 25])
.build()) # 参数搜索范围
cv = CrossValidator(estimator=pipeline,estimatorParamMaps=paramGrid,evaluator=BinaryClassificationEvaluator(), numFolds=3)
cvModel = cv.fit(df_train)
df_test_result = cvModel.transform(df_test)
gbdtModel = cvModel.bestModel.stages[-1] # 获得模型
初始化spark
spark = SparkSession.builder.appName('pspredict').enableHiveSupport().config('spark.driver.memory', '8g').getOrCreate() # jupyter
spark.sparkContext.setLogLevel('ERROR')
常用缺失值填充
(1) replace(to_replace, values, subset)
(2) replace('', 'unknown', 'country_nm')
(3) replace(['a', 'b'], ['c', 'd'], 'country_nm'): 将国家(可列表)中a->c, b->d, ab需同类型,b不能为None
(4) replace({-1: 14}, 'stature'): 将stature的-1->14,values参数无效,字典里多个需同类型(string与None不能混用)
(5) fillna('haha'): 将null->'haha', 非string值跳过
(6) fillna('xx', [columns_name]): 将多列统一替换na->xx
(7) fillna({'f1': 24, 'f2': 'hah'}): 多列分别替换
StringIndexer 多字段处理
pyspark StringIndexer 输入列不支持多字段, 考虑使用表达式列表实现
indexer = [StringIndexer(inputCol=x, outputCol='{}_idx'.format(x), handleInvalid='keep') for x in feature_index]