样本不均衡-Focal loss,GHM

Ref:

  1. https://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf
  2. https://zhuanlan.zhihu.com/p/80594704
  3. https://arxiv.org/pdf/1811.05181.pdf

背景

工作中处理二分类问题,数据大多是长尾分布,即正样本远小于负样本。一般来说,通过调整阈值(置信度),就可以满足上线需求。但总是有一些正样本,得分较低,希望找到一些办法,提高这些得分很低的正例分数,且负样本得分不被拉高太多。

模型通过梯度更新进行训练,实际应用中,大部分的样本是容易区分的,而这些样本贡献了主要的loss,模型偏向于这些样本,在部分难区分的样本上效果不好。

所以,为提高模型效果,要解决两个问题:

  1. 如何处理样本不均衡问题?
  2. 如何有效处理{正难,负难}的样本?

Focal Loss

主要应用在目标检测,实际应用范围很广。
分类问题中,常见的loss是cross-entropy:
L_{CE} = \begin{cases} -log(p), & y = 1 \\ -log(1 - p), & y = otherwise \end{cases}

为了解决正负样本不均衡,乘以权重\alpha
L_{FL} = \begin{cases}-\alpha log(p), & y = 1 \\ -(1-\alpha)log(1 - p), & y = 0 \end{cases}

一般根据各类别数据占比,对\alpha进行取值,即当class_1占比为30%时,\alpha = 0.3

我们希望模型能更关注容易错分的数据,反向思考,就是让模型别那么关注容易分类的样本。因此,Focal Loss的思路就是,把高置信度的样本损失降低
L_{FL} = \begin{cases} -\alpha(1-p)^{\gamma} log(p), & y = 1 \\ -(1-\alpha)p^{\gamma} log(1 - p), & y = 0\\ \end{cases}

多分类样本:
L_{FL} = -\alpha(1-p)^{\gamma}log(p)

\gamma不同取值情况如下图:

from paper

模型是如何通过(1-p)^{\gamma}控制损失的衰减的呢?

当样本被误分类时,p很小,(1-p)^{\gamma}很大,loss不怎么受影响。当样本被正确分类,p很大,(1-p)^{\gamma}变小,loss衰减。
比如:当\alpha = 1\gamma=2,p为0.9时,L_{FL} = -(1-0.9)^2 * log(0.9) = 0.01*L_{CE},这个容易分类的样本,损失和cross-entropy相比,衰减了100倍。

代码

# 二分类
class BCEFocalLoss(torch.nn.Module):
    """
    https://github.com/louis-she/focal-loss.pytorch/blob/master/focal_loss.py
    二分类的Focalloss alpha 固定
    """
    def __init__(self, gamma=2, alpha=0.25, reduction='sum'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
 
    def forward(self, preds, targets):
        "preds:[B,C],targets:[B]"
        pt = torch.sigmoid(preds)
        pt = pt.clamp(min=0.0001,max = 1.0) # 概率过低,logpt后,loss返回nan
        # 我在gpu上使用时,不加.to(targets.device),报错
        targets = torch.zeros(targets.size(0),2).to(targets.device).scatter_(1,targets.view(-1,1),1) 
        loss = - self.alpha * (1 - pt) ** self.gamma * targets * torch.log(pt) - \
               (1 - self.alpha) * pt ** self.gamma * (1 - targets) * torch.log(1 - pt)
        if self.reduction == 'elementwise_mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        return loss

# 多分类
class FocalLoss(nn.Module):
    """ 
        Ref: https://github.com/yatengLG/Focal-Loss-Pytorch/blob/master/Focal_Loss.py
        FL(pt) = -alpha_t(1-pt)^gamma log(pt)
        alpha: 类别权重,常数时,类别权重为:[alpha,1-alpha,1-alpha,...];列表时,表示对应类别权重
        gamma: 难易分类的样本权重,使得模型更关注难分类的样本
        优点:帮助区分难分类的不均衡样本数据
    """
    def __init__(self, num_classes, alpha=0.25,gamma=2,reduce=True):

        super(FocalLoss,self).__init__()

        self.num_classes = num_classes
        self.gamma = gamma
        self.reduce = reduce 

        if alpha is None:
            self.alpha = torch.ones(self.num_classes,1)
        else:
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] = alpha 
            self.alpha[1:] += (1-alpha)
    
    def forward(self,preds,targets):
        "preds:[B,C],targets:[B]"
        preds = preds.view(-1,preds.size(-1)) #[B,C]
        self.alpha = self.alpha.to(preds.device)
        logpt = F.log_softmax(preds,dim=1) 
        pt = F.softmax(preds).clamp(min=0.0001,max=1.0) 

        logpt = logpt.gather(1,targets.view(-1,1)) # 对应类别值
        pt = pt.gather(1,targets.view(-1,1)) 
        self.alpha = self.alpha.gather(0,targets.view(-1))

        loss = -(1-pt) **self.gamma *logpt
        loss = self.alpha*loss.t()

        if self.reduce:
            return loss.mean()
        else:
            return loss.sum()

GHM - gradient harmonizing mechanism

Focal Loss对容易分类的样本进行了损失衰减,让模型更关注难分样本,并通过\alpha\gamma进行调参。

GHM提到:

  1. 有一部分难分样本就是离群点,不应该给他太多关注;
  2. 梯度密度可以直接统计得到,不需要调参。

GHM认为,类别不均衡可总结为难易分类样本的不均衡,而这种难分样本的不均衡又可视为梯度密度分布的不均衡。假设一个正样本被正确分类,它就是正易样本,损失不大,模型不能从中获益。而一个错误分类的样本,更能促进模型迭代。实际应用中,大量的样本都是属于容易分类的类型,这种样本一个起不了太大作用,但量级过大,在模型进行梯度更新时,起主要作用,使得模型朝这类数据更新。

from paper
  • 图示左,样本梯度分布。
    梯度模长(gradient norm)在很小和很大时,密度较大。前者,表示了大量容易分类的样本,所以梯度很低。而后者,文中认为是离群点,即便模型收敛,损失仍然很大。
  • 图示中,经过修正后的梯度分布。
    和CE,FL相比,GHM-C根据梯度密度,大量容易分类的样本和离群点的累计梯度被降级,达到样本均衡,使得模型更加有效稳定。
  • 图示右,样本集梯度贡献。
    经过GHM-C的梯度密度调整,各种难易分类的样本分布更加平滑。

简而言之:Focal Loss是从置信度p来调整loss,GHM通过一定范围置信度p的样本数来调整loss。

梯度模长

梯度模长:原文中用p^*表示真实标签,这里统一符号,用y表示:
g = |p-y|= \begin{cases} 1-p, & y = 1 \\ p, & y = 0\\ \end{cases}

推理:
p = sigmoid(x)
\frac { \partial p}{ \partial x} = p(1-p)
\frac { \partial L_{CE}}{ \partial p} = \begin{cases} -\frac {\partial logp}{\partial p}= -\frac{1}{p} , & y = 1 \\ -\frac {\partial log(1-p)}{\partial p}= \frac{1}{1 - p} , &y = 0 \end{cases}
则:
\frac {\partial L_{CE}}{\partial x} = \frac {\partial L_{CE}}{\partial p} \frac {\partial p}{\partial x} = \begin{cases} p-1 , & y = 1 \\ p, & y = 0 \end{cases} = p-y

g = |p-y| = |\frac {\partial L_{CE}}{\partial x} |

梯度密度(Gradient Density)

梯度模长分布不均,引入梯度密度:
GD(g)=\frac{1}{l_{ \epsilon} (g)} \sum_k^N \delta_{ \epsilon}(g_k,g)

在N个样本中,梯度模长分布在(g-\epsilon/2,g+\epsilon/2)范围的个数:
\delta_{ \epsilon}(x,y) = \begin{cases} 1, if&y-\frac{\epsilon} {2} \leq x <y + \frac{\epsilon} {2}\\ 0, &otherwise \end{cases}
区间长度: l_{ \epsilon} (g) = min(g+\epsilon/2,1) - max(g-\epsilon/2,0)
梯度密度协调参数:\beta_i = \frac {N}{GD(g_i)} = \frac {1}{GD(g_i)/N}
上式分母,可视为对g_i附近样本进行归一化。如果梯度分布均匀,则\beta_i = 1,如果密度过高,则意味着要降级处理。

GHM loss计算

L_{GHM-C} = \frac{1}{N}\sum_i^N \beta_i{L_{CE}(p_i,y_i)} = \sum_i^N \frac{L_{CE}(p_i,y_i)}{GD(g_i)}

代码

def _expand_binary_labels(labels,label_weights,label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels),0)
    inds = torch.nonzero(labels>=1).squeeze()
    if inds.numel() >0:
        bin_labels[inds,labels[inds]] = 1
    bin_label_weights = label_weights.view(-1,1).expand(label_weights.size(0),label_channels)
    return bin_labels, bin_label_weights
class GHMC(nn.Module):
    """GHM Classification Loss.
    Ref:https://github.com/libuyu/mmdetection/blob/master/mmdet/models/losses/ghm_loss.py
    Details of the theorem can be viewed in the paper
    "Gradient Harmonized Single-stage Detector".
    https://arxiv.org/abs/1811.05181

    Args:
        bins (int): Number of the unit regions for distribution calculation.
        momentum (float): The parameter for moving average.
        use_sigmoid (bool): Can only be true for BCE based loss now.
        loss_weight (float): The weight of the total GHM-C loss.
    """

    def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0,alpha=None):
        super(GHMC, self).__init__()
        self.bins = bins
        self.momentum = momentum
        edges = torch.arange(bins + 1).float() / bins
        self.register_buffer('edges', edges)
        self.edges[-1] += 1e-6
        if momentum > 0:
            acc_sum = torch.zeros(bins)
            self.register_buffer('acc_sum', acc_sum)
        self.use_sigmoid = use_sigmoid
        if not self.use_sigmoid:
            raise NotImplementedError
        self.loss_weight = loss_weight

        self.label_weight = alpha

    def forward(self, pred, target, label_weight =None, *args, **kwargs):
        """Calculate the GHM-C loss.
          
        Args:
            pred (float tensor of size [batch_num, class_num]):
                The direct prediction of classification fc layer.
            target (float tensor of size [batch_num, class_num]):
                Binary class target for each sample.
            label_weight (float tensor of size [batch_num, class_num]):
                the value is 1 if the sample is valid and 0 if ignored.
        Returns:
            The gradient harmonized loss.
        """
        # the target should be binary class label

        # if pred.dim() != target.dim():
        #     target, label_weight = _expand_binary_labels(
        #     target, label_weight, pred.size(-1))

        # 我的pred输入为[B,C],target输入为[B]
        target = torch.zeros(target.size(0),2).to(target.device).scatter_(1,target.view(-1,1),1)
        
        # 暂时不清楚这个label_weight输入形式,默认都为1
        if label_weight is None:
            label_weight = torch.ones([pred.size(0),pred.size(-1)]).to(target.device)

        target, label_weight = target.float(), label_weight.float()
        edges = self.edges
        mmt = self.momentum
        weights = torch.zeros_like(pred)

        # gradient length
        # sigmoid梯度计算
        g = torch.abs(pred.sigmoid().detach() - target)
        # 有效的label的位置
        valid = label_weight > 0
        # 有效的label的数量
        tot = max(valid.float().sum().item(), 1.0)
        n = 0  # n valid bins
        for i in range(self.bins):
            # 将对应的梯度值划分到对应的bin中, 0-1
            inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
            # 该bin中存在多少个样本
            num_in_bin = inds.sum().item()
            if num_in_bin > 0:
                if mmt > 0:
                    # moment计算num bin
                    self.acc_sum[i] = mmt * self.acc_sum[i] \
                        + (1 - mmt) * num_in_bin
                    # 权重等于总数/num bin
                    weights[inds] = tot / self.acc_sum[i]
                else:
                    weights[inds] = tot / num_in_bin
                n += 1
        if n > 0:
            # scale系数
            weights = weights / n

        loss = F.binary_cross_entropy_with_logits(
            pred, target, weights, reduction='sum') / tot
        return loss * self.loss_weight
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,039评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,223评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,916评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,009评论 1 291
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,030评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,011评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,934评论 3 416
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,754评论 0 271
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,202评论 1 309
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,433评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,590评论 1 346
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,321评论 5 342
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,917评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,568评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,738评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,583评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,482评论 2 352

推荐阅读更多精彩内容