数据分析:Stochastic Gradient Boosting(随机梯度boosting)

介绍

Boosting是机器学习常用的方法,其中随机梯度boosting更是常见的机器学习算法,可用于构建分类器和回归分析。

加载数据

library(tidyverse)
library(ISLR)
library(caret)
library(pROC)

ml_data <- College
ml_data %>%
  glimpse()
Rows: 777
Columns: 18
$ Private     <fct> Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, Yes, No, Yes, …
$ Apps        <dbl> 1660, 2186, 1428, 417, 193, 587, 353, 1899, 1038, 582, 1732, 2652, 1179, 1267, 494, 1420, 4302, 1216, 11…
$ Accept      <dbl> 1232, 1924, 1097, 349, 146, 479, 340, 1720, 839, 498, 1425, 1900, 780, 1080, 313, 1093, 992, 908, 704, 2…
$ Enroll      <dbl> 721, 512, 336, 137, 55, 158, 103, 489, 227, 172, 472, 484, 290, 385, 157, 220, 418, 423, 322, 1016, 252,…
$ Top10perc   <dbl> 23, 16, 22, 60, 16, 38, 17, 37, 30, 21, 37, 44, 38, 44, 23, 9, 83, 19, 14, 24, 25, 20, 20, 24, 46, 12, 2…
$ Top25perc   <dbl> 52, 29, 50, 89, 44, 62, 45, 68, 63, 44, 75, 77, 64, 73, 46, 22, 96, 40, 23, 54, 44, 63, 51, 49, 74, 52, …
$ F.Undergrad <dbl> 2885, 2683, 1036, 510, 249, 678, 416, 1594, 973, 799, 1830, 1707, 1130, 1306, 1317, 1018, 1593, 1819, 15…
$ P.Undergrad <dbl> 537, 1227, 99, 63, 869, 41, 230, 32, 306, 78, 110, 44, 638, 28, 1235, 287, 5, 281, 326, 1512, 23, 1035, …
$ Outstate    <dbl> 7440, 12280, 11250, 12960, 7560, 13500, 13290, 13868, 15595, 10468, 16548, 17080, 9690, 12572, 8352, 870…
$ Room.Board  <dbl> 3300, 6450, 3750, 5450, 4120, 3335, 5720, 4826, 4400, 3380, 5406, 4440, 4785, 4552, 3640, 4780, 5300, 35…
$ Books       <dbl> 450, 750, 400, 450, 800, 500, 500, 450, 300, 660, 500, 400, 600, 400, 650, 450, 660, 550, 900, 500, 400,…
$ Personal    <dbl> 2200, 1500, 1165, 875, 1500, 675, 1500, 850, 500, 1800, 600, 600, 1000, 400, 2449, 1400, 1598, 1100, 132…
$ PhD         <dbl> 70, 29, 53, 92, 76, 67, 90, 89, 79, 40, 82, 73, 60, 79, 36, 78, 93, 48, 62, 60, 69, 83, 55, 88, 79, 57, …
$ Terminal    <dbl> 78, 30, 66, 97, 72, 73, 93, 100, 84, 41, 88, 91, 84, 87, 69, 84, 98, 61, 66, 62, 82, 96, 65, 93, 88, 60,…
$ S.F.Ratio   <dbl> 18.1, 12.2, 12.9, 7.7, 11.9, 9.4, 11.5, 13.7, 11.3, 11.5, 11.3, 9.9, 13.3, 15.3, 11.1, 14.7, 8.4, 12.1, …
$ perc.alumni <dbl> 12, 16, 30, 37, 2, 11, 26, 37, 23, 15, 31, 41, 21, 32, 26, 19, 63, 14, 18, 5, 35, 14, 25, 5, 24, 5, 30, …
$ Expend      <dbl> 7041, 10527, 8735, 19016, 10922, 9727, 8861, 11487, 11644, 8991, 10932, 11711, 7940, 9305, 8127, 7355, 2…
$ Grad.Rate   <dbl> 60, 56, 54, 59, 15, 55, 63, 73, 80, 52, 73, 76, 74, 68, 55, 69, 100, 59, 46, 34, 48, 70, 65, 48, 54, 48,…

训练模型

  • 构建训练集和测试集

  • 训练模型:使用3次5折交叉验证方法并预处理数据

set.seed(123)
index <- createDataPartition(ml_data$Private, p = 0.7, list = FALSE)
train_data <- ml_data[index, ]
test_data  <- ml_data[-index, ]

model_gbm <- train(Private ~ .,
                          data = train_data,
                          method = "gbm",
                          preProcess = c("scale", "center"),
                          trControl = trainControl(method = "repeatedcv", 
                                                  number = 5, 
                                                  repeats = 3, 
                                                  verboseIter = FALSE),
                          verbose = 0)
model_gbm
Stochastic Gradient Boosting 

545 samples
 17 predictor
  2 classes: 'No', 'Yes' 

Pre-processing: scaled (17), centered (17) 
Resampling: Cross-Validated (5 fold, repeated 3 times) 
Summary of sample sizes: 436, 436, 436, 436, 436, 437, ... 
Resampling results across tuning parameters:

  interaction.depth  n.trees  Accuracy   Kappa    
  1                   50      0.9369957  0.8368376
  1                  100      0.9394369  0.8453525
  1                  150      0.9376299  0.8417065
  2                   50      0.9430954  0.8552244
  2                  100      0.9437293  0.8556455
  2                  150      0.9424612  0.8528115
  3                   50      0.9400314  0.8476074
  3                  100      0.9406488  0.8490041
  3                  150      0.9412773  0.8508960

Tuning parameter 'shrinkage' was held constant at a value of 0.1
Tuning parameter 'n.minobsinnode' was held constant at
 a value of 10
Accuracy was used to select the optimal model using the largest value.
The final values used for the model were n.trees = 100, interaction.depth = 2, shrinkage = 0.1 and n.minobsinnode = 10.

结果:模型在n.trees = 100, interaction.depth = 2, shrinkage = 0.1 and n.minobsinnode = 10时获得最佳Accuracy=0.9437293。另外也可以使用summary(model_gbm)查看重要变量重要性分布(按照相对重要性排序:百分比相对标准化)。

summary(model_gbm)
                    var    rel.inf
F.Undergrad F.Undergrad 41.5488790
Outstate       Outstate 37.4947348
P.Undergrad P.Undergrad  5.5553944
S.F.Ratio     S.F.Ratio  3.2261838
Room.Board   Room.Board  2.3599418
Enroll           Enroll  1.8459618
Accept           Accept  1.2306723
PhD                 PhD  1.1096188
Terminal       Terminal  1.0970409
Expend           Expend  0.8743070
Grad.Rate     Grad.Rate  0.8085252
perc.alumni perc.alumni  0.7778578
Top25perc     Top25perc  0.6229050
Top10perc     Top10perc  0.4310016
Apps               Apps  0.4217785
Personal       Personal  0.3608742
Books             Books  0.2343231

预测结果

predict函数在预测predictors是可以选择type类型,通常分类predictors的有两类type:默认是raw值,在使用pROC包的rocauc函数计算时候,需要使用probability值,通常选择某类的probability值计算即可。

  • raw: 测试样本最后预测的分类label

  • prob:测试样本最后预测为各个分类label的概率

confusionMatrix

caret::confusionMatrix(
  data = predict(model_gbm, test_data),
  reference = test_data$Private
  )
Confusion Matrix and Statistics

          Reference
Prediction  No Yes
       No   52   8
       Yes  11 161
                                        
               Accuracy : 0.9181        
                 95% CI : (0.8751, 0.95)
    No Information Rate : 0.7284        
    P-Value [Acc > NIR] : 3.803e-13     
                                        
                  Kappa : 0.7899        
                                        
 Mcnemar's Test P-Value : 0.6464        
                                        
            Sensitivity : 0.8254        
            Specificity : 0.9527        
         Pos Pred Value : 0.8667        
         Neg Pred Value : 0.9360        
             Prevalence : 0.2716        
         Detection Rate : 0.2241        
   Detection Prevalence : 0.2586        
      Balanced Accuracy : 0.8890        
                                        
       'Positive' Class : No 

confusionMatrix函数给出分类变量的预测值和真实值混淆矩阵和对应的测试样本在模型预测过程的统计结果,如 Accuracy=0.9181等值。

type="raw"

predict(model_gbm, test_data, type = "raw")
# predict(model_gbm, test_data) # 默认type="raw"
 [1] No  Yes Yes Yes Yes Yes Yes Yes Yes No  No  Yes Yes Yes No  Yes Yes Yes Yes No  Yes Yes No  Yes Yes Yes Yes Yes No  Yes Yes Yes
 [33] Yes Yes Yes Yes Yes No  Yes Yes Yes No  No  Yes No  Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes No  Yes No  Yes No  No  Yes
 [65] Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes No  Yes Yes Yes Yes Yes Yes Yes No  Yes Yes Yes Yes Yes Yes Yes Yes Yes
 [97] Yes Yes Yes No  Yes No  Yes No  Yes Yes No  Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes No  Yes Yes Yes No  No  Yes No  Yes No  Yes
[129] Yes Yes Yes Yes Yes No  Yes Yes Yes Yes Yes No  Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes No  No  Yes No  Yes Yes Yes Yes
[161] Yes No  No  Yes Yes Yes Yes No  Yes Yes No  No  Yes No  Yes Yes Yes No  No  Yes No  Yes No  Yes No  No  Yes No  Yes No  No  No 
[193] No  No  Yes Yes Yes Yes No  No  No  No  No  No  Yes Yes No  No  No  Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes No 
[225] Yes Yes Yes Yes Yes Yes Yes No 
Levels: No Yes

type="prob"

predict(model_gbm, test_data, type = "prob")
head(predict(model_gbm, test_data, type = "prob"))
           No       Yes
1 0.734880487 0.2651195
2 0.006084374 0.9939156
3 0.004985657 0.9950143
4 0.062989176 0.9370108
5 0.005712712 0.9942873
6 0.005905355 0.9940946

ROC曲线

  • 获取ROC曲线:通过pROC的roc和auc函数分别获取roc对象和auc值
 library(pROC)
library(ggplot2)

rocobj <- roc(test_data$Private, predict(model_gbm, newdata = test_data, type = "prob")[, "No"])
auc <- round(auc(test_data$Private, predict(model_gbm, newdata = test_data, type = "prob")[, "Yes"]),4)

ggroc(rocobj, color = "red", linetype = 1, size = 1, alpha = 1, legacy.axes = T)+
                geom_abline(intercept = 0, slope = 1, color="grey", size = 1, linetype=1)+
              labs(x = "False Positive Rate (1 - Specificity)",
                   y = "True Positive Rate (Sensivity or Recall)")+
              annotate("text",x = .75, y = .25,label=paste("AUC =", auc),
                       size = 5, family="serif")+
              coord_cartesian(xlim = c(0, 1), ylim = c(0, 1))+
              theme_bw()+
              theme(panel.background = element_rect(fill = 'transparent'),
                    axis.ticks.length = unit(0.4, "lines"), 
                    axis.ticks = element_line(color='black'),
                    axis.line = element_line(size=.5, colour = "black"),
                    axis.title = element_text(colour='black', size=12,face = "bold"),
                    axis.text = element_text(colour='black',size=10,face = "bold"),
                    text = element_text(size=8, color="black", family="serif"))

问题

问题:为什么模型对测试样本处理时,pROC计算出来的AUC和模型给的Accuracy值是不一样的呢?

答:AUC是ROC下的面积,ROC折线每个点对应的阈值确定了该点的Accuracy、Precision和Recall等等的度量,所以AUC是一系列Accuracy的综合。 AUC衡量模型好坏,Accuracy衡量模型在某个特定阈值下的预测准确度。

首先,AUC对应的不是一个accuracy,而是一系列accuracy。AUC是ROC的"线下面积",而ROC是以FPR-TPR为坐标的一条线,实际上是连接一系列散点的一条折线。这条折线上的每一个点,对应了一个threshold,以及由这个threshold确定的预测值及其accuracy、precision、recall等等的度量。所以说,AUC衡量的是一个模型的好坏,是它给所有sample排序的合理程度(是不是正确地把负例排在了正例的前面);而accuracy衡量的是一个模型在一个特定threshold(比如,logistic regression模型在阈值1/2)下的预测准确度(是不是正确地把负例排在了阈值之前,正例排在了阈值之后)。因此,AUC高而accuracy低或者accuracy高AUC低的情况有没有可能?有。一个模型定了,它的AUC就定了。但我可以取一个threshold,使得它的accuracy尽量低或者尽量高(有上限和下限)。

R Information

sessionInfo()
R version 4.0.3 (2020-10-10)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Big Sur 10.16

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] ISLR_1.2            forcats_0.5.0       stringr_1.4.0       purrr_0.3.4         readr_1.4.0         tidyr_1.1.2        
 [7] tidyverse_1.3.0     xgboost_1.3.1.1     mlbench_2.1-1       survminer_0.4.8     ggpubr_0.4.0        survcomp_1.40.0    
[13] prodlim_2019.11.13  survival_3.2-7      caretEnsemble_2.0.1 pROC_1.16.2         caret_6.0-86        ggplot2_3.3.3      
[19] lattice_0.20-41     data.table_1.13.6   tibble_3.0.4        dplyr_1.0.2        

loaded via a namespace (and not attached):
  [1] readxl_1.3.1         backports_1.2.0      plyr_1.8.6           splines_4.0.3        digest_0.6.27       
  [6] SuppDists_1.1-9.5    foreach_1.5.1        htmltools_0.5.0      fansi_0.4.1          magrittr_1.5        
 [11] openxlsx_4.2.3       recipes_0.1.15       modelr_0.1.8         gower_0.2.2          colorspace_2.0-0    
 [16] rvest_0.3.6          haven_2.3.1          xfun_0.19            crayon_1.3.4         jsonlite_1.7.1      
 [21] libcoin_1.0-7        zoo_1.8-8            iterators_1.0.13     glue_1.4.2           gtable_0.3.0        
 [26] ipred_0.9-9          questionr_0.7.3      car_3.0-10           kernlab_0.9-29       abind_1.4-5         
 [31] scales_1.1.1         mvtnorm_1.1-1        DBI_1.1.0            rstatix_0.6.0        miniUI_0.1.1.1      
 [36] Rcpp_1.0.5           xtable_1.8-4         Cubist_0.2.3         foreign_0.8-80       km.ci_0.5-2         
 [41] Formula_1.2-4        stats4_4.0.3         lava_1.6.8.1         httr_1.4.2           ellipsis_0.3.1      
 [46] pkgconfig_2.0.3      farver_2.0.3         nnet_7.3-14          dbplyr_2.0.0         utf8_1.1.4          
 [51] tidyselect_1.1.0     labeling_0.4.2       rlang_0.4.8          reshape2_1.4.4       later_1.1.0.1       
 [56] munsell_0.5.0        cellranger_1.1.0     tools_4.0.3          cli_2.1.0            generics_0.1.0      
 [61] broom_0.7.3          evaluate_0.14        fastmap_1.0.1        yaml_2.2.1           bootstrap_2019.6    
 [66] ModelMetrics_1.2.2.2 knitr_1.30           fs_1.5.0             zip_2.1.1            survMisc_0.5.5      
 [71] caTools_1.18.0       randomForest_4.6-14  pbapply_1.4-3        nlme_3.1-150         mime_0.9            
 [76] xml2_1.3.2           compiler_4.0.3       rstudioapi_0.12      curl_4.3             e1071_1.7-4         
 [81] ggsignif_0.6.0       reprex_0.3.0         klaR_0.6-15          stringi_1.5.3        highr_0.8           
 [86] Matrix_1.2-18        gbm_2.1.8            ggsci_2.9            survivalROC_1.0.3    KMsurv_0.1-5        
 [91] vctrs_0.3.4          pillar_1.4.6         lifecycle_0.2.0      combinat_0.0-8       cowplot_1.1.1       
 [96] bitops_1.0-6         httpuv_1.5.4         R6_2.5.0             promises_1.1.1       KernSmooth_2.23-18  
[101] gridExtra_2.3        C50_0.1.3.1          rio_0.5.16           codetools_0.2-18     MASS_7.3-53         
[106] assertthat_0.2.1     withr_2.3.0          parallel_4.0.3       hms_0.5.3            grid_4.0.3          
[111] rpart_4.1-15         labelled_2.7.0       timeDate_3043.102    class_7.3-17         rmarkdown_2.5       
[116] inum_1.0-1           carData_3.0-4        partykit_1.2-11      shiny_1.5.0          lubridate_1.7.9     
[121] rmeta_3.0 

参考

  1. 在机器学习中AUC和accuracy有什么内在关系?
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
禁止转载,如需转载请通过简信或评论联系作者。
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,384评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,845评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,148评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,640评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,731评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,712评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,703评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,473评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,915评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,227评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,384评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,063评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,706评论 3 324
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,302评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,531评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,321评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,248评论 2 352

推荐阅读更多精彩内容