源码解析目标检测的跨界之星DETR(五)、loss函数与匈牙利匹配算法

Date: 2020/07/17

Coder: CW

Foreword:

本文将对 loss函数的实现进行解析,由于 DETR 是预测结果是集合的形式,因此在计算loss的时候有个关键的前置步骤就是将预测结果和GT进行匹配,这里的GT类别是不包括背景的,未被匹配的预测结果就自动被归类为背景。匹配使用的是匈牙利算法,该算法主要用于解决与二分图匹配相关的问题,对这部分感兴趣的朋友们可以参考下这篇文:匈牙利算法


Outline

I. Loss Function

    i). 分类loss

    ii). 回归loss

II. Hungarian algorithm(匈牙利算法)


Loss Function

先来看看与loss函数相关的一些参数:matcher就是将预测结果与GT进行匹配的匈牙利算法,这部分的实现会在下一节解析。weight_dict是为各部分loss设置的权重,主要包括分类与回归损失,分类使用的是交叉熵损失,而回归损失包括bbox的 L1 Loss(计算x、y、w、h的绝对值误差)与 GIoU Loss。若设置了masks参数,则代表分割任务,那么还需加入对应的loss类型。另外,若设置了aux_loss,即代表需要计算解码器中间层预测结果对应的loss,那么也要设置对应的loss权重。

与loss函数实现相关的初始化参数

loss函数是通过实例化SetCriterion对象来构建。

构建loss函数

losses变量指示需要计算哪些类型的loss,其中cardinality仅用作log,并不涉及反向传播梯度。

loss_cardinality

可以先来看下SetCriterion这个类的doc string,了解下各部分参数的意义。

SetCriterion(i)

CW 也作了对应的注释:

SetCriterion(ii)

接下来看下其前向过程,从而知悉loss的计算。

这里一定要先搞清楚模型的输出(outputs)和GT(targets)的形式,对于outputs可参考CW在下图中的注释;而targets是一个包含多个dict的list,长度与batch size相等,其中每个dict的形式如COCO数据集的标注,具体可参考该系列的第二篇文章: 源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理 中的数据处理部分。

SetCriterion(iii)

如CW在前言部分所述,计算loss的一个关键的前置步骤就是将模型输出的预测结果与GT进行匹配,对应下图中self.matcher()的部分,返回的indices的形式已在注释中说明。

SetCriterion(iv)

接下来是计算各种类型的loss,并将对应结果存到一个dict中(如下图losses变量),self.get_loss()方法返回loss计算结果。

SetCriterion(v)
SetCriterion(vi)

get_loss方法中并不涉及具体loss的计算,其仅仅是将不同类型的loss计算映射到对应的方法,最后将计算结果返回。

get_loss

接下来,我们就对分类和回归损失的计算过程分别进行解析。

i). 分类loss

首先说明下,doc string里写的是NLL Loss,但实际调用的是CE Loss,这是因为在Pytorch实现中,CE Loss实质上就是将Log-Softmax操作和NLL Loss封装在了一起,如果直接使用NLL Loss,那么需要先对预测结果作Log-Softmax操作,而使用CELoss则直接免去了这一步。

loss_labels(i)

其次,要理解红框部分的_get_src_permutation_idx()在做什么。输入参数indices是匹配的预测(query)索引与GT的索引,其形式在上述SetCriterion(iv)图中注释已有说明。该方法返回一个tuple,代表所有匹配的预测结果的batch index(在当前batch中属于第几张图像)和 query index(图像中的第几个query对象)。

_get_src_permutation_idx

类似地,我们可以获得当前batch中所有匹配的GT所属的类别(target_classes_o),然后通过src_logitstarget_classes_o就可以设置预测结果对应的GT了,这就是下图中的target_classes。target_classes的shape和src_logits一致,代表每个query objects对应的GT,首先将它们全部初始化为背景,然后根据匹配的索引(idx)设置匹配的GT(target_classes_p)类别。

loss_labels(ii) 

“热身活动”做完后,终于可以开始计算loss了,注意在使用Pytorch的交叉熵损失时,需要将预测类别的那个维度转换到通道这个维度上(dim1)。

loss_labels(iii) 

另外,class_error计算的是Top-1精度(百分数),即预测概率最大的那个类别与对应被分配的GT类别是否一致,这部分仅用于log,并不参与模型训练。

accuracy

ii). 回归loss

回归loss的计算包括预测框与GT的中心点和宽高的L1 loss以及GIoU loss

注意在下图注释中,num_matched_queries1+num_matched_queries2+..., 和 num_matched_objs1+num_matched_objs2+... 是相等的,在前面 SetCriterion(iv) 图中matcher的返回结果注释中有说明。

loss_boxes(i)

以下就是loss的计算。注意下 reduction 参数,若不显式进行设置,在Pytorch的实现中默认是'mean',即返回所有涉及误差计算的元素的均值。

loss_boxes(ii)

另外,在计算GIoU loss时,使用了torch.diag()获取对角线元素,这是因为generalized_box_iou()方法返回的是所有预测框与所有GT的GIoU,比如预测框有N个,GT有M个,那么返回结果就是NxM个GIoU。而如 loss_boxes(i) 图中所示,我们预先对匹配的预测框和GT进行了排列,即N个预测框中的第1个匹配M个GT中的第1个,N中第2个匹配M中第2个,..,N中第i个匹配M中第i个,于是我们要取相互匹配的那一项来计算loss。

generalized_box_iou(i)
generalized_box_iou(ii)

Hungarian algorithm(匈牙利算法)

build_matcher()方法返回HungarianMatcher对象,其实现了匈牙利算法,在这里用于预测集(prediction set)和GT的匹配,最终匹配方案是选取“loss总和”最小的分配方式。注意CW对loss总和这几个字用了引号,其与loss函数中计算的loss并不完全一致,在这里是作为度量cost/metric)的角色,度量的值决定了匹配的结果,接下来我们看代码实现就会一清二楚。

build_matcher

如doc string所述,GT是不包含背景类的,通常预测集中的物体数量(默认为100)会比图像中实际存在的目标数量多,匈牙利算法按1对1的方式进行匹配,没有被匹配到的预测物体就自动被归类为背景(non-objects)。

HungarianMatcher(i)

以下cost_xx代表各类型loss的相对权重,在匈牙利算法中,描述为各种度量的相对权重会更合适,因此,这里命名使用的是'cost'。

HungarianMatcher(ii)

现在来看看前向过程,注意这里是不需要梯度的

HungarianMatcher(iii)

首先将预测结果和GT进行reshape,并对应起来,方便进行计算。

HungarianMatcher(iv)

注:以上tgt_bbox等式右边的torch.cat()方法中应加上参数dim=0

然后就可以对各种度量(各类型loss)进行计算。

如代码所示,这里的cost与上一节解析的loss并不完全一样,比如对于分类来说,loss计算使用的是交叉熵,而这里为了更加简便,直接采用1减去预测概率的形式,同时由于1是常数,于是作者甚至连1都省去了,有够机智(懒)的...

HungarianMatcher(v)

另外,在计算bbox的L1误差时,使用了torch.cdist(),其中设置参数p=1代表L1范式(默认是p=2,即L2范式),这个方法会对每对预测框与GT都进行误差计算:比如预测框有N个,GT有M个,结果就会有NxM个值。

接着对各部分度量加权求和,得到一个总度量。然后,统计当前batch中每张图像的GT数量,这个操作是为什么呢?接着看,你会发现这招很妙!

C.split()在最后一维按各张图像的目标数量进行分割,这样就可以在各图像中将预测结果与GT进行匹配了。

HungarianMatcher(vi)

匹配方法使用的是scipy优化模块中的linear_sum_assignment(),其输入是二分图的度量矩阵,该方法是计算这个二分图度量矩阵的最小权重分配方式,返回的是匹配方案对应的矩阵行索引和列索引。

linear_sum_assignment

结尾日常吹水

吾以为,loss函数的设计是DL项目中最重要的部分之一。CW每次看项目的源码时,最打起精神的就是这一part了。

从数学的角度来看,DL本质上是一个优化问题,loss是模型学习目标在数学上的表达形式,我们期望模型朝着loss最小的方向发展,因此,loss函数的设计关系到优化的可行性及难易程度,可谓成败之关键。因此,这部分其实很考验炼丹师的功力,也最能体现一个人考虑和解决问题的思想。

如今,我们是站在前人(一堆大佬,不,是巨佬!)的肩膀上,日常无脑地来来去去都用那几种loss,真是幸福的新生儿呐!

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
禁止转载,如需转载请通过简信或评论联系作者。
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,588评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,456评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,146评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,387评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,481评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,510评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,522评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,296评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,745评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,039评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,202评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,901评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,538评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,165评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,415评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,081评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,085评论 2 352