代码阅读-deformable DETR (五)

这一篇我们来看一下损失函数的定义。

class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            losses: list of all the losses to be applied. See get_loss for list of available losses.
            focal_alpha: alpha in Focal Loss
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.losses = losses
        self.focal_alpha = focal_alpha

该类定义前的注释指出DETR的损失包含两步:

  1. 计算模型输出和gt之间的二分图匹配;
  2. 对于匹配成功的数据对监督其类别和box

在初始化函数的参数里有一个matcher需要说明一下,这个是用来计算二分图匹配的nn.Module类:

class HungarianMatcher(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network

    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self,
                 cost_class: float = 1,
                 cost_bbox: float = 1,
                 cost_giou: float = 1):
        """Creates the matcher

        Params:
            cost_class: This is the relative weight of the classification error in the matching cost
            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"

这种两个集合数目不同的二分图匹配问题一般选择少的一方作为匹配对的数目,其余表示匹配到背景。但二分图匹配还有一种选择方式是根据能量阈值选择匹配对数的方法。初始化函数的参数导入的是类别、box的L1差异以及giou差异在匹配能量中的占比。也就是说最终的匹配能量由这三部分组成

def forward(self, outputs, targets):  # Matcher的推理函数
      with torch.no_grad():
            bs, num_queries = outputs["pred_logits"].shape[:2]

            # We flatten to compute the cost matrices in a batch
            out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
            out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

            # Also concat the target labels and boxes
            tgt_ids = torch.cat([v["labels"] for v in targets])
            tgt_bbox = torch.cat([v["boxes"] for v in targets])

            # Compute the classification cost.  # 采用的focal loss
            alpha = 0.25
            gamma = 2.0
            neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
            pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
            cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]

            # Compute the L1 cost between boxes
            cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

            # Compute the giou cost betwen boxes
            cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
                                             box_cxcywh_to_xyxy(tgt_bbox))

            # Final cost matrix
            C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
            C = C.view(bs, num_queries, -1).cpu()

            sizes = [len(v["boxes"]) for v in targets]  # batch中每个sample中目标的个数
            indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] # 相当于选择每个样本的sample与target的相似度矩阵进行二分匹配
            return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] # 长度为batchsize的元组list

这里首先需要注意的是整个推理过程是不参与梯度反向传导的。其次在刻画预测类别与gt的差异性时使用的是focal loss,且其参数\alpha,\gamma是固定的。最终对batch中每个样本使用匈牙利算法进行二分图匹配,获得对应的索引集合,输出格式是[(第一个样本配对的输出索引集合,第一个样本配对的gt索引集合), ...]

有个有意思的地方是,函数中没有采用循环方式分别针对每个样本计算能量矩阵,而是直接计算batch中所有的预测与所有的gt的能量矩阵,然后在通过索引的方式分别对每一个样本的能量矩阵块进行匈牙利匹配,不确定这种算法效率和循环比是否更有效

在使用Matcher获得匹配对之后,便可以对匹配对的回归和分类损失进行监督,SetCriterionforward函数主要是对最基础的输出(decoder的最后一层输出)计算了损失,可能的情况会计算辅助损失(decoder中每一个layer的输出)和two-stage的proposal损失(Encoder的最后一层对proposal的预测输出),不同情况下的损失计算方式是相同的,只是输入不同,这些损失包括labels,boxesmasks,我们主要看检测,所以我们这里忽略masks

loss_map中还有一个cardinality指标,注意到该值是不进行梯度反传的,只是用来作为模型性能度量的一个指标,表示预测的目标数与真实目标数的差异,其定义函数loss_cardinality中最重要的一句是

card_pred =  (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)  

这里card_pred表示预测为前景目标的mask,因为在制作target时,使用类别数表示背景,也即预测类别向量的最后一位表示其属于背景的概率。

在计算labels和box的损失时,出现一个函数_get_src_permutation_idx,这个函数主要是将Matcher返回的多个样本的匹配对索引拉平方便索引。举个例子,batch_size=2, query_num=4, 第一个样本的gt数位2, 第二个样本的gt数为3,那么matcher的返回可能是:
[([0,2], [0, 1]), ([1,3, 0], [2, 0, 1])], _get_src_permutation_idx的返回值idx为一个元组,即[0, 0,1,1,1](即每个匹配对对应的query所在的样本在batch中的索引)和[0, 2,1,3,0](即每个匹配对的query在每个样本所有query中的索引),这样的话
target_classes[idx] 表示选择对应的样本对应的query,进而进行gt赋值。
loss_label中需要注意的代码行:

        target_classes_onehot = target_classes_onehot[:,:,:-1]  # 最后一类是背景类
        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]

表示针对于gt为背景的query,其gt是全零向量,因此采用的是sigmoid+F.binary_cross_entropy_with_logits 构建Focal loss,而不是softmax。
这里有个奇怪的地方是 loss_ce有一个系数query_num, 这是应为sigmoid_focal_loss输出有一个query_num上的mean操作,所以这里可以抵消。

loss_boxes操作类似,唯一需要注意的是调用box_ops.generalized_box_iou是返回的是m\times m的矩阵,即src和gt任意两两配对,因此需要diag操作。


以上就是DETR损失函数部分的定义,很容易阅读。下一篇我们来看看数据集定义部分。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,588评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,456评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,146评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,387评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,481评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,510评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,522评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,296评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,745评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,039评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,202评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,901评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,538评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,165评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,415评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,081评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,085评论 2 352

推荐阅读更多精彩内容