本篇主要是总结一下我们常用的计算loss的方法和使用技巧。
1、tf.nn.sigmoid_cross_entropy_with_logits
sigmoid_cross_entropy_with_logits(
_sentinel=None,
labels=None,
logits=None,
name=None
)
说明:labels和logits必须有相同的type和shape,该方法可以用于多目标问题,如判断一张图片中是否包含人、狗、树等,即对应的label包含多个1。但是output不是一个数,而是一个batch中每个样本的loss,所以一般配合tf.reduce_mean(loss)使用。
#coding=utf8
import tensorflow as tf
a = tf.constant([[1,2,3],[4,5,6]],tf.float32)
y3 = tf.constant([[1,0,0],[0,1,1]],tf.float32)
loss3 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y3,logits=a))
with tf.Session() as sess:
print('loss3',sess.run(loss3))
2、tf.nn.softmax_cross_entropy_with_logits
softmax_cross_entropy_with_logits(
_sentinel=None,
labels=None,
logits=None,
dim=-1,
name=None
)
说明:labels和logits有相同的shape,适用于单目标问题,如判断一张图片是猫、狗还是人,即label中只有一个位置对应的是1,其余全为0。
#coding=utf8
import tensorflow as tf
a = tf.constant([[1,2,3],[4,5,6]],tf.float32)
y1 = tf.constant([[0,0,1],[0,1,0]],tf.int32)
loss1 = tf.nn.softmax_cross_entropy_with_logits(labels=y1,logits=a)
with tf.Session() as sess:
print('loss1',sess.run(loss1))
3、tf.nn.sparse_softmax_cross_entropy_with_logits
sparse_softmax_cross_entropy_with_logits(
_sentinel=None,
labels=None,
logits=None,
name=None
)
说明:labels的shape为[batch_size],且labels的值不能超过n_classes,logits的shape为[batch_size,n_classes],它和tf.nn.softmax_cross_entropy_with_logits的实现原理基本一样,只是后者在额外的做了one_hot编码,前者把这一过程写在了内部,有更优的方式。
#coding=utf8
import tensorflow as tf
a = tf.constant([[1,2,3],[4,5,6]],tf.float32)
y2 = tf.constant([0,1],tf.int32)
loss2 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y2,logits=a)
with tf.Session() as sess:
print('loss2',sess.run(loss2))
4、tf.nn.weighted_cross_entropy_with_logits
weighted_cross_entropy_with_logits(
targets,
logits,
pos_weight,
name=None
)
说明:这个loss我还没有用过,是sigmoid_cross_entropy_with_logits的拓展版,输入参数和实现和后者差不多,可以多支持一个pos_weight参数,目的是可以增加或者减小正样本在算Cross Entropy时的Loss。
除上面的一些计算loss的方法外,还有一些专门处理loss的方法,自己还没有这么用过,具体功能和用法后面再补起来,先列出来。
tf.losses.add_loss ##添加一个计算出来的loss到总的loss集合中
tf.losses.get_regularization_loss ##计算正则化loss
tf.losses.get_total_loss ##计算所有的loss,以包含上面的regular损失
参考文章:
https://weibo.com/ttarticle/p/show?id=2309404047468714166594
http://blog.csdn.net/QW_sunny/article/details/72885403