56-caret包学习:模型训练与调优

1、模型训练与参数优化

在进行建模时,需对模型的参数进行优化,在caret包中其主要函数是train。
一旦定义了模型和调优参数值,就应该指定重采样的类型。目前,k折交叉验证重采样(一次或重复)、留一交叉验证重采样和 bootstrap (简单估计或632规则)重采样方法可以被train函数使用。 重采样后,过程中生成性能度量的概要,以指导用户选择哪些调优参数值。 默认情况下,函数会自动选择与最佳值相关联的调优参数。

重采样方法:
交叉验证重采样将样本分割,一部分作为训练样本,一部分作为测试样本,通过计算在测试样本上的误差率来估计测试误差,常见的交叉验证技术有留一交叉验证和K折交叉验证法;拔靴法(bootstrap)是利用有限的样本资料经由多次重复抽样,重新建立起足以代表母体样本分布的新样本,其主要特点是能够被广泛的应用到各种统计学习方法中,特别是在对难以估计或者统计软件不能直接给出结果的变量的估计中。

2、自定义调优过程

train()函数可以在模型拟合之前以各种方式对数据进行预处理, 为了指定需要进行哪些预处理,train函数有个preProcess的参数可供调整。

交替调优网格:
tuneGrid()函数可以为每个调整参数生成一个数据框,数据框的列名为拟合模型的参数,比如RDA模型,列名将为gamma和lambda。

> library(pacman)
> p_load(caret,mlbench,dplyr)
> # 使用Sonar数据集
> data("Sonar")
> 
> # 拆分为训练集和测试集
> # 默认情况下,该函数使用分层随机分割
> set.seed(123)
> ind <- createDataPartition(y = Sonar$Class,
+                            # 训练集所占比例
+                            p = 0.75,
+                            list = F)
> train <- Sonar[ind,]
> test <- Sonar[-ind,]
> ctrl <- trainControl(method = "repeatedcv",
+                      # number交叉验证折数或重采样迭代次数
+                      number = 10,
+                      # repeats确定了反复次数
+                      repeats = 10,
+                      # 是否显示训练过程
+                      verboseIter = T,
+                      # 是否将数据保存到trainingData
+                      returnData = F,
+                      # 训练百分比
+                      p = 0.75,
+                      classProbs = T,
+                      summaryFunction = twoClassSummary,
+                      allowParallel = T)

method确定多次交叉检验的抽样方法;
method可选:"boot", "cv", "LOOCV", "LGOCV", "repeatedcv", "timeslice", "none" 和 "oob";
"oob"袋外估计值,只能用于randomForest、袋外决策树、袋外earth、袋外柔性判别分析或者条件树森林模型,不适合GBM模型;
对时间序列method = "timeslice", 有三个参数initialWindow, horizon 和 fixedWindow;
有的模型预测结果为计算概率,例如“Prob”、“After”、“Response”、“Probability”或“RAW”,classProbs =TRUE则让其返回类别"class";
summaryFunction 指定模型性能统计的函数;
selectionFunction 选择最优参数和抽样的函数;
returnResamp 指定要保存多少性能指标,可为all,final,none;
allowParallel 是否使用并行计算。

> gbm.grid <- expand.grid(interaction.depth = c(1,5,9),
+                         n.trees = c(50, 100, 150, 200, 250, 300),
+                         shrinkage = 0.1,
+                         n.minobsinnode = 20)
> head(gbm.grid)

梯度提升机模型(GBM)有三个主要的参数:

  1. n.trees:树的迭代次数
  2. interaction.depth:树的复杂度
  3. n.minobsinnode:收敛一个节点中开始分割的最小训练集样本数
##   interaction.depth n.trees shrinkage n.minobsinnode
## 1                 1      50       0.1             20
## 2                 5      50       0.1             20
## 3                 9      50       0.1             20
## 4                 1     100       0.1             20
## 5                 5     100       0.1             20
## 6                 9     100       0.1             20
> nrow(gbm.grid)
## [1] 18
> set.seed(123)
> # method="svmRadial支持向量机
> # “rda正则判别分析模型
> # treebag装袋树
> fit.gbm <- train(Class ~ .,data = train,
+                  # 梯度提升树模型
+                  method = "gbm",
+                  trControl = ctrl,
+                  verbose = F,
+                  tuneGrid = gbm.grid,
+                  metric = "ROC")
>
> fit.gbm
## Stochastic Gradient Boosting 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 140, 142, 142, 141, 141, 141, ... 
## Resampling results across tuning parameters:
## 
##   interaction.depth  n.trees  ROC        Sens       Spec     
##   1                   50      0.8768105  0.8397222  0.7094643
##   1                  100      0.8934425  0.8737500  0.7714286
##   1                  150      0.8991369  0.8656944  0.7801786
##   1                  200      0.8993031  0.8619444  0.7773214
##   1                  250      0.8987153  0.8650000  0.7789286
##   1                  300      0.9034747  0.8700000  0.7803571
##   5                   50      0.8890377  0.8586111  0.7505357
##   5                  100      0.8994891  0.8694444  0.7746429
##   5                  150      0.9019147  0.8626389  0.7908929
##   5                  200      0.9028720  0.8637500  0.7876786
##   5                  250      0.9028844  0.8700000  0.7876786
##   5                  300      0.9028299  0.8695833  0.7880357
##   9                   50      0.8984772  0.8515278  0.7678571
##   9                  100      0.9083358  0.8659722  0.7864286
##   9                  150      0.9162029  0.8776389  0.7983929
##   9                  200      0.9153100  0.8905556  0.7971429
##   9                  250      0.9168998  0.8902778  0.7991071
##   9                  300      0.9188070  0.8843056  0.7957143
## 
## Tuning parameter 'shrinkage' was held constant at a value of
##  0.1
## Tuning parameter 'n.minobsinnode' was held constant at a value
##  of 20
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees =
##  300, interaction.depth = 9, shrinkage = 0.1 and n.minobsinnode = 20.

最终选择的参数为n.trees = 300, interaction.depth = 9, shrinkage = 0.1, n.minobsinnode = 20。
plot函数可用于检查模型的性能估计值与调整参数之间的关系:

> trellis.par.set(caretTheme())
> plot(fit.gbm)
ROC与参数之间的关系

可以看到参数n.trees = 300, interaction.depth = 9时ROC达到最高点。
plot中使用metric参数可以查看其他性能指标,比如要查看“Kappa”指标,可以指定:metric = "Kappa"。本模型没有记录Kappa性能,所以无法查看。

查看拟合结果热力图,结果与上图一样,n.trees = 300, interaction.depth = 9时颜色最深:

> trellis.par.set(caretTheme())
> plot(fit.gbm, metric = "ROC", plotType = "level", scales = list(x = list(rot = 90)))
ROC热力图

也可以使用ggplot2包:

> ggplot(fit.gbm)
ggplot2画图

xyplot和stripplot可用于绘制针对(数值型)调整参数的重采样统计信息。
histogram和densityplot还可用于查看调整参数在调整参数之间的分布。

> trellis.par.set(caretTheme())
> densityplot(fit.gbm, pch = "|", resamples = "all")
ROC密度图

3、模型选择

tolerance()函数可用于找到不太复杂的模型,例如,要基于2%的性能损失选择参数值:

> which.pct <- tolerance(fit.gbm, metric = "ROC", tol = 2, maximize = T)
>
> fit.gbm$results[which.pct, 1:6]
##      shrinkage interaction.depth   n.minobsinnode   n.trees       ROC   Sens
##  8       0.1                5              20         100     0.9026612 0.8675

最后选择了一个不那么复杂的模型,在ROC曲线下的面积为0.9026612。

如果拟合了多个模型,可以通过resamples()函数对他们的性能差异做出统计报表:

> set.seed(123)
> fit.svm <- train(Class ~ .,data=train,
>                  method = "svmRadial", 
>                  trControl = ctrl, 
>                  preProc = c("center", "scale"),
>                  tuneLength = 8,
>                  metric = "ROC")
> 
> set.seed(123)
> fit.rda <- train(Class ~ ., data=train, 
>                  method = "rda", 
>                  trControl = ctrl, 
>                  tuneLength = 4,
>                  metric = "ROC")
>
> resamps <- resamples(list(GBM = fit.gbm,
>                           SVM = fit.svm,
>                           RDA = fit.rda))
>
> resamps
## Call:
## resamples.default(x = list(GBM = fit.gbm, SVM = fit.svm, RDA
##  = fit.rda))
## 
## Models: GBM, SVM, RDA 
## Number of resamples: 100 
## Performance metrics: ROC, Sens, Spec 
## Time estimates for: everything, final model fit 
> summary(resamps)
## Call:
## summary.resamples(object = resamps)
## 
## Models: GBM, SVM, RDA 
## Number of resamples: 100 
## 
## ROC 
##          Min.   1st Qu.    Median      Mean   3rd Qu. Max. NA's
## GBM 0.7142857 0.8700397 0.9186508 0.9107093 0.9683780    1    0
## SVM 0.7031250 0.8888889 0.9285714 0.9224578 0.9821429    1    0
## RDA 0.5781250 0.8571429 0.9206349 0.8978720 0.9598214    1    0
## 
## Sens 
##          Min.   1st Qu.    Median      Mean   3rd Qu. Max. NA's
## GBM 0.5000000 0.7777778 0.8888889 0.8798611 1.0000000    1    0
## SVM 0.5555556 0.8750000 0.8750000 0.8825000 0.9166667    1    0
## RDA 0.6250000 0.7777778 0.8750000 0.8747222 1.0000000    1    0
## 
## Spec 
##          Min.   1st Qu.    Median      Mean   3rd Qu. Max. NA's
## GBM 0.4285714 0.7142857 0.8571429 0.7819643 0.8750000    1    0
## SVM 0.3750000 0.7142857 0.7500000 0.7835714 0.8750000    1    0
## RDA 0.2857143 0.5714286 0.7142857 0.7164286 0.8571429    1    0

画个图看看:

trellis.par.set(caretTheme())
dotplot(resamps, metric = "ROC")
ROC对比图

通过ROC性能对比,SVM模型高于GBM模型高于RDA模型。
由于设置了同样的随机数种子,使用同样的数据训练模型,因此对模型的差异进行推断是有意义的。通过这种方法,降低了存在内部重采样相关性的可能,然后使用简单的t检验来评估模型之间没有差异的零假设:

> dif.value <- diff(resamps)
> summary(dif.value)
## Call:
## summary.diff.resamples(object = dif.value)
## 
## p-value adjustment: bonferroni 
## Upper diagonal: estimates of the difference
## Lower diagonal: p-value for H0: difference = 0
## 
## ROC 
##     GBM      SVM      RDA     
## GBM          -0.01175  0.01284
## SVM 0.614698           0.02459
## RDA 0.844504 0.005992         
## 
## Sens 
##     GBM SVM       RDA      
## GBM     -0.002639  0.005139
## SVM 1              0.007778
## RDA 1   1                  
## 
## Spec 
##     GBM      SVM       RDA      
## GBM          -0.001607  0.065536
## SVM 1.000000            0.067143
## RDA 0.029286 0.004865 
trellis.par.set(caretTheme())
dotplot(dif.value)
模型之间差异性检验

4、拟合最终模型

当最优模型和最优参数已经找到时,可以使用最优参数直接拟合模型:

> fit.ctrl <- trainControl(method = "none", classProbs = TRUE)
> 
> set.seed(123)
> fit.final <- train(Class ~ .,data=train,
>                    method = "svmRadial", 
>                    trControl = fit.ctrl, 
>                    preProc = c("center", "scale"),
>                    tuneGrid = fit.svm$bestTune,
>                    metric = "ROC")
> 
> fit.final
## Support Vector Machines with Radial Basis Function Kernel 
## 
## 157 samples
##  60 predictor
##   2 classes: 'M', 'R' 
## 
## Pre-processing: centered (60), scaled (60) 
## Resampling: None 
> pred.final <- predict(fit.final,newdata=test)
> confusionMatrix(pred.final,test$Class)
## Stochastic Gradient Boosting 
## 
## 157 samples
##  60 predictor
##   2 classes: 'M', 'R' 
## 
## No pre-processing
## Resampling: None
> pred.final <- predict(fit.gbm.final, newdata = test)
> confusionMatrix(pred.final, test$Class)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  M  R
##          M 23  2
##          R  4 22
##                                           
##                Accuracy : 0.8824          
##                  95% CI : (0.7613, 0.9556)
##     No Information Rate : 0.5294          
##     P-Value [Acc > NIR] : 8.488e-08       
##                                           
##                   Kappa : 0.765           
##                                           
##  Mcnemar's Test P-Value : 0.6831          
##                                           
##             Sensitivity : 0.8519          
##             Specificity : 0.9167          
##          Pos Pred Value : 0.9200          
##          Neg Pred Value : 0.8462          
##              Prevalence : 0.5294          
##          Detection Rate : 0.4510          
##    Detection Prevalence : 0.4902          
##       Balanced Accuracy : 0.8843          
##                                           
##        'Positive' Class : M
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 217,542评论 6 504
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,822评论 3 394
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 163,912评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,449评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,500评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,370评论 1 302
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,193评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,074评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,505评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,722评论 3 335
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,841评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,569评论 5 345
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,168评论 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,783评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,918评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,962评论 2 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,781评论 2 354

推荐阅读更多精彩内容