一、Focal loss损失函数
Focal Loss的引入主要是为了解决**难易样本数量不平衡****(注意,有区别于正负样本数量不平衡)的问题,实际可以使用的范围非常广泛。
本文的作者认为,易分样本(即,置信度高的样本)对模型的提升效果非常小,模型应该主要关注与那些难分样本。一个简单的思想:把高置信度(p)样本的损失再降低一些不就好了吗!
focal loss函数公式:
其中,(1)为类别权重,用来权衡正负样本不均衡问题,倘若负样本越多,给负样本的
权重就越小,这样就可以降低负样本的影响。加一个小于1的超参数,相当于把Loss曲线整体往下拉一些,使得当样本概率较大的时候影响减小。;
(2) 表示难分样本权重,用来衡量难分样本和易分样本,对于正类样本而言,预测结果为0.95肯定是简单样本,所以(1-0.95)的gamma次方就会很小,这时损失函数值就变得更小。而预测概率为0.3的样本其损失相对很大。即正样本:概率越小,表示hard example,损失越大; 负样本:概率越大,表示hard example,损失越大。γ 起到了平滑的作用,作者的实验中,论文采用α=0.25,γ=2效果最好。。针对hard example,Pt比较小,则权重比较大,让网络倾向于利用这样的样本来进行参数的更新
Focal loss缺点(腾讯面试):
(1) 对异常样本敏感: 假如训练集中有个样本label标错了,那么focal loss会一直放大这个样本的loss(模型想矫正回来),导致网络往错误方向学习。
(2)对分类边界异常点处理不理想:由于边界样本表示相似性较高,对于不同异常值表示,每次损失更新时,都会有反复在分类决策面(0.5)上反复横跳的点,导致模型收敛速度下降,退化成交叉熵损失。
二、Focal loss损失函数代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
def __init__(self, class_num, alpha=0.20, gamma=1.5, use_alpha=False, size_average=True):
super(FocalLoss, self).__init__()
self.class_num = class_num
self.alpha = alpha
self.gamma = gamma
if use_alpha:
self.alpha = torch.tensor(alpha).cuda()
# self.alpha = torch.tensor(alpha)
self.softmax = nn.Softmax(dim=1)
self.use_alpha = use_alpha
self.size_average = size_average
def forward(self, pred, target):
prob = self.softmax(pred.view(-1,self.class_num))
prob = prob.clamp(min=0.0001,max=1.0)
target_ = torch.zeros(target.size(0),self.class_num).cuda()
# target_ = torch.zeros(target.size(0),self.class_num)
target_.scatter_(1, target.view(-1, 1).long(), 1.)
if self.use_alpha:
batch_loss = - self.alpha.double() * torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
else:
batch_loss = - torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
batch_loss = batch_loss.sum(dim=1)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
三、Focal loss损失函数引用及使用
# 函数引用(focal_loss为模型文件名)
from focal_loss import FocalLoss
#...
# 损失函数初始化
criterion = FocalLoss(class_num=3)
#...
# 获得损失函数
loss = criterion(outputs, targets)