Date: 2020/05/19
Author: CW
前言:
EAST 的损失函数由三部分构成,对应预测输出的三个map——score map、loc map 以及 angle map,即分类损失,位置(点到文本框边界上下左右的距离)损失以及角度损失。
分类损失
对于分类损失,最自然想到的就是交叉熵。在这里,由于在文本框外的点会占多数,即负样本比例较大,因此,可采用类别平衡的交叉熵损失。
在代码实现中,以上有个特别需要注意的地方,就是模型输出的预测结果 pred_score 是从sigmoid 出来的,那么其取值范围为[0, 1],所以用在交叉熵时有可能使得 log 函数输出正/负无穷,因此需要加上一个极小值,这里使用的是 np.finfo(np.float32).eps。
另外,由于这里的分类是对一个个像素点进行区分(是否在文本框内),那么就可看作是语义分割问题,因此,使用 Dice Loss 作为分类损失也是一种可行的方案。
对于以上两种 loss 的选择,根据我的训练结果来看,使用类别平衡交叉熵计算得到的loss值相对较小,可能需要调整合适的loss权重才能更好地让模型学会分类,否则这种方案下训练出的模型容易出现大量误检(召回率还不错但准确率低);而使用 dice loss 的话,loss 值通常在零点几的数量级,不需要加大权重,模型也能比较容易学好分类,训练出来的模型误检率较低。也可以将两种 loss 结合在一起,其中类别平衡的交叉熵损失权重要相对大一些。
几何损失
顾名思义,这部分损失指的是预测框与真实文本框之间在几何层面上计算的损失,通过 d1~d4以及 angle 来计算。
这里比较有意思的是,对于(i=1,2,3,4) ,并不是分别计算预测的与 对应标签的的差来作为损失,而是根据计算出框的面积,然后将预测框与 gt 框之间的 IoU 用于损失计算,IoU 越大说明和 gt 越接近,因此 loss 应该越小,同时由于 IoU 取值范围在 [0, 1],因此可将其输入 log 函数并乘以-1作为 loss。
而对于 angle,使用余弦函数,余弦函数的输入为预测 angle 与 对应标签的 angle 之差。使用余弦函数的好处是,它是偶函数,无需对角度差值取绝对值。这样的话,两个角度相差越小,余弦函数的输出则越大,因此用1减去余弦函数的输出便可作为这部分的loss。
这里需要注意的是交集的计算,与通常计算两个 bbox 的交集稍有不同,这里是根据d来计算的,要取预测与 gt 对应 d 的最小值才是交集,而不像通常的两个 bbox 在计算交集的 x_min 与y_min 时是分别取两者的最大值。
综合损失
最终模型的损失综合了分类与几何损失,可以根据实际情况分别对分类 loss、IoU loss、angle loss 设置不同的权重,最后加起来作为总的损失。
通常在代码中实现模型的损失计算时,都会将其实现为一个 torch.nn.Module 的子类,损失计算则通过重载 forward 方法也即前向反馈过程来实现。
在这里我们需要考虑一种情况,就是一个 batch 中可能并没有 gt,那么此时就不计算损失,直接返回损失为0。
另外注意上图红框部分,这里将几何 loss 与 gt_score 相乘,由于gt_score 中非0即1,因此说明这里仅对正样本计算几何损失。
最后:
越发地觉得,gt 的生成与 loss 的设计往往是很有技巧性的,它们直接影响到模型会学习成什么样子,gt 生成让模型了解到学习的目标,而 loss 设计则是将模型学习的目标转化为在数学上的表达形式,使得模型有途径通过迭代学习不断逼近目标。
在阅读与手写了众多算法模型代码后,吾以为,细心观察生活很重要,只有你足够了解生活,才能从其中的需求中出发,然后基于生活中某些事物的工作方式,抽象出一套方法论,接着用代码去实践,最终通过实验验证,这样之后才有可能创造出一个好的模型算法。