AnchorDETR代码学习I

主要是研究DETR的代码来理解transformer和目标检测的方式,另外可以学习torch的使用,站在巨人的肩膀学学优秀开源框架,当然主要还要比较Deformable 和Anchor 以及DETR看看如何改进的,对于理解原理和实验调参有帮助,乃至后期可能魔改自己的DETR部分。

首先要看下DETR的组成,再次基础上看Anchor DETR,横向对比Deformable DETR

DETR 结构

主要可以通过项目结构看出,有models就是模型部分,主要是backbone网络的定义,detr结构,matcher的计算部分,也就是二部图匹配算法,位置编码部分以及transformer编解码等部分

datasets就是数据处理部分的代码,主要是coco数据结构的数据,还有进行数据的transform增强等,另外也有coco评估的代码

util部分是通用的工具类和类别定义,比如包围盒计算,画图等功能

主要的函数 main.py engine.py 即函数主要的流程,包括每次的训练和验证过程等,函数的参数默认设置等部分。

detr.png

对于机器学习等代码我们除了了解结构可能更关注的是数据的流动和转换,所以先理解main和engine可以有大致的了解。

Deformable.png

Deformable-DETR 主要改变的是数据处理部分,增加了data_prefetcher和samplers

模型部分则改为deformable_transformer 和 Deformable detr

而Anchor DETR在Defomable的数据改变基础上 ,改变detr和增加了行列分离的注意力模块

AnchorDETR.png

3 main流程

开始就是各种通过输入得到的超参数和参数,也有默认值。

main.png

get_args_parser() 方法设置了用户可以指定的参数,这里传入args

接着设置输出路径,如果参数有设,则将结果输出到指定output_dir

接着就是main的流程,也即是训练的pipeline

先看这些各类的参数:

主要超参数


def get_args_parser():    
     parser = argparse.ArgumentParser('AnchorDETR Detector', add_help=False)    
     parser.add_argument('--lr', default=1e-4, type=float)    #模型的学习率    
     parser.add_argument('--lr_backbone_names', default=["backbone"], type=str, nargs='+')   #骨架网络    
    parser.add_argument('--lr_backbone', default=1e-5, type=float)    #骨架网络的学习率    
    parser.add_argument('--lr_linear_proj_names', default=[], type=str, nargs='+')  #线性映射的方法,deformable和Anchordert增加的,    # 可以是参考点也可以是采样偏置    
    parser.add_argument('--lr_linear_proj_mult', default=0.1, type=float) #线性投射权重,就是多大概率进行线性采样    
    parser.add_argument('--batch_size', default=1, type=int)    #每个gpu输入图像个数    
    parser.add_argument('--weight_decay', default=1e-4, type=float)   #权重衰减    
    parser.add_argument('--epochs', default=50, type=int)    #总训练代数    
    parser.add_argument('--lr_drop', default=40, type=int)  #开始降低学习率的代数    
    parser.add_argument('--lr_drop_epochs', default=None, type=int, nargs='+')   #学习率减少代    
    parser.add_argument('--clip_max_norm', default=0.1, type=float,   #梯度裁剪参数,大于零则会进行梯度裁剪                        help='gradient clipping max norm')
    parser.add_argument('--sgd', action='store_true')   #是否使用随机梯度下降法

weight_decay 是权重衰减,类似L2正则项惩罚,对于梯度下降计算式中每个权重,都用一个0到1的值相乘缩减。现多用weight_decay 而少用dropout

模型参数


    # Model parameters    
    parser.add_argument('--frozen_weights', type=str, default=None,   #是否固定住参数的权重,类似于迁移学习的微调                        #这里注释也是看出要给出预训练权重文件的路径,主要是为分割使用,只有mask head会训练                        
   help="Path to the pretrained model. If set, only the mask head will be trained")
    # Backbone 网络模型参数   
     parser.add_argument('--backbone', default='resnet50', type=str,   #卷积骨架网络                        
    help="Name of the convolutional backbone to use")    
    parser.add_argument('--dilation', default=True,   #卷积核膨胀,区分是否DC5                        
    help="If true, we replace stride with dilation in the last convolutional block (DC5)")    
    parser.add_argument('--num_feature_levels', default=1, type=int, help='number of feature levels')                        #特征层个数也是Deformable和anchordetr增加的

transformer 网络参数

  parser.add_argument('--enc_layers', default=6, type=int,  #编码层个数                        
help="Number of encoding layers in the transformer")    parser.add_argument('--dec_layers', default=6, type=int,   #解码层个数                        
help="Number of decoding layers in the transformer")    parser.add_argument('--dim_feedforward', default=1024, type=int,   #前馈层维度                        
help="Intermediate size of the feedforward layers in the transformer blocks")    
parser.add_argument('--hidden_dim', default=256, type=int,      #隐藏层维度                       
 help="Size of the embeddings (dimension of the transformer)")    parser.add_argument('--dropout', default=0., type=float,    #对于神经网络单元,按照一定的概率将其暂时从网络中丢弃,避免过拟合。                        
help="Dropout applied in the transformer")    parser.add_argument('--nheads', default=8, type=int,   #检测头个数                        
help="Number of attention heads inside the transformer's attentions")    
parser.add_argument('--num_query_position', default=300, type=int,  #查询位置个数,即多少个目标框                        help="Number of query positions")    
parser.add_argument('--num_query_pattern', default=3, type=int,   #查询模式个数,anchor DETR特有                        
help="Number of query patterns")    
parser.add_argument('--spatial_prior', default='learned', choices=['learned', 'grid'], #使用空间位置偏好,学习的还是网格                        type=str,help="Number of query patterns")    parser.add_argument('--attention_type',  #注意力机制                        # default='nn.MultiheadAttention',                        
default="RCDA",   #anchor detr的行列分离注意力                        choices=['RCDA', 'nn.MultiheadAttention'],                        type=str,help="Type of attention module")                           # Segmentation 分割特有参数    
parser.add_argument('--masks', action='store_true',                        help="Train segmentation head if the flag is provided")

Anchor DETR主要增加的就是num_query_pattern 也就是查询的几种模式,可以适应多个尺度,

attention_type 表示注意力的类别,detr的是多头注意力,而Anchor DETR是使用了RCDA 行列分离的注意力。其他基本和detr一致。

接着是最为重要的也就是损失函数,匹配函数的参数

# Loss     
    parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',  #是否在解码器中使用辅助损耗                        help="Disables auxiliary decoding losses (loss at each layer)")
    # * Matcher 就是二部图匹配的部分    
parser.add_argument('--set_cost_class', default=2, type=float,   #类别损失权重                        
help="Class coefficient in the matching cost")    parser.add_argument('--set_cost_bbox', default=5, type=float,    #包围盒L1 损失函数权重                        
help="L1 box coefficient in the matching cost")    parser.add_argument('--set_cost_giou', default=2, type=float,  #giou权重                        
help="giou box coefficient in the matching cost")
    # * Loss coefficients  损失函数系数    
parser.add_argument('--mask_loss_coef', default=1, type=float)   #mask 分割使用    
parser.add_argument('--dice_loss_coef', default=1, type=float)  #dice_loss 轮廓区域的损失 分割使用    
parser.add_argument('--cls_loss_coef', default=2, type=float)  
#类别损失    
parser.add_argument('--bbox_loss_coef', default=5, type=float)  
#包围盒损失    
parser.add_argument('--giou_loss_coef', default=2, type=float)  #giou损失    
parser.add_argument('--focal_alpha', default=0.25, type=float)   #Focal Loss的, 解决难易样本数量不平衡

主要有类别损失,包围盒损失和GIou的损失函数。L1 Loss整体不如Giou

数据集和训练相关的超参数

 # dataset parameters    
parser.add_argument('--dataset_file', default='coco')   #数据集文件类型    
parser.add_argument('--coco_path', default='/data/coco', type=str)  #检测数据集路径    
parser.add_argument('--coco_panoptic_path', type=str)   #全景分割   
 parser.add_argument('--remove_difficult', action='store_true')  #是否移除difficult    parser.add_argument('--output_dir', default='/data/detr-workdir/r50-dc5',  #默认模型输出位置                        help='path where to save, empty for no saving')    parser.add_argument('--device', default='cuda',   #设备,默认是cuda                        
help='device to use for training / testing')    
parser.add_argument('--seed', default=42, type=int)  #随机种子    parser.add_argument('--resume', default='', help='resume from checkpoint')  #模型权重位置,从哪个检测点继续开始训练    parser.add_argument('--auto_resume', default=False, action='store_true', help='whether to resume from last checkpoint')    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',  #训练开始代数                        
help='start epoch')    
parser.add_argument('--eval', action='store_true')    #是否验证评估    parser.add_argument('--num_workers', default=2, type=int) #线程数    
parser.add_argument('--cache_mode', default=False, action='store_true', help='whether to cache images on memory') #是否保持图片到缓存

dataset_file 指定了数据操作的py文件,coco就是说明是coco.py

coco_path 指定数据集位置,一定要和指定的格式一致。

接下来就是main函数流程


 def main(args):
    #分布式训练配置设置,根据用户参数选择是否进行分布式训练
    #主要是rank 节点id 和world_size有多少个节点
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha())) #获取远端git的sha从而得出文件库和分支版本

    if args.frozen_weights is not None: #如果froze_weights有设置开始分割任务,验证数据格式是否有适合
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)   #分配设备

    # fix the seed for reproducibility   为了实验复现,固定每次的随机种子,从而是训练所有随机数固定
    seed = args.seed + utils.get_rank()   #通过用户设置seed 和 机架编号固定初始随机种子
    torch.manual_seed(seed)  #设置神经网络等随机初始化的随机种子
    np.random.seed(seed)  #numpy库随机生成使用随机种子
    random.seed(seed)  #用生成的seed生成随机数使用

    model, criterion, postprocessors = build_model(args)  #初始化模型,损失函数类和后处理
    model.to(device)  #模型从内存存入设备,CPU或GPU

    #保持副本用于非ddp 就是 DistributedDataParallel
    model_without_ddp = model
    #输出打印模型参数个数,numel返回数组中元素的个数
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

初始化后接着是数据载入

  #通过文件名构建训练集验证集
    dataset_train = build_dataset(image_set='train', args=args)
    dataset_val = build_dataset(image_set='val', args=args)

    if args.distributed:
        if args.cache_mode: #缓存模式
            sampler_train = samplers.NodeDistributedSampler(dataset_train)
            sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
        else:
            sampler_train = samplers.DistributedSampler(dataset_train)
            sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(
        sampler_train, args.batch_size, drop_last=True)

    data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn, num_workers=args.num_workers,
                                   pin_memory=True)
    data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers,
                                 pin_memory=True)



这里先是通过定义的build_dataset 从文件名生成图片数据集,这里构造数据集也包括调用了coco数据集的API,并经过变换等预处理,所以是coco格式。但是对于大批量的数据torch需要进行shuffle和batch化,这里就是用了sampler类方法。

PyTorch中还单独提供了一个sampler模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。

这里训练集就是随机采样,而验证集每次都是全量所有是顺序采样。

最后使用pytorch的dataloader方法具体实现打乱和批次化到实时计算

dataset:加载的数据集(Dataset对象)

batch_size:batch sizeshuffle::是否将数据打乱

sampler: 样本抽样,

num_workers:使用多进程加载的进程数,0代表不使用多进程

collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可

pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些

drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

这里特别的是collate_fn,是utils/misc.py里定义的 collate_fn 方法来重新组装一个batch的数据

通过参数名将参数区分,方便分别调整学习率,这里比DETR 里多了线性变换的参数lr_linear_proj_names,另外就是backbone ,anchor DETR,lr_linear_proj_names 也就是骨架网络,模型网络和线性网络部分。

 def match_name_keywords(n, name_keywords):   #通过名称匹配参数
        out = False
        for b in name_keywords:
            if b in n:
                out = True
                break
        return out

    for n, p in model_without_ddp.named_parameters():   #输出参数名
        print(n)

    param_dicts = [
        {
            "params":
                [p for n, p in model_without_ddp.named_parameters()
                 if not match_name_keywords(n, args.lr_backbone_names) and not match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
            "lr": args.lr,
        },
        {
            "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad],
            "lr": args.lr_backbone,
        },
        {
            "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
            "lr": args.lr * args.lr_linear_proj_mult,
        }
    ]

其他训练超参数

  if args.sgd:   #是否使用随机梯度
        optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9,
                                    weight_decay=args.weight_decay)
    else:  
        optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                      weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    if args.distributed:  #分布式
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.dataset_file == "coco_panoptic":  #是否是全景分割
        # We also evaluate AP during panoptic training, on original coco DS
        coco_val = datasets.coco.build("val", args)
        base_ds = get_coco_api_from_dataset(coco_val)
    else:
        base_ds = get_coco_api_from_dataset(dataset_val)

    if args.frozen_weights is not None: #是否有固定参数,有从检查点加载参数
        checkpoint = torch.load(args.frozen_weights, map_location='cpu')
        model_without_ddp.detr.load_state_dict(checkpoint['model'])

    output_dir = Path(args.output_dir)   #模型输出路径

resume 主要是从训练的某个阶段恢复过来,加载某个检测点的模型参数,学习率和优化器参数等。

     if args.auto_resume:   #自动恢复,默认的checkpoint文件名
        if not args.resume:
            args.resume = os.path.join(args.output_dir, 'checkpoint.pth')
        if not os.path.isfile(args.resume):
            args.resume=''

    if args.resume:
        if args.resume.startswith('https'):   #根据路径名判断是从网络获取模型还是本地获取
            checkpoint = torch.hub.load_state_dict_from_url(   
                args.resume, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
            #获取缺失和不正确的参数名
        missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
        unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
        if len(missing_keys) > 0:
            print('Missing Keys: {}'.format(missing_keys))
        if len(unexpected_keys) > 0:
            print('Unexpected Keys: {}'.format(unexpected_keys))
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:   #训练则需要加载学习率和优化器
            import copy
            p_groups = copy.deepcopy(optimizer.param_groups)
            optimizer.load_state_dict(checkpoint['optimizer'])
            for pg, pg_old in zip(optimizer.param_groups, p_groups):
                pg['lr'] = pg_old['lr']
                pg['initial_lr'] = pg_old['initial_lr']

            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
            args.override_resumed_lr_drop = True   #更新lr_drop参数
            if args.override_resumed_lr_drop:
                print('Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler.')
                lr_scheduler.step_size = args.lr_drop
                lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
            lr_scheduler.step(lr_scheduler.last_epoch)
            args.start_epoch = checkpoint['epoch'] + 1   #在检测的代数基础上开始加1
        # check the resumed model
        if not args.eval:    #构建验证评估类
            test_stats, coco_evaluator = evaluate(
                model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
            )

测评部分,如果设置了只测试不训练

  if args.eval:
        test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
                                              data_loader_val, base_ds, device, args.output_dir)
        if args.output_dir:  #输出评估结果
            utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
        return

最后是每一代的训练过程


    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(  #每代的训练
            model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm)
        lr_scheduler.step()   #更新学习率
        if args.output_dir:   #是否输出模型参数
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:  #超过学习率开始下降的代或每100代保存一个模型
                checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master({
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'args': args,
                }, checkpoint_path)

        test_stats, coco_evaluator = evaluate(
            model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
        )  #评估模型

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch,
                     'n_parameters': n_parameters}
        
        print(args.output_dir)  #输出日志
        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

            # for evaluation logs
            if coco_evaluator is not None:
                (output_dir / 'eval').mkdir(exist_ok=True)
                if "bbox" in coco_evaluator.coco_eval:
                    filenames = ['latest.pth']
                    if epoch % 50 == 0:
                        filenames.append(f'{epoch:03}.pth')
                    for name in filenames:
                        torch.save(coco_evaluator.coco_eval["bbox"].eval,
                                   output_dir / "eval" / name)

最后打印输出和统计时间等,main的流程就完成了。
只是学习记录
大致了解了整体过程后,就可以对每个部分进一步深入学习。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,752评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,100评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,244评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,099评论 1 286
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,210评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,307评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,346评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,133评论 0 269
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,546评论 1 306
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,849评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,019评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,702评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,331评论 3 319
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,030评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,260评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,871评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,898评论 2 351

推荐阅读更多精彩内容