tf.keras.losses.BinaryCrossentropy()/CategoricalCrossentropy验证

tf.keras.losses.BinaryCrossentropy()内部是怎么计算的,举例说明

例1:

import tensorflow as tf
import math
y_true = [[0., 1.]]
y_pred = [[0.8, 0.2]]
# Using 'auto'/'sum_over_batch_size' reduction type.
bce = tf.keras.losses.BinaryCrossentropy()
bce(y_true, y_pred).numpy()

输出:1.6094375

手推:

a = -0*math.log(0.8)-(1-0)*math.log(1-0.8)
b = -1*math.log(0.2)-(1-1)*math.log(1-0.2)
print((a+b)/2)

输出:1.6094379124341005

例2:

y_true = [[0., 1.], [0., 0.]]
y_pred = [[0.6, 0.4], [0.4, 0.6]]
# Using 'auto'/'sum_over_batch_size' reduction type.
bce = tf.keras.losses.BinaryCrossentropy()
bce(y_true, y_pred).numpy() 

输出:0.81492424

y_true = [[0., 1.], [0., 0.]]
y_pred = [[0.6, 0.4], [0.4, 0.6]]
loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
assert loss.shape == (2,)
loss.numpy().mean()

输出:0.81492424

手推:

a = -0*math.log(0.6)-(1-0)*math.log(1-0.6)
b = -1*math.log(0.4)-(1-1)*math.log(1-0.4)
c = -0*math.log(0.4)-(1-0)*math.log(1-0.4)
d = -0*math.log(0.6)-(1-0)*math.log(1-0.6)
print(a,b,c,d)
print((a+b+c+d)/4)

输出:0.814924454847114

tf.keras.losses.categorical_crossentropy()内部是怎么计算的,举例说明


例1:

y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.01, 0.95, 0.04], [0.1, 0.8, 0.1]]
# Using 'auto'/'sum_over_batch_size' reduction type.
cce = tf.keras.losses.CategoricalCrossentropy()
cce(y_true, y_pred).numpy()

输出:1.1769392

y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
assert loss.shape == (2,)
loss.numpy()

输出:array([0.05129331, 2.3025851 ], dtype=float32)

CategoricalCrossentropy和categorical_crossentropy在运用和输出上有些许差别,但计算方式是一致的。看下面,做个平均其实就是一样的。

y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
assert loss.shape == (2,)
loss.numpy().mean()

输出:输出:1.1769392

手推:

import math 

a1 = -0*math.log(0.01)
b1 = -1*math.log(0.95)
c1 = -0*math.log(0.04)

a2 = -0*math.log(0.1)
b2 = -0*math.log(0.8)
c2 = -1*math.log(0.1)

print((a1+b1+c1+a2+b2+c2)/2)

输出:1.176939193690798
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容