这一篇我们来看一下损失函数的定义。
class SetCriterion(nn.Module):
""" This class computes the loss for DETR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25):
""" Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
matcher: module able to compute a matching between targets and proposals
weight_dict: dict containing as key the names of the losses and as values their relative weight.
losses: list of all the losses to be applied. See get_loss for list of available losses.
focal_alpha: alpha in Focal Loss
"""
super().__init__()
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.losses = losses
self.focal_alpha = focal_alpha
该类定义前的注释指出DETR的损失包含两步:
- 计算模型输出和gt之间的二分图匹配;
- 对于匹配成功的数据对监督其类别和box
在初始化函数的参数里有一个matcher需要说明一下,这个是用来计算二分图匹配的nn.Module
类:
class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(self,
cost_class: float = 1,
cost_bbox: float = 1,
cost_giou: float = 1):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
这种两个集合数目不同的二分图匹配问题一般选择少的一方作为匹配对的数目,其余表示匹配到背景。但二分图匹配还有一种选择方式是根据能量阈值选择匹配对数的方法。初始化函数的参数导入的是类别、box的L1差异以及giou差异在匹配能量中的占比。也就是说最终的匹配能量由这三部分组成
def forward(self, outputs, targets): # Matcher的推理函数
with torch.no_grad():
bs, num_queries = outputs["pred_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. # 采用的focal loss
alpha = 0.25
gamma = 2.0
neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
box_cxcywh_to_xyxy(tgt_bbox))
# Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu()
sizes = [len(v["boxes"]) for v in targets] # batch中每个sample中目标的个数
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] # 相当于选择每个样本的sample与target的相似度矩阵进行二分匹配
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] # 长度为batchsize的元组list
这里首先需要注意的是整个推理过程是不参与梯度反向传导的。其次在刻画预测类别与gt的差异性时使用的是focal loss,且其参数是固定的。最终对batch中每个样本使用匈牙利算法进行二分图匹配,获得对应的索引集合,输出格式是[(第一个样本配对的输出索引集合,第一个样本配对的gt索引集合), ...]
有个有意思的地方是,函数中没有采用循环方式分别针对每个样本计算能量矩阵,而是直接计算batch中所有的预测与所有的gt的能量矩阵,然后在通过索引的方式分别对每一个样本的能量矩阵块进行匈牙利匹配,不确定这种算法效率和循环比是否更有效
在使用Matcher获得匹配对之后,便可以对匹配对的回归和分类损失进行监督,SetCriterion
的forward
函数主要是对最基础的输出(decoder的最后一层输出)计算了损失,可能的情况会计算辅助损失(decoder中每一个layer的输出)和two-stage的proposal损失(Encoder的最后一层对proposal的预测输出),不同情况下的损失计算方式是相同的,只是输入不同,这些损失包括labels
,boxes
和masks
,我们主要看检测,所以我们这里忽略masks
。
loss_map中还有一个cardinality
指标,注意到该值是不进行梯度反传的,只是用来作为模型性能度量的一个指标,表示预测的目标数与真实目标数的差异,其定义函数loss_cardinality
中最重要的一句是
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
这里card_pred表示预测为前景目标的mask,因为在制作target时,使用类别数表示背景,也即预测类别向量的最后一位表示其属于背景的概率。
在计算labels和box的损失时,出现一个函数_get_src_permutation_idx
,这个函数主要是将Matcher返回的多个样本的匹配对索引拉平方便索引。举个例子,batch_size=2, query_num=4, 第一个样本的gt数位2, 第二个样本的gt数为3,那么matcher的返回可能是:
[([0,2], [0, 1]), ([1,3, 0], [2, 0, 1])], _get_src_permutation_idx
的返回值idx为一个元组,即[0, 0,1,1,1](即每个匹配对对应的query所在的样本在batch中的索引)和[0, 2,1,3,0](即每个匹配对的query在每个样本所有query中的索引),这样的话
target_classes[idx]
表示选择对应的样本对应的query,进而进行gt赋值。
loss_label
中需要注意的代码行:
target_classes_onehot = target_classes_onehot[:,:,:-1] # 最后一类是背景类
loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
表示针对于gt为背景的query,其gt是全零向量,因此采用的是sigmoid+F.binary_cross_entropy_with_logits 构建Focal loss,而不是softmax。
这里有个奇怪的地方是 loss_ce
有一个系数query_num, 这是应为sigmoid_focal_loss
输出有一个query_num上的mean操作,所以这里可以抵消。
loss_boxes
操作类似,唯一需要注意的是调用box_ops.generalized_box_iou
是返回的是的矩阵,即src和gt任意两两配对,因此需要diag操作。
以上就是DETR损失函数部分的定义,很容易阅读。下一篇我们来看看数据集定义部分。