我们对Logistics回归很熟悉,预测变量y为二分类变量,然后对预测结果进行评估,会用到2*2 Matrix
,计算灵敏度、特异度等及ROC曲线,判断模型预测准确性。
但是如果遇到y为三分类变量,那么会得到3*3 Matrix
那该选用什么指标进行评估呢?
答案:macro-average and micro-average
接下来,我们将介绍如何建立模型预测三分类变量,及对模型准确性进行评估。
1.模型构建
我们根据 iris
数据集中的 Species三分类变量,建立多元回归模型,根据花的特征预测Species种类,其中我们添加xv新变量;
首先我们对 iris
数据集进行拆分成 Training与Testing两个数据集,Training用于模型构建。
# https://stackoverflow.com/questions/59205776/random-forest-svm-and-multinomial-logistic-regression-with-r
library(tidyverse)
library(randomForest)
set.seed(123)
head(iris)
df=iris %>% mutate(xv=as.factor(ifelse(rnorm(150,3,4)<3,"Yes","No"))) # new predictor
## split da
split1= sample(c(rep(0, 0.7 * nrow(df)), rep(1, 0.3 * nrow(df))))
train <- df[split1 == 0, ]
test <- df[split1 == 1, ]
## Model LM
library(nnet)
fit1 = multinom(Species~.,data=train)
summary(fit1)
fit1结果解读比二分类多一个分类。参照OR的解释。
2.观测值VS预测值-Matrix
构建完模型fit1后,需要对testing 数据进行预测,然后我们创建一个真实值与预测值的矩阵。
## Model Prediction
pre=predict(fit1,test)
dfpre=tibble(actual=test$Species,predicted=pre)
table(dfpre)
predicted
actual setosa versicolor virginica
setosa 13 0 0
versicolor 0 13 0
virginica 0 1 18
3.Performance Measures
接下来对该矩阵进行分析,需要预先对矩阵的一些参数进行计算;为后续的
Accuracy, precision, F1等。
Source: https://www.r-bloggers.com/2016/03/computing-classification-evaluation-metrics-in-r/
## basic variables
n = sum(cm) # number of instances
nc = nrow(cm) # number of classes
diag = diag(cm) # number of correctly classified instances per class
rowsums = apply(cm, 1, sum) # number of instances per class
colsums = apply(cm, 2, sum) # number of predictions per class
p = rowsums / n # distribution of instances over the actual classes
q = colsums / n # distribution of instances over the predicted classes
## Accuracy
accuracy = sum(diag) / n
accuracy
precision = diag / colsums
recall = diag / rowsums
f1 = 2 * precision * recall / (precision + recall)
data.frame(precision, recall, f1)
## Macro
macroPrecision = mean(precision)
macroRecall = mean(recall)
macroF1 = mean(f1)
data.frame(macroPrecision, macroRecall, macroF1)
上述计算过程比较繁琐,有没有一键输出的,有!接下来是一键输出
3.1 Performance Measures 一键输出
这里使用 Evaluate 函数进行输出,其中Evaluate代码见连接或后台私信。 Source:https://github.com/saidbleik/Evaluation/blob/master/eval.R
results = Evaluate(actual=df3$ya, predicted=xa)
results
## output
$ConfusionMatrix
Predicted
Actual setosa versicolor virginica
setosa 13 0 0
versicolor 0 13 0
virginica 0 1 18
$Metrics
setosa versicolor virginica
Accuracy 0.9777778 0.9777778 0.9777778
Precision 1.0000000 0.9285714 1.0000000
Recall 1.0000000 1.0000000 0.9473684
F1 1.0000000 0.9629630 0.9729730
MacroAvgPrecision 0.9761905 0.9761905 0.9761905
MacroAvgRecall 0.9824561 0.9824561 0.9824561
MacroAvgF1 0.9786453 0.9786453 0.9786453
AvgAccuracy 0.9851852 0.9851852 0.9851852
MicroAvgPrecision 0.9777778 0.9777778 0.9777778
MicroAvgRecall 0.9777778 0.9777778 0.9777778
MicroAvgF1 0.9777778 0.9777778 0.9777778
MajorityClassAccuracy 0.4222222 0.4222222 0.4222222
MajorityClassPrecision 0.0000000 0.0000000 0.4222222
MajorityClassRecall 0.0000000 0.0000000 1.0000000
MajorityClassF1 0.0000000 0.0000000 0.5937500
Kappa 0.9662162 0.9662162 0.9662162
RandomGuessAccuracy 0.3333333 0.3333333 0.3333333
RandomGuessPrecision 0.2888889 0.2888889 0.4222222
RandomGuessRecall 0.3333333 0.3333333 0.3333333
RandomGuessF1 0.3095238 0.3095238 0.3725490
RandomWeightedGuessAccuracy 0.3451852 0.3451852 0.3451852
RandomWeightedGuessPrecision 0.2888889 0.2888889 0.4222222
RandomWeightedGuessRecall 0.2888889 0.2888889 0.4222222
RandomWeightedGuessF1 0.2888889 0.2888889 0.4222222
4.ROC Curves Across Multi-Class Classifications
当然我们也可以绘制 The ROC curves of micro-average and macro-average, indicating the overall distinguishing ability of the three-class classification. 但是需要分几个步骤进行:
- 我们原来的预测值输出是Species的分类结果,这部分我们需要输出对各种类别的概率值。
- 哑变量设置,将我们的 testing数据集中Species分类改成哑变量
- 计算 macro/micro。并绘制ROC曲线
Source:https://mran.microsoft.com/snapshot/2018-02-12/web/packages/multiROC/vignettes/my-vignette.html
当然这里我们需要提到一个概念:One-vs-all confusion matrices
即针对三个变量转换成,setosa与非setosa;这样就可以得到setosa的ROC
library(multiROC)
actual=dummies::dummy.data.frame(test %>% select(Species),
sep = "_",
dummy.classes = "factor" )
predicted=predict(fit1,test,type = "prob")# with probability
test_data=cbind(actual,predicted)
colnames(test_data)=c("setosa_true","versicolor_true" ,"virginica_true",
"setosa_pred_m1","versicolor_pred_m1","virginica_pred_m1")
res <- multi_roc(test_data, force_diag=T)
res
res里面存储了我们想要的信息,接下来对res进行提取各组的Specificity 与Sensitivity,绘制ROC曲线。
#### ggplot ROC
n_method <- length(unique(res$Methods))
n_group <- length(unique(res$Groups))
res_df <- data.frame(Specificity= numeric(0), Sensitivity= numeric(0), Group = character(0), AUC = numeric(0), Method = character(0))
for (i in 1:n_method) {
for (j in 1:n_group) {
temp_data_1 <- data.frame(Specificity=res$Specificity[[i]][j],
Sensitivity=res$Sensitivity[[i]][j],
Group=unique(res$Groups)[j],
AUC=res$AUC[[i]][j],
Method = unique(res$Methods)[i])
colnames(temp_data_1) <- c("Specificity", "Sensitivity", "Group", "AUC", "Method")
res_df <- rbind(res_df, temp_data_1)
}
temp_data_2 <- data.frame(Specificity=res$Specificity[[i]][n_group+1],
Sensitivity=res$Sensitivity[[i]][n_group+1],
Group= "Macro",
AUC=res$AUC[[i]][n_group+1],
Method = unique(res$Methods)[i])
temp_data_3 <- data.frame(Specificity=res$Specificity[[i]][n_group+2],
Sensitivity=res$Sensitivity[[i]][n_group+2],
Group= "Micro",
AUC=res$AUC[[i]][n_group+2],
Method = unique(res$Methods)[i])
colnames(temp_data_2) <- c("Specificity", "Sensitivity", "Group", "AUC", "Method")
colnames(temp_data_3) <- c("Specificity", "Sensitivity", "Group", "AUC", "Method")
res_df <- rbind(res_df, temp_data_2)
res_df <- rbind(res_df, temp_data_3)
}
ggplot(res_df, aes(x = 1-Specificity, y=Sensitivity)) +
geom_path(aes(color = Group, linetype=Method)) +
geom_segment(aes(x = 0, y = 0, xend = 1, yend = 1), colour='grey', linetype = 'dotdash') +
theme_bw() +
theme(plot.title = element_text(hjust = 0.5),
legend.justification=c(1, 0),
legend.position=c(.95, .05),
legend.title=element_blank(),
legend.background = element_rect(fill=NULL, size=0.5, linetype="solid", colour ="black"))
ggsave("ROC-SVM.pdf",width = 16,height = 12,dpi=500)
最后,附上RF,SVM的模型
#### 2.SVM
library(e1071)
fitsvm = svm(ya~ ., data = df2,probability=TRUE)
summary(fitsvm)
#### 3.RF
library(randomForest)
fitrf = randomForest(ya~ .,
data = df2,
ntree = 300, # parameter setting
mtry = 8,
importance = TRUE,
proximity = TRUE)
参考:
Performance Measures for Multi-Class Problems--
https://www.datascienceblog.net/post/machine-learning/performance-measures-multi-class-problems/