RetinaNet: Focal loss在目标检测网络中的应用

介绍

RetinaNet是2018年Facebook AI团队在目标检测领域新的贡献。它的重要作者名单中Ross Girshick与Kaiming He赫然在列。来自Microsoft的Sun Jian团队与现在Facebook的Ross/Kaiming团队在当前视觉目标分类、检测领域有着北乔峰、南慕容一般的独特地位。这两个实验室的文章多是行业里前进方向的提示牌。

RetinaNet只是原来FPN网络与FCN网络的组合应用,因此在目标网络检测框架上它并无特别亮眼创新。文章中最大的创新来自于Focal loss的提出及在单阶段目标检测网络RetinaNet(实质为Resnet + FPN + FCN)的成功应用。Focal loss是一种改进了的交叉熵(cross-entropy, CE)loss,它通过在原有的CE loss上乘了个使易检测目标对模型训练贡献削弱的指数式,从而使得Focal loss成功地解决了在目标检测时,正负样本区域极不平衡而目标检测loss易被大批量负样本所左右的问题。此问题是单阶段目标检测框架(如SSD/Yolo系列)与双阶段目标检测框架(如Faster-RCNN/R-FCN等)accuracy gap的最大原因。在Focal loss提出之前,已有的目标检测网络都是通过像Boot strapping/Hard example mining等方法来解决此问题的。作者通过后续实验成功表明Focal loss可在单阶段目标检测网络中成功使用,并最终能以更快的速率实现与双阶段目标检测网络近似或更优的效果。

类别不平衡问题

常规的单阶段目标检测网络像SSD一般在模型训练时会先大密度地在模型终端的系列feature maps上生成出10,000甚至100,0000个目标候选区域。然后再分别对这些候选区域进行分类与位置回归识别。而在这些生成的数万个候选区域中,绝大多数都是不包含待检测目标的图片背景,这样就造成了机器学习中经典的训练样本正负不平衡的问题。它往往会造成最终算出的training loss为占绝对多数但包含信息量却很少的负样本所支配,少样正样本提供的关键信息却不能在一般所用的training loss中发挥正常作用,从而无法得出一个能对模型训练提供正确指导的loss。

常用的解决此问题的方法就是负样本挖掘。或其它更复杂的用于过滤负样本从而使正负样本数维持一定比率的样本取样方法。而在此篇文章中作者提出了可通过候选区域包含潜在目标概率进而对最终的training loss进行较正的方法。实验表明这种新提出的focal loss在单阶段目标检测任务上表现突出,有效地解决了此领域里面潜在的类别不平衡问题。

Focal loss

  • CE(cross-entropy) loss

以下为典型的交叉熵loss,它广泛用于当下的图像分类、检测CNN网络当中。

Cross-entropy_loss
  • Balanced CE loss

考虑到上节中提到的类别不平衡问题对最终training loss的不利影响,我们自然会想到可通过在loss公式中使用与目标存在概率成反比的系数对其进行较正。如下公式即是此朴素想法的体现。它也是作者最终Focus loss的baseline。

Balanced_CE-loss
  • Focal loss定义

以下是作者提出的focal loss的想法。

Focal_loss定义

下图为focal loss与常规CE loss的对比。从中,我们易看出focal loss所加的指数式系数可对正负样本对loss的贡献自动调节。当某样本类别比较明确些,它对整体loss的贡献就比较少;而若某样本类别不易区分,则对整体loss的贡献就相对偏大。这样得到的loss最终将集中精力去诱导模型去努力分辨那些难分的目标类别,于是就有效提升了整体的目标检测准度。不过在此focus loss计算当中,我们引入了一个新的hyper parameter即γ。一般来说新参数的引入,往往会伴随着模型使用难度的增加。在本文中,作者有试者对其进行调节,线性搜索后得出将γ设为2时,模型检测效果最好。

几种loss的对比

在最终所用的focal loss上,作者还引入了α系数,它能够使得focal loss对不同类别更加平衡。实验表明它会比原始的focal loss效果更好。

最终所用的Focal_loss

模型的初始化参数选择

一般我们初始化CNN网络模型时都会使用无偏的参数对其初始化,比如Conv的kernel 参数我们会以bias 为0,variance为0.01的某分布来对其初始化。但是如果我们的模型要去处理类别极度不平衡的情况,那么就会考虑到这样对训练数据分布无任选先验假设的初始化会使得在训练过程中,我们的参数更偏向于拥有更多数量的负样本的情况去进化。作者观察下来发现它在训练时会出现极度的不稳定。于是作者在初始化模型最后一层参数时考虑了数据样本分布的不平衡性,这样使得初始训练时最终得出的loss不会对过多的负样本数量所惊讶到,从而有效地规避了初始训练时模型的震荡与不稳定。

RetinaNet检测框架

RetinaNet本质上是Resnet + FPN + 两个FCN子网络。
以下为RetinaNet目标框架框架图。有了之前blog里面提到的FPN与FCN的知识后,我们很容易理解此框架的设计含义。

RetinaNet目标检测框架

一般主干网络可选用任一有效的特征提取网络如vgg16或resnet系列,此处作者分别尝试了resnet-50与resnet-101。而FPN则是对resnet-50里面自动形成的多尺度特征进行了强化利用,从而得到了表达力更强、包含多尺度目标区域信息的feature maps集合。最后在FPN所吐出的feature maps集合上,分别使用了两个FCN子网络(它们有着相同的网络结构却各自独立,并不share参数)用来完成目标框类别分类与位置回归任务。

模型的推理与训练

  • 模型推理

一旦我们有了训练好的模型,在正式部署时,只需对其作一次forward,然后对最终生成的目标区域进行过渡。然后只对每个FPN level上目标存在概率最高的前1000个目标框进一步地decoding处理。接下来再将所有FPN level上得到的目标框汇集起来,统一使用极大值抑制的方法进一步过渡(其中极大值抑制时所用的阈值为0.5)。这样,我们就得到了最终的目标与其位置框架。

  • 模型训练

模型训练中主要在后端Loss计算时采用了Focal loss,另外也在模型初始化时考虑到了正负样本极度不平衡的情况进而对模型最后一个conv layer的bias参数作了有偏初始化。

训练时用了SGD,mini batch size为16,在8个GPU上一块训练,每个GPU上local batch size为2。最大iterations数目为90,000;模型初始lr为0.01,接下来随着训练进行分step wisely 降低。真正的training loss则为表达目标类别的focus loss与表达目标框位置回归信息的L1 loss的和。

下图为RetinaNet模型的检测准度与性能。

RetinaNet的检测准度与性能

代码实例

以下函数用于从FPN的各个level的feature maps上提取各种scale的anchor box。

def _create_cell_anchors():
    """
    Generate all types of anchors for all fpn levels/scales/aspect ratios.
    This function is called only once at the beginning of inference.
    """
    k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
    scales_per_octave = cfg.RETINANET.SCALES_PER_OCTAVE
    aspect_ratios = cfg.RETINANET.ASPECT_RATIOS
    anchor_scale = cfg.RETINANET.ANCHOR_SCALE
    A = scales_per_octave * len(aspect_ratios)
    anchors = {}
    for lvl in range(k_min, k_max + 1):
        # create cell anchors array
        stride = 2. ** lvl
        cell_anchors = np.zeros((A, 4))
        a = 0
        for octave in range(scales_per_octave):
            octave_scale = 2 ** (octave / float(scales_per_octave))
            for aspect in aspect_ratios:
                anchor_sizes = (stride * octave_scale * anchor_scale, )
                anchor_aspect_ratios = (aspect, )
                cell_anchors[a, :] = generate_anchors(
                    stride=stride, sizes=anchor_sizes,
                    aspect_ratios=anchor_aspect_ratios)
                a += 1
        anchors[lvl] = cell_anchors
    return anchors

下面函数则描述了如何使用train好的RetinaNet来进行图片目标检测。

def im_detect_bbox(model, im, timers=None):
    """Generate RetinaNet detections on a single image."""
    if timers is None:
        timers = defaultdict(Timer)
    # Although anchors are input independent and could be precomputed,
    # recomputing them per image only brings a small overhead
    anchors = _create_cell_anchors()
    timers['im_detect_bbox'].tic()
    k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
    A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS)
    inputs = {}
    inputs['data'], im_scale, inputs['im_info'] = \
        blob_utils.get_image_blob(im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)
    cls_probs, box_preds = [], []
    for lvl in range(k_min, k_max + 1):
        suffix = 'fpn{}'.format(lvl)
        cls_probs.append(core.ScopedName('retnet_cls_prob_{}'.format(suffix)))
        box_preds.append(core.ScopedName('retnet_bbox_pred_{}'.format(suffix)))
    for k, v in inputs.items():
        workspace.FeedBlob(core.ScopedName(k), v.astype(np.float32, copy=False))

    workspace.RunNet(model.net.Proto().name)
    cls_probs = workspace.FetchBlobs(cls_probs)
    box_preds = workspace.FetchBlobs(box_preds)

    # here the boxes_all are [x0, y0, x1, y1, score]
    boxes_all = defaultdict(list)

    cnt = 0
    for lvl in range(k_min, k_max + 1):
        # create cell anchors array
        stride = 2. ** lvl
        cell_anchors = anchors[lvl]

        # fetch per level probability
        cls_prob = cls_probs[cnt]
        box_pred = box_preds[cnt]
        cls_prob = cls_prob.reshape((
            cls_prob.shape[0], A, int(cls_prob.shape[1] / A),
            cls_prob.shape[2], cls_prob.shape[3]))
        box_pred = box_pred.reshape((
            box_pred.shape[0], A, 4, box_pred.shape[2], box_pred.shape[3]))
        cnt += 1

        if cfg.RETINANET.SOFTMAX:
            cls_prob = cls_prob[:, :, 1::, :, :]

        cls_prob_ravel = cls_prob.ravel()
        # In some cases [especially for very small img sizes], it's possible that
        # candidate_ind is empty if we impose threshold 0.05 at all levels. This
        # will lead to errors since no detections are found for this image. Hence,
        # for lvl 7 which has small spatial resolution, we take the threshold 0.0
        th = cfg.RETINANET.INFERENCE_TH if lvl < k_max else 0.0
        candidate_inds = np.where(cls_prob_ravel > th)[0]
        if (len(candidate_inds) == 0):
            continue

        pre_nms_topn = min(cfg.RETINANET.PRE_NMS_TOP_N, len(candidate_inds))
        inds = np.argpartition(
            cls_prob_ravel[candidate_inds], -pre_nms_topn)[-pre_nms_topn:]
        inds = candidate_inds[inds]

        inds_5d = np.array(np.unravel_index(inds, cls_prob.shape)).transpose()
        classes = inds_5d[:, 2]
        anchor_ids, y, x = inds_5d[:, 1], inds_5d[:, 3], inds_5d[:, 4]
        scores = cls_prob[:, anchor_ids, classes, y, x]

        boxes = np.column_stack((x, y, x, y)).astype(dtype=np.float32)
        boxes *= stride
        boxes += cell_anchors[anchor_ids, :]

        if not cfg.RETINANET.CLASS_SPECIFIC_BBOX:
            box_deltas = box_pred[0, anchor_ids, :, y, x]
        else:
            box_cls_inds = classes * 4
            box_deltas = np.vstack(
                [box_pred[0, ind:ind + 4, yi, xi]
                 for ind, yi, xi in zip(box_cls_inds, y, x)]
            )
        pred_boxes = (
            box_utils.bbox_transform(boxes, box_deltas)
            if cfg.TEST.BBOX_REG else boxes)
        pred_boxes /= im_scale
        pred_boxes = box_utils.clip_tiled_boxes(pred_boxes, im.shape)
        box_scores = np.zeros((pred_boxes.shape[0], 5))
        box_scores[:, 0:4] = pred_boxes
        box_scores[:, 4] = scores

        for cls in range(1, cfg.MODEL.NUM_CLASSES):
            inds = np.where(classes == cls - 1)[0]
            if len(inds) > 0:
                boxes_all[cls].extend(box_scores[inds, :])
    timers['im_detect_bbox'].toc()

    # Combine predictions across all levels and retain the top scoring by class
    timers['misc_bbox'].tic()
    detections = []
    for cls, boxes in boxes_all.items():
        cls_dets = np.vstack(boxes).astype(dtype=np.float32)
        # do class specific nms here
        keep = box_utils.nms(cls_dets, cfg.TEST.NMS)
        cls_dets = cls_dets[keep, :]
        out = np.zeros((len(keep), 6))
        out[:, 0:5] = cls_dets
        out[:, 5].fill(cls)
        detections.append(out)

    # detections (N, 6) format:
    #   detections[:, :4] - boxes
    #   detections[:, 4] - scores
    #   detections[:, 5] - classes
    detections = np.vstack(detections)
    # sort all again
    inds = np.argsort(-detections[:, 4])
    detections = detections[inds[0:cfg.TEST.DETECTIONS_PER_IM], :]

    # Convert the detections to image cls_ format (see core/test_engine.py)
    num_classes = cfg.MODEL.NUM_CLASSES
    cls_boxes = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
    for c in range(1, num_classes):
        inds = np.where(detections[:, 5] == c)[0]
        cls_boxes[c] = detections[inds, :5]
    timers['misc_bbox'].toc()

    return cls_boxes

以下为RetinaNet中training loss的具体计算。可以看出它包含了两个部分分别为反映位置信息的L1 loss与反映类别信息的focus loss。

def add_fpn_retinanet_losses(model):
    loss_gradients = {}
    gradients, losses = [], []

    k_max = cfg.FPN.RPN_MAX_LEVEL  # coarsest level of pyramid
    k_min = cfg.FPN.RPN_MIN_LEVEL  # finest level of pyramid

    model.AddMetrics(['retnet_fg_num', 'retnet_bg_num'])
    # ==========================================================================
    # bbox regression loss - SelectSmoothL1Loss for multiple anchors at a location
    # ==========================================================================
    for lvl in range(k_min, k_max + 1):
        suffix = 'fpn{}'.format(lvl)
        bbox_loss = model.net.SelectSmoothL1Loss(
            [
                'retnet_bbox_pred_' + suffix,
                'retnet_roi_bbox_targets_' + suffix,
                'retnet_roi_fg_bbox_locs_' + suffix, 'retnet_fg_num'
            ],
            'retnet_loss_bbox_' + suffix,
            beta=cfg.RETINANET.BBOX_REG_BETA,
            scale=model.GetLossScale() * cfg.RETINANET.BBOX_REG_WEIGHT
        )
        gradients.append(bbox_loss)
        losses.append('retnet_loss_bbox_' + suffix)

    # ==========================================================================
    # cls loss - depends on softmax/sigmoid outputs
    # ==========================================================================
    for lvl in range(k_min, k_max + 1):
        suffix = 'fpn{}'.format(lvl)
        cls_lvl_logits = 'retnet_cls_pred_' + suffix
        if not cfg.RETINANET.SOFTMAX:
            cls_focal_loss = model.net.SigmoidFocalLoss(
                [
                    cls_lvl_logits, 'retnet_cls_labels_' + suffix,
                    'retnet_fg_num'
                ],
                ['fl_{}'.format(suffix)],
                gamma=cfg.RETINANET.LOSS_GAMMA,
                alpha=cfg.RETINANET.LOSS_ALPHA,
                scale=model.GetLossScale(),
                num_classes=model.num_classes - 1
            )
            gradients.append(cls_focal_loss)
            losses.append('fl_{}'.format(suffix))
        else:
            cls_focal_loss, gated_prob = model.net.SoftmaxFocalLoss(
                [
                    cls_lvl_logits, 'retnet_cls_labels_' + suffix,
                    'retnet_fg_num'
                ],
                ['fl_{}'.format(suffix), 'retnet_prob_{}'.format(suffix)],
                gamma=cfg.RETINANET.LOSS_GAMMA,
                alpha=cfg.RETINANET.LOSS_ALPHA,
                scale=model.GetLossScale(),
                num_classes=model.num_classes
            )
            gradients.append(cls_focal_loss)
            losses.append('fl_{}'.format(suffix))

    loss_gradients.update(blob_utils.get_loss_gradients(model, gradients))
    model.AddLosses(losses)
    return loss_gradients

参考文献

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,185评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,445评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,684评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,564评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,681评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,874评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,025评论 3 408
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,761评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,217评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,545评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,694评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,351评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,988评论 3 315
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,778评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,007评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,427评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,580评论 2 349

推荐阅读更多精彩内容

  • 1. 简介 在本篇文章中, 我会详细的记录我学习MPEG2TSExtractor的全部过程. 与其说是一篇技术博客...
    TankWitch阅读 3,691评论 1 5
  • 0 工作中遇到很多事,医院里的事,大部分都不会很好看,都不会很美好,在生死面前,可以看到人性最美好的一面,也...
    mok是个精神病患阅读 243评论 0 0
  • 今天更深入地了解了李笑来老师说的“元认知能力”,有了更为深刻的领悟,也让自己的操作系统不断迭新。从而更为深...
    巧青阅读 168评论 0 0
  • 初春 呼出春风到,风来众树摇...
    晨光雪影阅读 124评论 0 1
  • LLDB调试命令初探 如果你在平时的开发中从未使用过调试器,那你恐怕不知道一个调试器的作用有多大。你可能只满足于通...
    欢博阅读 275评论 0 0