Keras 自定义loss函数 focal loss + triplet loss

上一节中已经阐述清楚了,keras.Model的输入输出与loss的关系。

一、自定义loss损失函数

https://spaces.ac.cn/archives/4493/comment-page-1#comments
非常简单,其实和官方写的方法一样。比如MSE:

def mean_squared_error(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true), axis=-1)

model.compile(optimizer=optim, loss=[mean_squared_error])

注意的是,损失函数def mean_squared_error(y_true, y_pred)中的两个参数是固定的,由Keras自动注入。第一个参数来自于model.fit(x=[],y=[])中的y中的第n个,代表的是真实标签。第二个参数来自于推理后model.outputs相应位置的输出。
同时,model.compile()方法中loss传入的是方法体名称,非方法的return。

二、自定义keras损失函数:focal loss

https://github.com/umbertogriffo/focal-loss-keras/blob/master/losses.py

为了传入超参数,使用了python的wrapper模式构建函数,函数实际返回的是内部函数的名称,符合上述定义。

三、自定义keras损失函数:triplet loss

https://stackoverflow.com/questions/53996020/keras-model-with-tf-contrib-losses-metric-learning-triplet-semihard-loss-asserti
https://github.com/rsalesc/TCC/blob/master/scpd/tf/keras/common.py

由于triplet loss的输入比较特殊,是label(非one-hot格式)与嵌入层向量,因此,对应的,我们在keras的数据输入阶段,提供的第二个label就得是非one-hot格式。同时,model构造中得定义嵌入层,并使用L2正则化,且作为model的一个output以方便loss中调用。

实例中,定义模型时,我们分开定义嵌入层的logits与激活函数,以提取出来嵌入层的值。

    conv3 = layers.Conv2D(5, 1)(maxpool)
    embed = Dense(128, activation=None, name="embedding", kernel_regularizer=regularizers.l2(0.01))(conv3)
    dense4 = layers.Activation(activation=keras.activations.relu)(embed)
    norm_x = Lambda(lambda x: K.l2_normalize(x, axis=1))(embed)
    dense5 = layers.Dense(10, activation='softmax')(dense4)
    model = keras.Model(inputs=[a], outputs=[dense5, norm_x])

当然,在输入数据的生成器中,也必须每次:
yield img,[one_hot_label, label]以对应。

之后即可构造自定义的triplet loss func:

def semi_hard(labels, embeddings):
    labels = K.squeeze(labels, axis=1)
    return tf.contrib.losses.metric_learning.triplet_semihard_loss(labels, embeddings, margin=1.0)

最后在compile中调用即可:

model.compile(optimizer=optim, loss=[classify_loss, semi_hard], loss_weights=[1,0.1])

四、tensorflow中的triplet loss

网易云课堂-吴恩达深度学习的triplet loss章节
https://blog.csdn.net/weixin_40400177/article/details/105213578

https://blog.csdn.net/qq_36387683/article/details/83583099
https://zhuanlan.zhihu.com/p/121763855

Easy Triplets 显然不应加入训练,因为它的损失为0,加在loss里面会拉低loss的平均值。Hard Triplets 和 Semi-Hard Triplets 的选择则见仁见智,针对不同的任务需求,可以只选择Semi-Hard Triplets或者Hard Triplets,也可以两者混用。

如图中所示,其实最难分类的是
semi-hard triplets:d(a,p) < d(a,n) < d(a,p) + margin
我们试图找出这样的图片对来加以训练。

可以使用离线学习,每次训练先找到难分类的图片对,然后喂入网络,但是这样很麻烦,且网络结构同样不好设计。因此使用在线挖掘,即每次在一个batch即B个特征向量中,去挖掘出(a,p)和最难分类的(a,n)来计算loss并反向传播。

官方API:

def triplet_semihard_loss(labels, embeddings, margin=1.0):
  """Computes the triplet loss with semi-hard negative mining.

  The loss encourages the positive distances (between a pair of embeddings with
  the same labels) to be smaller than the minimum negative distance among
  which are at least greater than the positive distance plus the margin constant
  (called semi-hard negative) in the mini-batch. If no such negative exists,
  uses the largest negative distance instead.
  See: https://arxiv.org/abs/1503.03832.

  Args:
    labels: 1-D tf.int32 `Tensor` with shape [batch_size] of
      multiclass integer labels.
    embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should
      be l2 normalized.
    margin: Float, margin term in the loss definition.

  Returns:
    triplet_loss: tf.float32 scalar.
  """

官方解释的很清楚了,就是想让处于semi-hard区域的最小的d(a,n)尽量去远离>d(a,p)+margin,而由于该(a,n)处于semi-hard区域因此该d(a,n)必须至少>d(a,p)。若找不到这样的(a,n),则表明可能(a,n)比起(a,p)更小,因此使用最大的(a,n)代替。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,039评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,426评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,417评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,868评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,892评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,692评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,416评论 3 419
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,326评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,782评论 1 316
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,957评论 3 337
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,102评论 1 350
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,790评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,442评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,996评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,113评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,332评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,044评论 2 355