ROC计算流程简述与实现

计算流程

输入shape都是N的preds和targets

一、获取分数的降序索引desc_score_indices:

desc_score_indices = torch.argsort(preds, descending=True)

二、使用分数的降序索引 desc_score_indices 获取降序排序后的 preds 和 targets :

preds = preds[desc_score_indices]
targets = targets[desc_score_indices]

三、使用降序排序的 preds ,获取分数下降位置的索引distinct_value_indices:

distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]

四、使用在 distinct_value_indices 末尾添加 N-1,构建阈值索引 threshold_idxs :

threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=targets.size(0) - 1)

五、获取各个阈值对应的 tp:
累加 targets,使用 threshold_idxs 获取各个阈值对应的 tp。

tps = torch.cumsum(targets, dim=0)[threshold_idxs]

六、获取各个阈值对应的fp:
方法1: 累加 1 - targets,使用 threshold_idxs 获取各个阈值对应的 fp。

fps = torch.cumsum((1 - targets), dim=0)[threshold_idxs]

方法2: 通过 threshold_idxs 与 tps 计算,因为 threshold_idxs + 1 可以表示有多少样本大于等于对应阈值,即有多少个样本被判断为正样本。

fps = 1 + threshold_idxs - tps

七、获取阈值列表 thresholds:

thresholds = preds[threshold_idxs]

八、增加额外的阈值位,保证ROC曲线从(0,0)开始:

tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])

九、判断 tps 与 fps 是否有效,并计算 tpr 与 fpr:

if tps[-1] <= 0:
    raise ValueError("No positive samples in targets, true positive value should be meaningless")
tpr = tps / tps[-1]

if fps[-1] <= 0:
    raise ValueError("No negative samples in targets, false positive value should be meaningless")
fpr = fps / fps[-1]

整体实现

使用 pytorch 实现,参考 torchmetrics,源代码中包含 pos_label 和 sample_weights 参数,这里没有使用只是简单实现。

import torch
from torch import Tensor
from torch.nn import functional as F

def roc_compute_single_class(preds, targets):

    #获取分数的降序索引desc_score_indices
    desc_score_indices = torch.argsort(preds, descending=True)

    #使用分数的降序索引 desc_score_indices 获取降序排序后的 preds 和 targets 
    preds = preds[desc_score_indices]
    targets = targest[desc_score_indices]

    #使用降序排序的 preds ,获取分数下降位置的索引distinct_value_indices
    distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]

    #使用在 distinct_value_indices 末尾添加 N-1,构建阈值索引 threshold_idxs 
    threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=targets.size(0) - 1)

    #获取各个阈值对应的 tp
    tps = torch.cumsum(targets, dim=0)[threshold_idxs]

    #获取各个阈值对应的fp
    fps = 1 + threshold_idxs - tps

    #获取阈值列表 thresholds
    thresholds = preds[threshold_idxs]

    #增加额外的阈值位,保证ROC曲线从(0,0)开始
    tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
    fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
    thresholds = torch.cat([thresholds[0][None] + 1, thresholds])

    #判断 tps 与 fps 是否有效,并计算 tpr 与 fpr
    if fps[-1] <= 0:
        raise ValueError("No negative samples in targets, false positive value should be meaningless")
    fpr = fps / fps[-1]

    if tps[-1] <= 0:
        raise ValueError("No positive samples in targets, true positive value should be meaningless")
    tpr = tps / tps[-1]

    return fpr, tpr, thresholds
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

友情链接更多精彩内容