损失主要包括概率图损失(probability map loss)Ls,二值图损失(binary map loss)Lb和阈值图损失(threshold map loss)Lt,计算公式如下
概率图损失
概率图损失使用的是BCE损失,为平衡正负样本,采用了在线难例挖掘的策略,正负样本比例1:3
代码实现
class BalanceCrossEntropyLoss(nn.Module):
'''
设置正负样本比例进行在线难样本挖掘
Balanced cross entropy loss.
Shape:
- Input: :math:`(N, 1, H, W)`
- GT: :math:`(N, 1, H, W)`, same shape as the input
- Mask: :math:`(N, H, W)`, same spatial shape as the input
- Output: scalar.
'''
def __init__(self, negative_ratio=3.0, eps=1e-6):
"""
:param negative_ratio: 负样本比例
:param eps: epsilon的缩写,误差
"""
super(BalanceCrossEntropyLoss, self).__init__()
self.negative_ratio = negative_ratio
self.eps = eps
def forward(self,
pred: torch.Tensor,
gt: torch.Tensor,
mask: torch.Tensor,
return_origin=False):
'''
Args:
pred: shape :math:`(N, 1, H, W)`, the prediction of network
gt: shape :math:`(N, 1, H, W)`, the target
mask: shape :math:`(N, H, W)`, the mask indicates positive regions
'''
positive = (gt * mask).byte() # 等价于torch.uint8
negative = ((1 - gt) * mask).byte()
positive_count = int(positive.float().sum())
negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio))
loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none')
positive_loss = loss * positive.float()
negative_loss = loss * negative.float()
# negative_loss, _ = torch.topk(negative_loss.view(-1).contiguous(), negative_count)
negative_loss, _ = negative_loss.view(-1).topk(negative_count)
balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + self.eps)
if return_origin:
return balance_loss, loss
return balance_loss
阈值图损失
阈值图损失使用的L1损失
代码实现
class MaskL1Loss(nn.Module):
def __init__(self, eps=1e-6):
super(MaskL1Loss, self).__init__()
self.eps = eps
def forward(self, pred: torch.Tensor, gt, mask):
loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
return loss
二值图损失
计算公式
代码实现
class DiceLoss(nn.Module):
'''
Loss function from https://arxiv.org/abs/1707.03237,
where iou computation is introduced heatmap manner to measure the
diversity bwtween tow heatmaps.
'''
def __init__(self, eps=1e-6):
super(DiceLoss, self).__init__()
self.eps = eps
def forward(self, pred: torch.Tensor, gt, mask, weights=None):
'''
pred: one or two heatmaps of shape (N, 1, H, W),
the losses of tow heatmaps are added together.
gt: (N, 1, H, W)
mask: (N, H, W)
'''
return self._compute(pred, gt, mask, weights)
def _compute(self, pred, gt, mask, weights):
if pred.dim() == 4:
pred = pred[:, 0, :, :]
gt = gt[:, 0, :, :]
assert pred.shape == gt.shape
assert pred.shape == mask.shape
if weights is not None:
assert weights.shape == mask.shape
mask = weights * mask
intersection = (pred * gt * mask).sum()
union = (pred * mask).sum() + (gt * mask).sum() + self.eps
loss = 1 - 2.0 * intersection / union
assert loss <= 1
return loss