R-多分类logistic回归(机器学习)

多分类logistic回归

在之前文章介绍了,如何在R里面处理多分类的回归模型,得到的是各个因素的系数及相对OR,但是解释性,比二元logistic回归方程要冗杂的多。

那么今天继续前面的基础上,用机器学习的方法来解释多分类问题。
其实最终回归到这类分类问题的本质:有了一系列的影响因素x,那么根据这些影响因素来判断最终y属于哪一类别。


image.png

1.数据案例

这里主要用到DALEX包里面包含的HR数据,里面记录了职工在工作岗位的状态与年龄,性别,工作时长,评价及薪水有关。根据7847条记录来评估,如果一个职工属于男性,68岁,薪水及评价处于3等级,那么该职工可能会处于什么状态。

library(DALEX)
library(iBreakDown)
library(car)
library(questionr)
try(data(package="DALEX"))
data(HR)

# split
set.seed(543)
ind = sample(2,nrow(HR),replace=TRUE,prob=c(0.9,0.1))
trainData = HR[ind==1,]
testData = HR[ind==2,]

# randforest
m_rf = randomForest(status ~ . , data = trainData)

2.随机森林模型

我们根据上述数据,分成训练集与测试集(Train and Test)测试集用来估计随机森林模型的效果。

2.1模型评估

通过对Train数据构建rf模型后,我们对Train数据进行拟合,看一下模型的效果,Accuracy : 0.9357 显示很好,kappa一致性为90%。
那再用该fit去预测test数据, Accuracy : 0.7166 , Kappa : 56% ,显示效果不怎么理想。

# Prediction and Confusion Matrix - Training data 
pred1 <- predict(m_rf, trainData)
head(pred1)
confusionMatrix(pred1, trainData$status)  #

pred2 <- predict(m_rf, testData)
head(pred2)
confusionMatrix(pred2, testData$status)  #

> confusionMatrix(pred1, trainData$status)  #
Confusion Matrix and Statistics

          Reference
Prediction fired   ok promoted
  fired     2478  194       49
  ok          43 1738       80
  promoted    25   64     2375

Overall Statistics
                                          
               Accuracy : 0.9354          
                 95% CI : (0.9294, 0.9411)
    No Information Rate : 0.3613          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.9024          
                                          
 Mcnemar's Test P-Value : < 2.2e-16       

Statistics by Class:

                     Class: fired Class: ok Class: promoted
Sensitivity                0.9733    0.8707          0.9485
Specificity                0.9460    0.9756          0.9804
Pos Pred Value             0.9107    0.9339          0.9639
Neg Pred Value             0.9843    0.9502          0.9718
Prevalence                 0.3613    0.2833          0.3554
Detection Rate             0.3517    0.2467          0.3371
Detection Prevalence       0.3862    0.2641          0.3497
Balanced Accuracy          0.9596    0.9232          0.9644
> 
> pred2 <- predict(m_rf, testData)
> head(pred2)
    1    20    36    42    49    56 
fired fired fired fired fired    ok 
Levels: fired ok promoted
> confusionMatrix(pred2, testData$status)  #
Confusion Matrix and Statistics

          Reference
Prediction fired  ok promoted
  fired      246  62       19
  ok          37 117       37
  promoted    26  46      211

Overall Statistics
                                         
               Accuracy : 0.7166         
                 95% CI : (0.684, 0.7476)
    No Information Rate : 0.3858         
    P-Value [Acc > NIR] : < 2e-16        
                                         
                  Kappa : 0.5692         
                                         
 Mcnemar's Test P-Value : 0.03881        

Statistics by Class:

                     Class: fired Class: ok Class: promoted
Sensitivity                0.7961    0.5200          0.7903
Specificity                0.8354    0.8715          0.8652
Pos Pred Value             0.7523    0.6126          0.7456
Neg Pred Value             0.8671    0.8230          0.8919
Prevalence                 0.3858    0.2809          0.3333
Detection Rate             0.3071    0.1461          0.2634
Detection Prevalence       0.4082    0.2385          0.3533
Balanced Accuracy          0.8157    0.6958          0.8277

2.2变量重要性

我们看到,对影响因素进行重要性排序,等同于P值。在预测时候,哪些因素对y占影响比重较大。这里的variable_importance(),可以有好几种方式对变量进行衡量,这里采用默认的MeanDecreaseGini.

# vip
vip(m_rf)
var=randomForest::importance(m_rf)
var
image.png

2.2边际效应

我们知道了hours,age比较重要,那么是如何重要的,譬如年龄在什么阶段,会导致升职或者开除。
当工作小时在45以内,被开除/离职的概率较大,当工作时常超过60以后,很有可能会被提升。得到升职加薪的机会。
当然了,也可以绘制2D的边际效应,两个因素相互作用的Partial plot

# partial plot
partialPlot(m_rf, HR, age)
head(partial(m_rf, pred.var = "age"))  # returns a data frame

# for all varibles
nm=rownames(var)
# Get partial depedence values for top predictors
pd_df <- partial_dependence(fit = m_rf,
                            vars = nm,
                            data = df_rf,
                            n = c(100, 200))
                        
# Plot partial dependence using edarf
plot_pd(pd_df)
image.png
image.png

2.3个体预测

现在假如有一个员工的信息如下,

      gender      age    hours evaluation salary   status
10000 female 57.96254 54.78624          4      4 promoted

去预测该职工最后的状态:
该预测结果显示,这个职工,有97%的可能性要升职加薪。而他的实际状态也是Promoted。

new_observation=tail(HR,1)
p_fun <- function(object, newdata){predict(object, newdata = newdata, type = "prob")}
bd_rf <- local_attributions(m_rf,
                            data = HR_test,
                            new_observation =  new_observation,
                            predict_function = p_fun)

bd_rf
plot(bd_rf)
image.png
> sessionInfo()
R version 3.6.2 (2019-12-12)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS Mojave 10.14

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  utils     datasets  grDevices methods   base     

other attached packages:
 [1] edarf_1.1.1         ranger_0.12.1       questionr_0.7.0     car_3.0-7          
 [5] carData_3.0-3       nnet_7.3-14         DALEX_1.2.1         vip_0.2.2          
 [9] ggpubr_0.3.0        rstatix_0.5.0       caret_6.0-86        lattice_0.20-41    
[13] pdp_0.7.0           randomForest_4.6-14 iBreakDown_1.2.0    hrbrthemes_0.8.0   
[17] reshape2_1.4.4      RColorBrewer_1.1-2  forcats_0.5.0       stringr_1.4.0      
[21] dplyr_0.8.5         purrr_0.3.4         readr_1.3.1         tidyr_1.0.3        
[25] tibble_3.0.1        ggplot2_3.3.0       tidyverse_1.3.0    

参考

  1. iBreakDown plots for classification models
  2. prediction 预测结果输出为概率
  3. pdp 边际效应
  4. Partial dependence (PD) plots For Random Forests
  5. Explaining Black-Box Machine Learning Models
  6. Interpretable Machine Learning
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,254评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,875评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,682评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,896评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,015评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,152评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,208评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,962评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,388评论 1 304
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,700评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,867评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,551评论 4 335
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,186评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,901评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,142评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,689评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,757评论 2 351