结合源码分析YOLOv3的训练过程(二)

在上篇博客中https://www.jianshu.com/p/2f89e74b9b3c
介绍了YOLOv3的网络模型及前向传播过程,知道了网络在不同层的feature map进行预测以获得对大、中、小型目标的检测。对于一幅图片最终会预产生10647个anchor box,得到维度为[1,10647,85]的输出。但是并没有说明YOLOv3的损失函数及训练过程。在论文中也没有给出损失函数的公式,需要通过源码去分析损失函数及训练过程。下图为根据源码得到的损失函数。

YOLOv3的损失函数

由于上篇博客中给出的代码没有包含训练的部分,所以给出YOLOv3含训练的pytorch实现:https://github.com/eriklindernoren/PyTorch-YOLOv3

YOLOLayer检测层

YOLOLayer负责在13x13,26x26,52x52这三层feature map上进行预测,可以看到其forward前向传播的代码。若为测试集则直接返回预测,否则返回该层的预测和损失值。其中重要的方法为build_targets将在下面讲解。

        if targets is None:  # 测试集则直接返回预测
            return output, 0
        else:  #训练集返回预测和损失
            iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf = build_targets(
                pred_boxes=pred_boxes,
                pred_cls=pred_cls,
                target=targets,
                anchors=self.scaled_anchors,
                ignore_thres=self.ignore_thres,
            )

            # Loss : Mask outputs to ignore non-existing objects (except with conf. loss)
            loss_x = self.mse_loss(x[obj_mask], tx[obj_mask])
            loss_y = self.mse_loss(y[obj_mask], ty[obj_mask])
            loss_w = self.mse_loss(w[obj_mask], tw[obj_mask])
            loss_h = self.mse_loss(h[obj_mask], th[obj_mask])
            loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask])
            loss_conf_noobj = self.bce_loss(pred_conf[noobj_mask], tconf[noobj_mask])
            loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj
            loss_cls = self.bce_loss(pred_cls[obj_mask], tcls[obj_mask])
            total_loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls
            return output, total_loss

以在13x13的feature map上计算为例。预测bounding box的pred_boxes维度为[1,3,13,13,4],预测类别的pred_cls维度为[1,3,13,13,80],假设图片中有两个检测目标,target维度为[2,6],这6维分别表示[img_index,cls,x1,y1,x2,y2],第一个数表示读入图片的编号,anchors为该层anchor box的大小,由于每层有3个不同比例大小的anchor box,所以anchors维度为[3,2],ignore_thres为阈值,设定为0.5。

区分是否含有object的anchor box

def build_targets(pred_boxes, pred_cls, target, anchors, ignore_thres):
    ByteTensor = torch.cuda.ByteTensor if pred_boxes.is_cuda else torch.ByteTensor
    FloatTensor = torch.cuda.FloatTensor if pred_boxes.is_cuda else torch.FloatTensor

    nB = pred_boxes.size(0)  # 样本数量
    nA = pred_boxes.size(1)  # 通道数
    nC = pred_cls.size(-1)  # 预测类别数量
    nG = pred_boxes.size(2)  # 单元格尺寸(13x13,26x26,52x52)

    # Output tensors
    obj_mask = ByteTensor(nB, nA, nG, nG).fill_(0)
    noobj_mask = ByteTensor(nB, nA, nG, nG).fill_(1)
    class_mask = FloatTensor(nB, nA, nG, nG).fill_(0)
    iou_scores = FloatTensor(nB, nA, nG, nG).fill_(0)
    tx = FloatTensor(nB, nA, nG, nG).fill_(0)
    ty = FloatTensor(nB, nA, nG, nG).fill_(0)
    tw = FloatTensor(nB, nA, nG, nG).fill_(0)
    th = FloatTensor(nB, nA, nG, nG).fill_(0)
    tcls = FloatTensor(nB, nA, nG, nG, nC).fill_(0)

    # Convert to position relative to box
    target_boxes = target[:, 2:6] * nG
    gxy = target_boxes[:, :2]
    gwh = target_boxes[:, 2:]
    # Get anchors with best iou
    ious = torch.stack([bbox_wh_iou(anchor, gwh) for anchor in anchors])
    best_ious, best_n = ious.max(0)
    # Separate target values
    b, target_labels = target[:, :2].long().t()
    gx, gy = gxy.t()
    gw, gh = gwh.t()
    gi, gj = gxy.long().t()
    # Set masks
    obj_mask[b, best_n, gj, gi] = 1
    noobj_mask[b, best_n, gj, gi] = 0

    # Set noobj mask to zero where iou exceeds ignore threshold
    for i, anchor_ious in enumerate(ious.t()):
        noobj_mask[b[i], anchor_ious > ignore_thres, gj[i], gi[i]] = 0

    # Coordinates
    tx[b, best_n, gj, gi] = gx - gx.floor()
    ty[b, best_n, gj, gi] = gy - gy.floor()
    # Width and height
    tw[b, best_n, gj, gi] = torch.log(gw / anchors[best_n][:, 0] + 1e-16)
    th[b, best_n, gj, gi] = torch.log(gh / anchors[best_n][:, 1] + 1e-16)
    # One-hot encoding of label
    tcls[b, best_n, gj, gi, target_labels] = 1
    # Compute label correctness and iou at best anchor
    class_mask[b, best_n, gj, gi] = (pred_cls[b, best_n, gj, gi].argmax(-1) == target_labels).float()
    iou_scores[b, best_n, gj, gi] = bbox_iou(pred_boxes[b, best_n, gj, gi], target_boxes, x1y1x2y2=False)

    tconf = obj_mask.float()
    return iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf

一、将真实框target的大小映射到13x13的feature map上,获得对应的坐标和大小(Convert to position relative to box)。


target在feature map上的位置和大小

二、由于对每层featre map都有3个不同比例大小的anchor box,所以我们要选择与target形状最接近的anchor box,评判指标就是IOU值。下图中最接近的target就是anchor 2了(Get anchors with best iou)。


选择最佳的anchor box

三、找到包含target的单元格,因为target的中心点坐标xy通常不是整数,所以通过取整找到对应的单元格,例如[6.2502, 6.4846],对应的单元格为[6,6],且每个单元格有3个anchor box,将第二步中找到的最佳anchor box设置为obj_mask,表示其含有目标值。将下图中的obj_mask[0,1,6,6]设置为1(Set masks)。
找到对应的单元格并找到最佳的anchor box

四、对应单元格最佳anchor box设置坐标和宽高的训练目标值。对于坐标值的计算,代码中直接将下图中bx-cx作为tx的回归目标值。( Coordinates,Width and height)。


转化公式

五、对应单元格最佳anchor box的label进行one-hot编码,将其cls类别标为1。
六、计算该位置预测的类别pred_cls是否与真实类别target_labels相同,以及预测框pred_box与真实框target_boxes的IOU值。(Compute label correctness and iou at best anchor),该结果主要用于之后的模型评估阶段,与loss计算没有关系。

计算总损失

在13x13大小的feature map上执行完build_targets方法,我们可以从13x13x3个anchor box中找出包含object且与target box最匹配(两者IOU值最高)的anchor box,其余的anchor box为no object。
对于含有object的anchor box计算其与target的坐标xy,宽高kw的损失值,损失函数为均方误差。

loss_x = self.mse_loss(x[obj_mask], tx[obj_mask])
loss_y = self.mse_loss(y[obj_mask], ty[obj_mask])
loss_w = self.mse_loss(w[obj_mask], tw[obj_mask])
loss_h = self.mse_loss(h[obj_mask], th[obj_mask])

使用交叉熵函数计算其与target的含有目标的置信度的损失值。对于真实情况tconf只有0,1两种值,因为真实情况只有存在目标和不存在目标两种情况,而pred_conf是一个存在目标的概率值。

loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask])

同样使用交叉熵函数计算与target各个类别置信度的损失值。对于target来说,其class为one-hot编码,因为真实情况下,object只属于80种类别中的一种。

loss_cls = self.bce_loss(pred_cls[obj_mask], tcls[obj_mask])

对于不含object的anchor box只需计算与target的含有目标的置信度的损失值。

loss_conf_noobj = self.bce_loss(pred_conf[noobj_mask], tconf[noobj_mask])

对于目标置信度的损失,含object的anchor box与不含object的anchor box所占权重不同。代码中将object_scale设为1,noobject_scale设为100。通过不断下降loss,使那些不含object而认为自己含有object的anchor(表现为pred_conf预测值很大)数量大大减少。

loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj

单层featrue map的总损失为将上述各损失值相加。

total_loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls

最终网络的总损失是将各层的feature map相加得到loss,最后通过优化函数不断减少loss来训练网络,得到最终的模型参数。
参考博客:https://blog.csdn.net/qq_34795071/article/details/92803741

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,496评论 6 501
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,407评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,632评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,180评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,198评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,165评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,052评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,910评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,324评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,542评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,711评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,424评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,017评论 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,668评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,823评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,722评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,611评论 2 353