参数
ntrees 定义多少棵树
ma x_depth 允许多深的树
mtries 选择多少特征
sample_rate 抽取多少比例的数据
histogram_type 如何抽取数据 UniformAdaptive Random QuantilesGlobal RoundRobin
使用默认参数构建一个随机森林
> # Default Random Forest
>
> m <- h2o.randomForest(1:4,5,iris.train,nfolds = 10,model_id = 'RF.fold10')
|============================================================| 100%
> m
Model Details:
==============
H2OMultinomialModel: drf
Model ID: RF.fold10
Model Summary:
number_of_trees number_of_internal_trees model_size_in_bytes
1 50 150 17892
min_depth max_depth mean_depth min_leaves max_leaves mean_leaves
1 1 7 2.95333 2 10 4.52000
H2OMultinomialMetrics: drf
** Reported on training data. **
** Metrics reported on Out-Of-Bag training samples **
Training Set Metrics:
=====================
Extract training frame with `h2o.getFrame("RTMP_sid_ba81_121")`
MSE: (Extract with `h2o.mse`) 0.02576462
RMSE: (Extract with `h2o.rmse`) 0.1605136
Logloss: (Extract with `h2o.logloss`) 0.08515898
Mean Per-Class Error: 0.03143275
Confusion Matrix: Extract with `h2o.confusionMatrix(<model>,train = TRUE)`)
=========================================================================
Confusion Matrix: Row labels: Actual class; Column labels: Predicted class
setosa versicolor virginica Error Rate
setosa 28 0 0 0.0000 = 0 / 28
versicolor 0 23 1 0.0417 = 1 / 24
virginica 0 2 36 0.0526 = 2 / 38
Totals 28 25 37 0.0333 = 3 / 90
Hit Ratio Table: Extract with `h2o.hit_ratio_table(<model>,train = TRUE)`
=======================================================================
Top-3 Hit Ratios:
k hit_ratio
1 1 0.966667
2 2 1.000000
3 3 1.000000
H2OMultinomialMetrics: drf
** Reported on cross-validation data. **
** 10-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
Cross-Validation Set Metrics:
=====================
Extract cross-validation frame with `h2o.getFrame("RTMP_sid_ba81_121")`
MSE: (Extract with `h2o.mse`) 0.03402869
RMSE: (Extract with `h2o.rmse`) 0.1844687
Logloss: (Extract with `h2o.logloss`) 0.1451198
Mean Per-Class Error: 0.03143275
Hit Ratio Table: Extract with `h2o.hit_ratio_table(<model>,xval = TRUE)`
=======================================================================
Top-3 Hit Ratios:
k hit_ratio
1 1 0.966667
2 2 1.000000
3 3 1.000000
Cross-Validation Metrics Summary:
mean sd cv_1_valid
accuracy 0.9608333 0.043624822 1.0
err 0.039166667 0.043624822 0.0
err_count 0.3 0.32403705 0.0
logloss 0.16325581 0.15044996 0.01936583
max_per_class_error 0.08095238 0.09410035 0.0
mean_per_class_accuracy 0.97301584 0.031366784 1.0
mean_per_class_error 0.026984127 0.031366784 0.0
mse 0.038846087 0.035734665 6.680552E-4
r2 0.9322564 0.06199563 0.99895614
rmse 0.15172102 0.088957354 0.025846764
cv_2_valid cv_3_valid cv_4_valid
accuracy 0.9 0.875 1.0
err 0.1 0.125 0.0
err_count 1.0 1.0 0.0
logloss 0.15824482 0.6884255 0.083293706
max_per_class_error 0.14285715 0.33333334 0.0
mean_per_class_accuracy 0.95238096 0.8888889 1.0
mean_per_class_error 0.04761905 0.11111111 0.0
mse 0.056907933 0.12448429 0.025920173
r2 0.91244936 0.7957181 0.9278743
rmse 0.23855384 0.35282332 0.16099744
cv_5_valid cv_6_valid cv_7_valid
accuracy 1.0 0.8333333 1.0
err 0.0 0.16666667 0.0
err_count 0.0 1.0 0.0
logloss 0.01643106 0.44375053 0.05486414
max_per_class_error 0.0 0.33333334 0.0
mean_per_class_accuracy 1.0 0.8888889 1.0
mean_per_class_error 0.0 0.11111111 0.0
mse 8.837725E-4 0.14378169 0.003728628
r2 0.99893945 0.741193 0.9958053
rmse 0.029728312 0.37918556 0.061062492
cv_8_valid cv_9_valid cv_10_valid
accuracy 1.0 1.0 1.0
err 0.0 0.0 0.0
err_count 0.0 0.0 0.0
logloss 0.010296788 0.06945834 0.08842746
max_per_class_error 0.0 0.0 0.0
mean_per_class_accuracy 1.0 1.0 1.0
mean_per_class_error 0.0 0.0 0.0
mse 3.6618378E-4 0.011905115 0.019815048
r2 0.9995278 0.9787409 0.97335976
rmse 0.01913593 0.109110564 0.14076594
> summary(m)
Model Details:
==============
H2OMultinomialModel: drf
Model Key: RF.fold10
Model Summary:
number_of_trees number_of_internal_trees model_size_in_bytes
1 50 150 17892
min_depth max_depth mean_depth min_leaves max_leaves mean_leaves
1 1 7 2.95333 2 10 4.52000
H2OMultinomialMetrics: drf
** Reported on training data. **
** Metrics reported on Out-Of-Bag training samples **
Training Set Metrics:
=====================
Extract training frame with `h2o.getFrame("RTMP_sid_ba81_121")`
MSE: (Extract with `h2o.mse`) 0.02576462
RMSE: (Extract with `h2o.rmse`) 0.1605136
Logloss: (Extract with `h2o.logloss`) 0.08515898
Mean Per-Class Error: 0.03143275
Confusion Matrix: Extract with `h2o.confusionMatrix(<model>,train = TRUE)`)
=========================================================================
Confusion Matrix: Row labels: Actual class; Column labels: Predicted class
setosa versicolor virginica Error Rate
setosa 28 0 0 0.0000 = 0 / 28
versicolor 0 23 1 0.0417 = 1 / 24
virginica 0 2 36 0.0526 = 2 / 38
Totals 28 25 37 0.0333 = 3 / 90
Hit Ratio Table: Extract with `h2o.hit_ratio_table(<model>,train = TRUE)`
=======================================================================
Top-3 Hit Ratios:
k hit_ratio
1 1 0.966667
2 2 1.000000
3 3 1.000000
H2OMultinomialMetrics: drf
** Reported on cross-validation data. **
** 10-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
Cross-Validation Set Metrics:
=====================
Extract cross-validation frame with `h2o.getFrame("RTMP_sid_ba81_121")`
MSE: (Extract with `h2o.mse`) 0.03402869
RMSE: (Extract with `h2o.rmse`) 0.1844687
Logloss: (Extract with `h2o.logloss`) 0.1451198
Mean Per-Class Error: 0.03143275
Hit Ratio Table: Extract with `h2o.hit_ratio_table(<model>,xval = TRUE)`
=======================================================================
Top-3 Hit Ratios:
k hit_ratio
1 1 0.966667
2 2 1.000000
3 3 1.000000
Cross-Validation Metrics Summary:
mean sd cv_1_valid
accuracy 0.9608333 0.043624822 1.0
err 0.039166667 0.043624822 0.0
err_count 0.3 0.32403705 0.0
logloss 0.16325581 0.15044996 0.01936583
max_per_class_error 0.08095238 0.09410035 0.0
mean_per_class_accuracy 0.97301584 0.031366784 1.0
mean_per_class_error 0.026984127 0.031366784 0.0
mse 0.038846087 0.035734665 6.680552E-4
r2 0.9322564 0.06199563 0.99895614
rmse 0.15172102 0.088957354 0.025846764
cv_2_valid cv_3_valid cv_4_valid
accuracy 0.9 0.875 1.0
err 0.1 0.125 0.0
err_count 1.0 1.0 0.0
logloss 0.15824482 0.6884255 0.083293706
max_per_class_error 0.14285715 0.33333334 0.0
mean_per_class_accuracy 0.95238096 0.8888889 1.0
mean_per_class_error 0.04761905 0.11111111 0.0
mse 0.056907933 0.12448429 0.025920173
r2 0.91244936 0.7957181 0.9278743
rmse 0.23855384 0.35282332 0.16099744
cv_5_valid cv_6_valid cv_7_valid
accuracy 1.0 0.8333333 1.0
err 0.0 0.16666667 0.0
err_count 0.0 1.0 0.0
logloss 0.01643106 0.44375053 0.05486414
max_per_class_error 0.0 0.33333334 0.0
mean_per_class_accuracy 1.0 0.8888889 1.0
mean_per_class_error 0.0 0.11111111 0.0
mse 8.837725E-4 0.14378169 0.003728628
r2 0.99893945 0.741193 0.9958053
rmse 0.029728312 0.37918556 0.061062492
cv_8_valid cv_9_valid cv_10_valid
accuracy 1.0 1.0 1.0
err 0.0 0.0 0.0
err_count 0.0 0.0 0.0
logloss 0.010296788 0.06945834 0.08842746
max_per_class_error 0.0 0.0 0.0
mean_per_class_accuracy 1.0 1.0 1.0
mean_per_class_error 0.0 0.0 0.0
mse 3.6618378E-4 0.011905115 0.019815048
r2 0.9995278 0.9787409 0.97335976
rmse 0.01913593 0.109110564 0.14076594
Scoring History:
timestamp duration number_of_trees training_rmse
1 2018-10-03 02:11:29 2.090 sec 0 NA
2 2018-10-03 02:11:29 2.092 sec 1 0.44745
3 2018-10-03 02:11:29 2.093 sec 2 0.39143
4 2018-10-03 02:11:29 2.095 sec 3 0.35182
5 2018-10-03 02:11:29 2.096 sec 4 0.32207
training_logloss training_classification_error
1 NA NA
2 5.71507 0.19355
3 3.93823 0.16364
4 3.18927 0.13235
5 2.38874 0.12987
---
timestamp duration number_of_trees training_rmse
46 2018-10-03 02:11:29 2.208 sec 45 0.16163
47 2018-10-03 02:11:29 2.210 sec 46 0.16142
48 2018-10-03 02:11:29 2.212 sec 47 0.16283
49 2018-10-03 02:11:29 2.214 sec 48 0.16269
50 2018-10-03 02:11:29 2.219 sec 49 0.16196
51 2018-10-03 02:11:29 2.221 sec 50 0.16051
training_logloss training_classification_error
46 0.08575 0.03333
47 0.08605 0.04444
48 0.08731 0.04444
49 0.08701 0.04444
50 0.08648 0.04444
51 0.08516 0.03333
Variable Importances: (Extract with `h2o.varimp`)
=================================================
Variable Importances:
variable relative_importance scaled_importance percentage
1 Petal.Length 1232.214844 1.000000 0.470131
2 Petal.Width 1107.685425 0.898939 0.422619
3 Sepal.Length 213.977249 0.173653 0.081639
4 Sepal.Width 67.124794 0.054475 0.025610
>
检验预测效果
> h2o.performance(m,iris.test)
H2OMultinomialMetrics: drf
Test Set Metrics:
=====================
MSE: (Extract with `h2o.mse`) 0.02854637
RMSE: (Extract with `h2o.rmse`) 0.1689567
Logloss: (Extract with `h2o.logloss`) 0.103123
Mean Per-Class Error: 0.01754386
Confusion Matrix: Extract with `h2o.confusionMatrix(<model>, <data>)`)
=========================================================================
Confusion Matrix: Row labels: Actual class; Column labels: Predicted class
setosa versicolor virginica Error Rate
setosa 9 0 0 0.0000 = 0 / 9
versicolor 0 18 1 0.0526 = 1 / 19
virginica 0 0 6 0.0000 = 0 / 6
Totals 9 18 7 0.0294 = 1 / 34
Hit Ratio Table: Extract with `h2o.hit_ratio_table(<model>, <data>)`
=======================================================================
Top-3 Hit Ratios:
k hit_ratio
1 1 0.970588
2 2 1.000000
3 3 1.000000
网格搜索参数
g <-
h2o.grid("randomForest",
hyper_params = list(ntrees = c(50, 100, 120),
max_depth = c(40,60),
min_rows=c(1,2)),x = 1:4,y = 5,training_frame = iris.train,nfolds=10)
g
g_re <- h2o.getGrid(g@grid_id,sort_by = 'r2',decreasing = T)
选择如何网格搜索
g <- h2o.grid("randomForest",search_criteria = list(
strategy = "RandomDiscrete",
stopping_metric = "mes",
stopping_tolerance = 0.001,
stopping_tounds = 10,
max_runtime_secs = 120
),hyper_params = list(
ntrees = c(50,100,150,200,250),
mtries = c(2,3,4,5)
sample_rate = c(0.5,0.632,0.8,0.95),
col_sample_rate_per_tree = c(0.5,0.9,1)
),
x = x,y = y,training_frame=train,nfolds = 5,max_depth = 40,
stopping_metric = "deviance",stopping_tolerance=0,stopping_rounds = 4,score_tree_interval = 3
)
高级版本
g <-
h2o.grid(
"randomForest",
hyper_params = list(
ntrees = c(50, 100, 120),
max_depth = c(40, 60),
min_rows = c(1, 2)
),
x = x,
y = y,
training_frame = train,
validation_frame = valid,
nfolds = 10
)
g <- h2o.grid(
"randomForest",
search_criteria = list(
strategy = "RandomDiscrete",
stopping_metric = "mes",
stopping_tolerance = 0.001,
stopping_tounds = 10,
max_runtime_secs = 120
),
hyper_params = list(
ntrees = c(50, 100, 150, 200, 250),
mtries = c(2, 3, 4, 5),
sample_rate = c(0.5, 0.632, 0.8, 0.95),
col_sample_rate_per_tree = c(0.5, 0.9, 1)
),
x = x,
y = y,
training_frame = train,
validation_frame = valid,
nfolds = 5,
max_depth = 40,
stopping_metric = "deviance",
stopping_tolerance = 0,
stopping_rounds = 4,
score_tree_interval = 3
)