计算流程
输入shape都是N的preds和targets
一、获取分数的降序索引desc_score_indices:
desc_score_indices = torch.argsort(preds, descending=True)
二、使用分数的降序索引 desc_score_indices 获取降序排序后的 preds 和 targets :
preds = preds[desc_score_indices]
targets = targets[desc_score_indices]
三、使用降序排序的 preds ,获取分数下降位置的索引distinct_value_indices:
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
四、使用在 distinct_value_indices 末尾添加 N-1,构建阈值索引 threshold_idxs :
threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=targets.size(0) - 1)
五、获取各个阈值对应的 tp:
累加 targets,使用 threshold_idxs 获取各个阈值对应的 tp。
tps = torch.cumsum(targets, dim=0)[threshold_idxs]
六、获取各个阈值对应的fp:
方法1: 累加 1 - targets,使用 threshold_idxs 获取各个阈值对应的 fp。
fps = torch.cumsum((1 - targets), dim=0)[threshold_idxs]
方法2: 通过 threshold_idxs 与 tps 计算,因为 threshold_idxs + 1 可以表示有多少样本大于等于对应阈值,即有多少个样本被判断为正样本。
fps = 1 + threshold_idxs - tps
七、获取阈值列表 thresholds:
thresholds = preds[threshold_idxs]
八、增加额外的阈值位,保证ROC曲线从(0,0)开始:
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])
九、判断 tps 与 fps 是否有效,并计算 tpr 与 fpr:
if tps[-1] <= 0:
raise ValueError("No positive samples in targets, true positive value should be meaningless")
tpr = tps / tps[-1]
if fps[-1] <= 0:
raise ValueError("No negative samples in targets, false positive value should be meaningless")
fpr = fps / fps[-1]
整体实现
使用 pytorch 实现,参考 torchmetrics,源代码中包含 pos_label 和 sample_weights 参数,这里没有使用只是简单实现。
import torch
from torch import Tensor
from torch.nn import functional as F
def roc_compute_single_class(preds, targets):
#获取分数的降序索引desc_score_indices
desc_score_indices = torch.argsort(preds, descending=True)
#使用分数的降序索引 desc_score_indices 获取降序排序后的 preds 和 targets
preds = preds[desc_score_indices]
targets = targest[desc_score_indices]
#使用降序排序的 preds ,获取分数下降位置的索引distinct_value_indices
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
#使用在 distinct_value_indices 末尾添加 N-1,构建阈值索引 threshold_idxs
threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=targets.size(0) - 1)
#获取各个阈值对应的 tp
tps = torch.cumsum(targets, dim=0)[threshold_idxs]
#获取各个阈值对应的fp
fps = 1 + threshold_idxs - tps
#获取阈值列表 thresholds
thresholds = preds[threshold_idxs]
#增加额外的阈值位,保证ROC曲线从(0,0)开始
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])
#判断 tps 与 fps 是否有效,并计算 tpr 与 fpr
if fps[-1] <= 0:
raise ValueError("No negative samples in targets, false positive value should be meaningless")
fpr = fps / fps[-1]
if tps[-1] <= 0:
raise ValueError("No positive samples in targets, true positive value should be meaningless")
tpr = tps / tps[-1]
return fpr, tpr, thresholds