#1. 安装
> install.packages("yardstick")
#2. 调用包
> library(yardstick)
> library(dplyr)
> library(ggplot2)
#3. ROC 分析
##3.1 示例数据
> data("hpc_cv")
> hpc_cv <- as_tibble(hpc_cv)
> hpc_cv
#> # A tibble: 3,467 x 7
#> obs pred VF F M L Resample
#> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr>
#> 1 VF VF 0.914 0.0779 0.00848 0.0000199 Fold01
#> 2 VF VF 0.938 0.0571 0.00482 0.0000101 Fold01
#> 3 VF VF 0.947 0.0495 0.00316 0.00000500 Fold01
#> 4 VF VF 0.929 0.0653 0.00579 0.0000156 Fold01
#> 5 VF VF 0.942 0.0543 0.00381 0.00000729 Fold01
#> 6 VF VF 0.951 0.0462 0.00272 0.00000384 Fold01
#> 7 VF VF 0.914 0.0782 0.00767 0.0000354 Fold01
#> 8 VF VF 0.918 0.0744 0.00726 0.0000157 Fold01
#> 9 VF VF 0.843 0.128 0.0296 0.000192 Fold01
#> 10 VF VF 0.920 0.0728 0.00703 0.0000147 Fold01
#> # … with 3,457 more rows
##3.2 roc_auc()计算 ROC AUC
hpc_cv %>%
group_by(Resample) %>%
roc_auc(obs, VF:L)
#> # A tibble: 10 x 4
#> Resample .metric .estimator .estimate
#> <chr> <chr> <chr> <dbl>
#> 1 Fold01 roc_auc hand_till 0.831
#> 2 Fold02 roc_auc hand_till 0.817
#> 3 Fold03 roc_auc hand_till 0.869
#> 4 Fold04 roc_auc hand_till 0.849
#> 5 Fold05 roc_auc hand_till 0.811
#> 6 Fold06 roc_auc hand_till 0.836
#> 7 Fold07 roc_auc hand_till 0.825
#> 8 Fold08 roc_auc hand_till 0.846
#> 9 Fold09 roc_auc hand_till 0.836
#> 10 Fold10 roc_auc hand_till 0.820
##3.3 roc_curve() 获取ROC曲线
hpc_cv %>%
group_by(Resample) %>%
roc_curve(obs, VF:L) %>%
autoplot()
#4. precision recall curve
##4.1 示例数据
> head(two_class_example)
truth Class1 Class2 predicted
1 Class2 0.003589243 0.9964107574 Class2
2 Class1 0.678621054 0.3213789460 Class1
3 Class2 0.110893522 0.8891064779 Class2
4 Class1 0.735161703 0.2648382969 Class1
5 Class2 0.016239960 0.9837600397 Class2
6 Class1 0.999275071 0.0007249286 Class1
##4.2 pr_curve() 计算precision recall curve
# Two class - a tibble is returned
pr_curve(two_class_example, truth, Class1)
#> # A tibble: 501 x 3
#> .threshold recall precision
#> <dbl> <dbl> <dbl>
#> 1 Inf 0 NA
#> 2 1.000 0.00388 1
#> 3 1.000 0.00775 1
#> 4 1.000 0.0116 1
#> 5 1.000 0.0155 1
#> 6 1.000 0.0194 1
#> 7 1.000 0.0233 1
#> 8 1.000 0.0271 1
#> 9 1.000 0.0310 1
#> 10 1.000 0.0349 1
#> # … with 491 more rows
##4.3 使用ggplot2 可视化precision recall curve
pr_curve(two_class_example, truth, Class1) %>%
ggplot(aes(x = recall, y = precision)) +
geom_path() +
coord_equal() +
theme_bw()
##4.4 使用autoplot()可视化precision recall curve
> autoplot(pr_curve(two_class_example, truth, Class1))
##4.5 多个水平的precision recall curve
- obs下每个水平的precision recall curve
#obs 下多个水平
hpc_cv %>%
filter(Resample == "Fold01") %>%
count(obs)
# A tibble: 4 x 2
obs n
<fct> <int>
1 VF 177
2 F 108
3 M 41
4 L 21
#作图
hpc_cv %>%
filter(Resample == "Fold01") %>%
pr_curve(obs, VF:L) %>%
autoplot()
- 依据Resample下各个水平分组,obs下每个水平precision recall curve
hpc_cv %>%
group_by(Resample) %>%
pr_curve(obs, VF:L) %>%
autoplot()
One curve per level group_by.png
#5 参考
yardstick package | R Documentation
系列文章:
使用yardstick包进行ROC分析
使用 pROC包进行ROC分析