一、解决了什么问题?
- 不同类别不均衡
- 难易样本不均衡
- 减少易分类样本的权重,增加难分类样本的损失贡献值
二、基本概念
- 采用soft - gamma: 在训练的过程中阶段性的增大gamma 可能会有更好的性能提升。
- alpha 与每个类别在训练数据中的频率有关。
-
F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与 F.cross_entropy相同。
F.nll_loss中实现了对于target的one-hot encoding,将其编码成与input shape相同的tensor,然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production。
三、公式
标准的Cross Entropy 为:
Focal Loss 为:
其中
四、代码实现
pytorch 来自Kaggle的实现(基于二分类交叉熵实现)
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
if self.logits:
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
else:
BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
Tensorflow 参考
huanghuidmml/cail2019_track2: 中国法研杯CAIL2019要素抽取任务第三名方案分享 (github.com)
def focal_loss(prediction_tensor, target_tensor, alpha=FLAGS.alpha, gamma=FLAGS.gamma):
r"""Compute focal loss for predictions.
Multi-labels Focal loss formula:
FL = -alpha * (z-p)^gamma * log(p) -(1-alpha) * p^gamma * log(1-p)
,which alpha = 0.25, gamma = 2, p = sigmoid(x), z = target_tensor.
Args:
prediction_tensor: A float tensor of shape [batch_size, num_anchors,
num_classes] representing the predicted logits for each class
target_tensor: A float tensor of shape [batch_size, num_anchors,
num_classes] representing one-hot encoded classification targets
weights: A float tensor of shape [batch_size, num_anchors]
alpha: A scalar tensor for focal loss alpha hyper-parameter
gamma: A scalar tensor for focal loss gamma hyper-parameter
Returns:
loss: A (scalar) tensor representing the value of the loss function
"""
sigmoid_p = tf.nn.sigmoid(prediction_tensor)
zeros = array_ops.zeros_like(sigmoid_p, dtype=sigmoid_p.dtype)
# For poitive prediction, only need consider front part loss, back part is 0;
# target_tensor > zeros <=> z=1, so poitive coefficient = z - p.
pos_p_sub = array_ops.where(target_tensor > zeros, target_tensor - sigmoid_p, zeros)
# For negative prediction, only need consider back part loss, front part is 0;
# target_tensor > zeros <=> z=1, so negative coefficient = 0.
neg_p_sub = array_ops.where(target_tensor > zeros, zeros, sigmoid_p)
per_entry_cross_ent = - alpha * (pos_p_sub ** gamma) * tf.log(tf.clip_by_value(sigmoid_p, 1e-8, 1.0)) \
- (1 - alpha) * (neg_p_sub ** gamma) * tf.log(tf.clip_by_value(1.0 - sigmoid_p, 1e-8, 1.0))
return tf.reduce_sum(per_entry_cross_ent), per_entry_cross_ent, sigmoid_p
参考tensorflow改写的pytorch版本
# 由于torch版本没有自带的clip_by_value函数
# 网上参考如下,focal_loss函数不变
def clip_by_tensor(t,t_min,t_max):
"""
clip_by_tensor
:param t: tensor
:param t_min: min
:param t_max: max
:return: cliped tensor
"""
# t=t.float()
# t_min=t_min.float()
# t_max=t_max.float()
result = (t >= t_min).float() * t + (t < t_min).float() * t_min
result = (result <= t_max).float() * result + (result > t_max).float() * t_max
return result