Focal Loss 原理及实践

1 关于Focal Loss

Focal Loss 是一个在交叉熵(CE)基础上改进的损失函数,来自ICCV2017的Best student paper——Focal Loss for Dense Object Detection。论文下载链接为:https://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf。Focal Loss的提出源自图像领域中目标检测任务中样本数量不平衡性的问题,并且这里所谓的不平衡性跟平常理解的是有所区别的,它还强调了样本的难易性。尽管Focal Loss 始于目标检测场景,其实它可以应用到很多其他任务场景,只要符合它的问题背景,就可以试试,会有意想不到的效果。

2 Focal Loss 原理

在引入Focal Loss公式前,我们以源paper中目标检测的任务来说:目标检测器通常会产生高达100k的候选目标,只有极少数是正样本,正负样本数量非常不平衡

在计算分类的时候常用的损失——交叉熵(CE)的公式如下:


其中y取值{1,-1}代表正负样本,p为模型预测的label概率,通常p>0.5就判断为正样本,否则为负样本。论文中为了方便展示,重新定义了p_t:

这样CE函数就可以表达为:CE(p, y) = CE(p_t) = − log(p_t).

在CE基础上,为了解决正负样本不平衡性,有人提出一种带权重的CE函数

其中当:y=1, \alpha_t=\alpha ; y=-1, \alpha_t=1-\alpha。 参数\alpha 为控制正负样本的权重,取值范围[0,1]。 尽管这是一种很简单的解决正负样本不平衡的方案,但它还没真正达到paper中作者想解决的问题:因为正负样本中也有难易之分,认为模型应该更聚焦在难样本的学习上。如下图,按正负,难易可将样本分为四个维度,其实上面带权重的CE函数,只是解决了正负问题,并没有解决难易问题

在这里可能有人疑问,怎么来衡量一个样本的难易程度,更何况真实数据也没有这个标记。其实,这里的样本难易是用模型来判断的,就正样本集合来说,如果一个样本预测的p=0.9,一个样本预测的p=0.6,明显前一个样本更容易学习,或者说特征更明显,是易样本。这样也就是说,预测的概率越接近1或0的样本,就越是容易学习的样本,相反,越是集中0.5左右的样本,就是难样本。在sigomid函数上,可以按下图的方式展示样本的难易之分。

既然问题已梳理清楚,怎么让模型对难易样本也有区分性的学习,也是说聚焦程度不同。模型应该花更多精力在难样本的学习上,而减少精力在易样本的学习,之前的CE函数,以及带权重的CE函数,都是将难样本、易样本等同看待的。这样就引出Focal Loss 的表达形式:


其中\gamma为调节因子,取值为[0,5],当\gamma=0,就等同于CE函数;\gamma值越大,表示模型在难易样本上聚焦的更厉害。下图是不同参数下表现形式。

结合上图与公式,可以看出,当p_t趋近1时,权重(1-p_t)^{\gamma}趋近0,对总损失贡献几乎没有影响,意味模型较少对这类样本的学习;比如, 在正样本集合中,\gamma=2,当一样本p_t=0.6, 当一样本p_t=0.7,二者相对来说,前者是难样本,后者是易样本,反映在Focal Losss上,前者的对总损失贡献权重为0.16,后者0.09,明显难样本贡献权重更大,模型也就会更聚焦对其学习。同理,负样本中一样。

但是上面的Focal Loss公式只是体现了难易样本的区分,没有区分正负。这样就引出了完整版的Focal Loss表达形式:


这样Focal Loss既能调整正负样本的权重,又能控制难易分类样本的权重。paper中通过实验验证,默认\gamma=2\alpha_t=0.25(y=1)。在这里\alpha_t取值上可能会有疑问,理论上正样本权重更大些,取0.75,而paper实验结果给的是0.25。这里结合其他人的解释,说下我的理解:主要原因是\gamma=2,而大部分负样本的p<0.1,导致负样本的贡献权重还小于正样本贡献的权重,本意是想调高正样本的贡献权重,但这样就有点调的过大了,所以\alpha_t=0.25(y=1)就有点反过来提高下负样本的权重。所以在最终版中,不能理解\alpha_t就是完全来调节正负样本的权重的,而是要结合\alpha_t(1-p_t)^{\gamma}一起来看。

3 Focal Loss 实践

基于上面的介绍,我们对Focal Loss进行一下实验验证。这里选择MNIST数据集进行实验:只识别数字3,这样将数据集的label转变为[0,1],1代表是数字3,0为其他数字,这样就构建一个不平衡的样本数据集。模型最后一层选择sigmod作为激活函数进行回归预测,然后选择CE与FL两种损失函数,看看训练情况如何。下面为对应的代码。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import tensorflow.keras.backend as K

# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') /255
x_test = x_test.reshape(10000, 784).astype('float32') /255

y_train=np.array([1 if d==2 else 0 for d in y_train])
y_test=np.array([1 if d==2 else 0 for d in y_test])
#定义focal loss
def focal_loss(gamma=2., alpha=.25):
    def focal_loss_fixed(y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon()+pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
    return focal_loss_fixed
#build model
inputs = keras.Input(shape=(784,), name='mnist_input')
h1 = layers.Dense(64, activation='relu')(inputs)
outputs = layers.Dense(1, activation='sigmoid')(h1)
model = tf.keras.Model(inputs, outputs)
#以平方差损失函数来编译模型进行训练
model.compile(optimizer=keras.optimizers.RMSprop(),
             loss=keras.losses.BinaryCrossentropy(),
             metrics=['accuracy'])
#以Focal Loss损失函数来编译模型进行训练
model.compile(optimizer=keras.optimizers.RMSprop(),
             loss=[focal_loss(alpha=.25, gamma=2)],
             metrics=['accuracy'])
#training
history = model.fit(x_train, y_train, batch_size=64, epochs=5,
         validation_data=(x_test, y_test))

训练结果如下:


在CE下训练结果
在FL下训练结果

从结果可以看出,虽然在该数据集上二者提升效果并不大,但Focal Loss在每轮上都优于CE的训练效果,所以还是能体现Focal Loss的优势,如果在其他更不平衡的数据集上,应该效果更好。不管在CV,还是NLP领域,该损失函数值得大家去尝试。在AAAI2019会议上提出一种基于Focal loss的改进版GHM(Gradient Harmonized Single-stage Detector),有兴趣的也可以去读读。

更多文章可关注笔者公众号:自然语言处理算法与实践

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 218,546评论 6 507
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,224评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,911评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,737评论 1 294
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,753评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,598评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,338评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,249评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,696评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,888评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,013评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,731评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,348评论 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,929评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,048评论 1 270
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,203评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,960评论 2 355