

一旦定义了模型和调优参数值,就应该指定重采样的类型。目前,k折交叉验证重采样(一次或重复)、留一交叉验证重采样和 bootstrap (简单估计或632规则)重采样方法可以被train函数使用。 重采样后,过程中生成性能度量的概要,以指导用户选择哪些调优参数值。 默认情况下,函数会自动选择与最佳值相关联的调优参数。



train()函数可以在模型拟合之前以各种方式对数据进行预处理, 为了指定需要进行哪些预处理,train函数有个preProcess的参数可供调整。


> library(pacman)
> p_load(caret,mlbench,dplyr)
> # 使用Sonar数据集
> data("Sonar")
> # 拆分为训练集和测试集
> # 默认情况下,该函数使用分层随机分割
> set.seed(123)
> ind <- createDataPartition(y = Sonar$Class,
+                            # 训练集所占比例
+                            p = 0.75,
+                            list = F)
> train <- Sonar[ind,]
> test <- Sonar[-ind,]
> ctrl <- trainControl(method = "repeatedcv",
+                      # number交叉验证折数或重采样迭代次数
+                      number = 10,
+                      # repeats确定了反复次数
+                      repeats = 10,
+                      # 是否显示训练过程
+                      verboseIter = T,
+                      # 是否将数据保存到trainingData
+                      returnData = F,
+                      # 训练百分比
+                      p = 0.75,
+                      classProbs = T,
+                      summaryFunction = twoClassSummary,
+                      allowParallel = T)

method可选:"boot", "cv", "LOOCV", "LGOCV", "repeatedcv", "timeslice", "none" 和 "oob";
对时间序列method = "timeslice", 有三个参数initialWindow, horizon 和 fixedWindow;
有的模型预测结果为计算概率,例如“Prob”、“After”、“Response”、“Probability”或“RAW”,classProbs =TRUE则让其返回类别"class";
summaryFunction 指定模型性能统计的函数;
selectionFunction 选择最优参数和抽样的函数;
returnResamp 指定要保存多少性能指标,可为all,final,none;
allowParallel 是否使用并行计算。

> gbm.grid <- expand.grid(interaction.depth = c(1,5,9),
+                         n.trees = c(50, 100, 150, 200, 250, 300),
+                         shrinkage = 0.1,
+                         n.minobsinnode = 20)
> head(gbm.grid)


  1. n.trees:树的迭代次数
  2. interaction.depth:树的复杂度
  3. n.minobsinnode:收敛一个节点中开始分割的最小训练集样本数
##   interaction.depth n.trees shrinkage n.minobsinnode
## 1                 1      50       0.1             20
## 2                 5      50       0.1             20
## 3                 9      50       0.1             20
## 4                 1     100       0.1             20
## 5                 5     100       0.1             20
## 6                 9     100       0.1             20
> nrow(gbm.grid)
## [1] 18
> set.seed(123)
> # method="svmRadial支持向量机
> # “rda正则判别分析模型
> # treebag装袋树
> fit.gbm <- train(Class ~ .,data = train,
+                  # 梯度提升树模型
+                  method = "gbm",
+                  trControl = ctrl,
+                  verbose = F,
+                  tuneGrid = gbm.grid,
+                  metric = "ROC")
> fit.gbm
## Stochastic Gradient Boosting 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 140, 142, 142, 141, 141, 141, ... 
## Resampling results across tuning parameters:
##   interaction.depth  n.trees  ROC        Sens       Spec     
##   1                   50      0.8768105  0.8397222  0.7094643
##   1                  100      0.8934425  0.8737500  0.7714286
##   1                  150      0.8991369  0.8656944  0.7801786
##   1                  200      0.8993031  0.8619444  0.7773214
##   1                  250      0.8987153  0.8650000  0.7789286
##   1                  300      0.9034747  0.8700000  0.7803571
##   5                   50      0.8890377  0.8586111  0.7505357
##   5                  100      0.8994891  0.8694444  0.7746429
##   5                  150      0.9019147  0.8626389  0.7908929
##   5                  200      0.9028720  0.8637500  0.7876786
##   5                  250      0.9028844  0.8700000  0.7876786
##   5                  300      0.9028299  0.8695833  0.7880357
##   9                   50      0.8984772  0.8515278  0.7678571
##   9                  100      0.9083358  0.8659722  0.7864286
##   9                  150      0.9162029  0.8776389  0.7983929
##   9                  200      0.9153100  0.8905556  0.7971429
##   9                  250      0.9168998  0.8902778  0.7991071
##   9                  300      0.9188070  0.8843056  0.7957143
## Tuning parameter 'shrinkage' was held constant at a value of
##  0.1
## Tuning parameter 'n.minobsinnode' was held constant at a value
##  of 20
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees =
##  300, interaction.depth = 9, shrinkage = 0.1 and n.minobsinnode = 20.

最终选择的参数为n.trees = 300, interaction.depth = 9, shrinkage = 0.1, n.minobsinnode = 20。

> trellis.par.set(caretTheme())
> plot(fit.gbm)

可以看到参数n.trees = 300, interaction.depth = 9时ROC达到最高点。
plot中使用metric参数可以查看其他性能指标,比如要查看“Kappa”指标,可以指定:metric = "Kappa"。本模型没有记录Kappa性能,所以无法查看。

查看拟合结果热力图,结果与上图一样,n.trees = 300, interaction.depth = 9时颜色最深:

> trellis.par.set(caretTheme())
> plot(fit.gbm, metric = "ROC", plotType = "level", scales = list(x = list(rot = 90)))


> ggplot(fit.gbm)


> trellis.par.set(caretTheme())
> densityplot(fit.gbm, pch = "|", resamples = "all")



> which.pct <- tolerance(fit.gbm, metric = "ROC", tol = 2, maximize = T)
> fit.gbm$results[which.pct, 1:6]
##      shrinkage interaction.depth   n.minobsinnode   n.trees       ROC   Sens
##  8       0.1                5              20         100     0.9026612 0.8675



> set.seed(123)
> fit.svm <- train(Class ~ .,data=train,
>                  method = "svmRadial", 
>                  trControl = ctrl, 
>                  preProc = c("center", "scale"),
>                  tuneLength = 8,
>                  metric = "ROC")
> set.seed(123)
> fit.rda <- train(Class ~ ., data=train, 
>                  method = "rda", 
>                  trControl = ctrl, 
>                  tuneLength = 4,
>                  metric = "ROC")
> resamps <- resamples(list(GBM = fit.gbm,
>                           SVM = fit.svm,
>                           RDA = fit.rda))
> resamps
## Call:
## resamples.default(x = list(GBM = fit.gbm, SVM = fit.svm, RDA
##  = fit.rda))
## Models: GBM, SVM, RDA 
## Number of resamples: 100 
## Performance metrics: ROC, Sens, Spec 
## Time estimates for: everything, final model fit 
> summary(resamps)
## Call:
## summary.resamples(object = resamps)
## Models: GBM, SVM, RDA 
## Number of resamples: 100 
## ROC 
##          Min.   1st Qu.    Median      Mean   3rd Qu. Max. NA's
## GBM 0.7142857 0.8700397 0.9186508 0.9107093 0.9683780    1    0
## SVM 0.7031250 0.8888889 0.9285714 0.9224578 0.9821429    1    0
## RDA 0.5781250 0.8571429 0.9206349 0.8978720 0.9598214    1    0
## Sens 
##          Min.   1st Qu.    Median      Mean   3rd Qu. Max. NA's
## GBM 0.5000000 0.7777778 0.8888889 0.8798611 1.0000000    1    0
## SVM 0.5555556 0.8750000 0.8750000 0.8825000 0.9166667    1    0
## RDA 0.6250000 0.7777778 0.8750000 0.8747222 1.0000000    1    0
## Spec 
##          Min.   1st Qu.    Median      Mean   3rd Qu. Max. NA's
## GBM 0.4285714 0.7142857 0.8571429 0.7819643 0.8750000    1    0
## SVM 0.3750000 0.7142857 0.7500000 0.7835714 0.8750000    1    0
## RDA 0.2857143 0.5714286 0.7142857 0.7164286 0.8571429    1    0


dotplot(resamps, metric = "ROC")


> dif.value <- diff(resamps)
> summary(dif.value)
## Call:
## summary.diff.resamples(object = dif.value)
## p-value adjustment: bonferroni 
## Upper diagonal: estimates of the difference
## Lower diagonal: p-value for H0: difference = 0
## ROC 
##     GBM      SVM      RDA     
## GBM          -0.01175  0.01284
## SVM 0.614698           0.02459
## RDA 0.844504 0.005992         
## Sens 
##     GBM SVM       RDA      
## GBM     -0.002639  0.005139
## SVM 1              0.007778
## RDA 1   1                  
## Spec 
##     GBM      SVM       RDA      
## GBM          -0.001607  0.065536
## SVM 1.000000            0.067143
## RDA 0.029286 0.004865 



> fit.ctrl <- trainControl(method = "none", classProbs = TRUE)
> set.seed(123)
> fit.final <- train(Class ~ .,data=train,
>                    method = "svmRadial", 
>                    trControl = fit.ctrl, 
>                    preProc = c("center", "scale"),
>                    tuneGrid = fit.svm$bestTune,
>                    metric = "ROC")
> fit.final
## Support Vector Machines with Radial Basis Function Kernel 
## 157 samples
##  60 predictor
##   2 classes: 'M', 'R' 
## Pre-processing: centered (60), scaled (60) 
## Resampling: None 
> pred.final <- predict(fit.final,newdata=test)
> confusionMatrix(pred.final,test$Class)
## Stochastic Gradient Boosting 
## 157 samples
##  60 predictor
##   2 classes: 'M', 'R' 
## No pre-processing
## Resampling: None
> pred.final <- predict(fit.gbm.final, newdata = test)
> confusionMatrix(pred.final, test$Class)
## Confusion Matrix and Statistics
##           Reference
## Prediction  M  R
##          M 23  2
##          R  4 22
##                Accuracy : 0.8824          
##                  95% CI : (0.7613, 0.9556)
##     No Information Rate : 0.5294          
##     P-Value [Acc > NIR] : 8.488e-08       
##                   Kappa : 0.765           
##  Mcnemar's Test P-Value : 0.6831          
##             Sensitivity : 0.8519          
##             Specificity : 0.9167          
##          Pos Pred Value : 0.9200          
##          Neg Pred Value : 0.8462          
##              Prevalence : 0.5294          
##          Detection Rate : 0.4510          
##    Detection Prevalence : 0.4902          
##       Balanced Accuracy : 0.8843          
##        'Positive' Class : M
