源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理

Date: 2020/06/28

Author: CW

前言:

本文会对模型训练部分的代码进行解析,主要把训练过程的pipeline过一下,其中有些部分的具体实现会放到后面的篇章中讲解,这部分源码对应于项目的 main.py 文件。


Outline

1、训练pipeline

2、一个训练周期的过程

3、数据处理


训练pipeline

首先是解析运行脚本时用户输入的参数,然后创建记录结果的目录,最后根据用户指定的参数进行训练。

pipeline整体过程

get_args_parser() 方法中设置了用户可以指定的参数项,想要了解的朋友们可以去参考源码,这里就不再阐述了,接下来主要看看 main() 方法,其中的内容就是训练过程的pipeline。

init_distributed_mode() 方法是与分布式训练相关的设置,在该方法里,是通过环境变量来判断是否使用分布式训练,如果是,那么就设置相关参数,具体可参考 util/misc.py 文件中的源码,这里不作解析。

训练pipeline(i)

参数项 frozen_weights 代表是否固定住参数的权重,类似于迁移学习的微调。如果是,那么需要同时指定 masks 参数,代表这种条件仅适用于分割任务。上图最后部分是固定随机种子,以便复现结果。

然后就是构造模型、loss函数以及后处理方法、输出可训练的参数数量。

训练pipeline(ii)

下图的部分包括设置优化器、学习率策略以及构建训练和验证集。

训练pipeline(iii)

从上图可以看到,这里将backbone和其它部分的参数分开,以便使用不同的初始学习率进行训练。构造数据集使用的 build_dataset() 方法调用了COCO数据集的api,其中的内容具体会在后文展示。

构造了数据集后,设置数据集的采样器,并且装在到 DataLoader,以进行批次训练。

训练pipeline(iv)

注意到以上使用了 collate_fn 方法来重新组装一个batch的数据,具体细节会在后面数据处理部分一并讲解。

训练pipeline(v)

下图部分主要是用于从历史的某个训练阶段中恢复过来,包括加载当时的模型权重、优化器和学习率等参数。

训练pipeline(vi)
训练pipeline(vii)

接下来真正开始一个个周期地训练,每个周期后根据学习率策略调整下学习率。

训练pipeline(viii)

下图部分是将训练结果和相关参数记录到指定文件。

训练pipeline(ix)
训练pipeline(x)

下图中的内容是将训练和验证的结果记录到(分布式)主节点中指定的文件。

训练pipeline(xi)

最后计算训练的总共耗时并且打印,整个训练流程就此结束。

训练pipeline(xii)

一个训练周期的过程

这部分对应的代码在 detr/engine.py 中的 train_one_epoch() 方法,在上一节的图中也能看到。顾名思义,这部分内容就是模型在一个训练周期中的操作,下面就来一起瞄瞄里面有啥值得学习的地方。

惯用套路,首先将模型设置为训练模式,这样梯度才能进行反向传播,从而更新模型参数的权重。注意到这里同时将 criterion 对象也设为train模式,它是 SetCriterion 类的一个对象实例,代表loss函数,看了下相关代码发现里面并没有需要学习的参数,因此感觉之类可以将这行代码去掉,后面我会亲自实践看看,朋友们也可一试。

train_one_epoch(i)

这里用到了一个类 MetricLogger(位于 detr/util/misc.py),它主要用于log输出,其中使用了一个defaultdict来记录各种数据的历史值,这些数据为 SmoothValue(位于 detr/util/misc.py) 类型,该类型通过指定的窗口大小(上图中的 window_size)来存储数据的历史步长(比如1就代表不存储历史记录,每次新的值都会覆盖旧的),并且可以格式化输出。另外 SmoothValue 还实现了统计中位数、均值等方法,并且能够在各进程间同步数据。

MetricLogger 除了通过key来存储SmoothValue以外,最重要的就是其实现了一个log_every的方法,这个方法是一个生成器,用于将每个batch的数据取出(yeild),然后该方法内部会暂停在此处,待模型训练完一次迭代后再执行剩下的内容,进行各项统计,然后再yeild下一个batch的数据,暂停在那里,以此重复,直至所有batch都训练完。这种方式在其它项目中比较少见,感兴趣的炼丹者们可以一试,找些新鲜感~

train_one_epoch(ii)

在计算出loss后,若采用了分布式训练,那么就在各个进程间进行同步。另外,若梯度溢出了,那么此时会产生梯度爆炸,于是就直接结束训练。

train_one_epoch(iii)

于是,为避免梯度爆炸,在训练过程中,对梯度进行裁剪,裁剪方式有很多种,可以直接对梯度值处理,这里的方式是对梯度的范式做截断,默认是第二范式,即所有参数的梯度平方和开方后与一个指定的最大值(下图中max_norm)相比,若比起大,则按比例对所有参数的梯度进行缩放。

train_one_epoch(iv)

最后,将 MetricLogger 统计的各项数据在进程间进行同步,同时返回它们的历史均值,对于这个历史均值的解释见下图注释。

train_one_epoch(v)

关于 MetricLogger 和 SmoothValue 的具体实现这里就不作解析了,这只是作者的个人喜好,用于训练过程中数据的记录与展示,和模型的工作原理及具体实现无关,大家如果想要将 DETR 用到自己的项目上,完全可以不care这部分。对于 MetricLogger 和 SmoothValue 的这种做法,我们可以学习下里面的技巧,抽象地继承,而不必生搬硬套。


数据处理

先来讲解下第一部分中拉下的collate_fn(),它的作用是将一个batch的数据重新组装为自定义的形式,输入参数batch就是原始的一个batch数据,通常在Pytorch中的Dataloader中,会将一个batch的数据组装为((data1, label1), (data2, label2), ...)这样的形式,于是第一行代码的作用就是将其变为[(data1, data2, data3, ...), (label1, label2, label3, ...)]这样的形式,然后取出batch[0]即一个batch的图像输入到nested_tensor_from_tensor_list()方法中进行处理,最后将返回结果替代原始的这一个batch图像数据。

collate_fn

接着来看看nested_tensor_from_tensor_list()是如何操作的。首先,为了能够统一batch中所有图像的尺寸,以便形成一个batch,我们需要得到其中的最大尺度(在所有维度上),然后对尺度较小的图像进行填充(padding),同时设置mask以指示哪些部分是padding得来的,以便后续模型能够在有效区域内去学习目标,相当于加入了一部分先验知识。

nested_tensor_from_tensor_list

下图演示了如何得到batch中每张图像在每个维度上的最大值,代码已经show得很明白了,CW无需多言。

_max_by_axis

构建数据集使用的是 build_dataset() 这个方法,该方法位于 datasets/__init__.py 文件。方法内部根据用户参数来构造用于目标检测/全景分割的数据集。image_set 是一个字符类型的参数,代表要构造的是训练集还是验证集。

build_dataset

针对目标检测任务,我们来看看 build_coco() 这个方法的内容,该方法位于 datasets/coco.py

build

这个方法首先检查数据文件路径的有效性,然后构造一个字典类型的 PATHS 变量来映射训练集与验证集的路径,最后实例化一个 CocoDetection() 对象,CocoDetection 这个类继承了torchvision.datasets.CocoDetection。

CocoDetection

在类的初始化方法中,首先调用父类的初始化方法,将图像文件及标注文件的路径传进去。transforms 是用于数据增强的方法;根据名字来看,ConvertCocoPolysToMask() 这个对象是将数据标注的多边形坐标转换为掩码,但其实不仅仅是这样,或者说不一定是这样,因为需要根据传进去的参数 return_masks 来确定。

另外,需要提下COCO数据集中标注字段annotation的格式,对于目标检测任务,其格式如下:

annotation

当 "iscrowd" 字段为0时,segmentation就是polygon的形式,比如这时的 "segmentation" 的值可能为 [[510.66, 423.01, 511.72, 420.03, 510.45......], ..],其中是一个个polygon即多边形,,这些数按序两两组成多边形各个点的横、纵坐标,也就是说,表示polygon的list中如果有n个数(必定是偶数),那么就代表了 n/2 个点坐标。

至于取数据用到的 __getitem__ 方法,首先也是调用父类的这个方法获得图像和对应的标签,然后 prepare 就是调用 ConvertCocoPolysToMask() 这个对象对图像和标签进行处理,之后若有指定数据增强,则进一步进行对应的处理,最后返回这一系列处理后的图像和对应的标签。

现在我们来看看 ConvertCocoPolysToMask 这个类内部究竟玩了些什么东东。

ConvertCocoPolysToMask(i)

这里的 target 是一个list,其中包含了多个字典类型的annotation,每个annotation的格式如上一部分的图中所示。这里将 "iscrowd" 为1的数据(即一组对象,如一群人)过滤掉了,仅保留标注为单个对象的数据。

另外这里对bbox的形式做了转换,将"xywh"转换为"x1y1x2y2"的形式,并且将它们控制图像尺寸范围内。

ConvertCocoPolysToMask(ii)

通过上图可以了解到,若传进来的 return_masks 值不为True,那么实质上是没有做 "convert_poly_to_mask" 这个操作的,这也是为何我在上述提到 ConvertCocoPolysToMask() 这个对象的实际操作可能和其命名有所差异。

下图中,keep 代表那些有效的bbox,即左上角坐标小于右下角坐标那些,过滤掉无效的那批。

ConvertCocoPolysToMask(iii)

在进行完处理和过滤操作后,更新annotation里各个字段的值,同时新增 "orig_size" 和 "size" 两个 key,最后返回处理后的图像和标签。

ConvertCocoPolysToMask(iv)

综上所述,ConvertCocoPolysToMask() 仅在传入的参数 return_masks 为True时做了将多边形转换为掩码的操作,该对象的主要工作其实是过滤掉标注为一组对象的数据,以及筛选掉bbox坐标不合法的那批数据。

现在我们来看看 convert_coco_poly_to_mask() 这个方法即将多边形坐标转换为掩码是如何操作的。

convert_coco_poly_to_mask

该方法中调用的 frPyObjects 和 decode 都是 coco api(pycocotools)中的方法,将每个多边形结合图像尺寸解码为掩码,然后将掩码增加至3维(若之前不足3维)。

这里有个实现上的细节——为何要加一维呢?因为我们希望的是这个mask能够在图像尺寸范围(h, w)中指示每个点为0或1,在解码后,mask的shape应该是 (h,w),加一维变为 (h,w,1),然后在最后一个维度使用any()后才能维持原来的维度即(h,w);如果直接在(h,w)的最后一维使用any(),那么得到的shape会是(h,),各位可以码码试试。

最后,将一个个多边形转换得到的掩码添加至列表,堆叠起来形成张量后返回。

在本系列第一篇文中我就提到过,说 DETR 的整体工作很solid,没有使用骚里骚气的数据增强,那么我们就来看看它究竟在数据增强方面做了啥。

make_coco_transforms(i)

可以看到,真的是很“老土”!就是归一化、随机反转、缩放、裁剪,除此之外,没有了,可谓大道至简~

make_coco_transforms(ii)

另外,提及下,上图中的 T 是项目中的 datatsets/transforms.py 模块,以上各个数据增强的方法在该模块中的实现和 torchvision.transforms 中的差不多,其中ToTensor()会将图像的通道维度排列在第一个维度,并且像素值归一化到0-1范围内;而Normalize()则会根据指定的均值和标准差对图像进行归一化,同时将标签的bbox转换为c_{x} c_{y} wh形式后归一化到0-1,此处不再进行解析,感兴趣的可以去参考源码。


@最后

通常,很多项目在数据处理部分都会相对复杂,一方面固然是因为数据处理好了模型才能进行有效训练与学习,而另一方面则是为了适应任务需求而“不得已”处理成这样,其中还可能会使用到一些算法技巧,但是在 DETR中,真的太简单了,coco api 几乎搞定了一切,然后搞几个超级老土的 data augmentation,完事,666!

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
禁止转载,如需转载请通过简信或评论联系作者。
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,047评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,807评论 3 386
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,501评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,839评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,951评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,117评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,188评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,929评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,372评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,679评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,837评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,536评论 4 335
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,168评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,886评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,129评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,665评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,739评论 2 351