Pytorch中的分类问题损失函数

前言:pytorch中有几个非常容易搞混淆的函数,它们是softmax和log_softmax,CrossEntropyLoss()和NLLLoss(),为了更加彻底的弄清楚,本文将分为两篇文章来说明,第一篇说明原理,第二篇说明用法。

一、二分类损失函数

1.1 从一个简单的实例说起

对于一个二分类问题,比如我们有一个样本,有两个不同的模型对他进行分类,那么它们的输出都应该是一个二维向量,比如:

模型一的输出为:pred_y1=[0.8,0.2]

模型二的输出为:pred_y2=[0.6,0.4]

需要注意的是,这里的数值已经经过了sigmoid激活函数(为什么要这么说,这对于后面理解pytorch的几个函数是有帮助的),所以0.8+0.2=1,

比如样本的真实标签是

true_y=[1,0]

现在我们来求这两个模型对于这一个类别的分类损失,怎么求?先给出二分类损失函数,表达式如下

这里的y表示的是真实地标签,y上面有一个符号的那个y表示模型的输出标签,我们带入进去计算得到:

cost1 = -[1log0.8+(1-1)log0.2] = 0.22314

cost2 = -[1log0.6+(1-1)log0.4] = 0.51083

我们可以看出,第一个模型的损失更小,自然我们觉得第一个模型更好,而且直观来看,第一个模型觉得正确的概率是0.8,自然比0.6要好。

二、多分类损失函数

2.1 关键概念——什么是交叉熵

对于连续性的概率分布而言,首先来看信息论中交叉熵的定义:

交叉熵是用来描述两个分布的距离的,神经网络训练的目的就是使 g(x) 逼近 p(x)。

对于离散型的概率分布而言,

注意:交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的。给定两个概率分布p和q,通过q来表示p的交叉熵为上式,交叉熵刻画的是两个概率分布之间的距离,或可以说它刻画的是通过概率分布q来表达概率分布p的困难程度,p代表正确答案,q代表的是预测值,交叉熵越小,两个概率的分布约接近。

2.2 多分类损失的简单例子

举个例子,假设有一个3分类问题,某个样例的正确答案是(1,0,0),这个模型经过softmax回归之后的预测答案是(0.5,0.4,0.1),

这里的数值已经经过了softmax激活函数(为什么要这么说,这对于后面理解pytorch的几个函数是有帮助的)

那么预测和正确答案之间的交叉熵为:

如果另一个模型的预测是(0.8,0.1,0.1),那么这个预测值和真实值之间的交叉熵是:

显然我们看到第一个模型的损失为0.3,二第二个模型的损失为0.1,第二个模型的损失更小,第二个预测要优于第一个。这里的(1,0,0)就是正确答案p,(0.5,0.4,0.1)和(0.8,0.1,0.1)就是预测值q,显然用(0.8,0.1,0.1)表达(1,0,0)的损失更小一些,准确度更高一些。

下面给一个多分类交叉熵损失函数的一般表达式

总结:不管是二分类,还是多分类问题,其实在计算损失函数的过程都经历了三个步骤:

(1)激活函数。通过激活函数sigmoid或者是softmax将输出值缩放到[0,1]之间,

(2)求对数。计算缩放之后的向量的对数值,即所谓的logy的值,求对数之后的值在[-infinite,0]之间

(3)累加求和。根据损失函数的定义,将标签和输出值逐元素相乘再求和,最后再添加一个负号求相反数,得到一个正数损失。

不管什么样的实现方式,都会经历这三个步骤,不同的是,可能有的函数会将其中的一个或者是几个步骤封装在一起。

三、二分类损失函数的实现

pytorch中二分类损失函数有两种,它们分别是:

torch.nn.BCELoss() 和 torch.nn.BCEWithLogitsLoss()

BCE这三个字母其实就是binary cross entropy的缩写

他们的区别是:

  • (1)BCELoss:需要先将最后一层经过sigmoid进行缩放然后再通过该函数
  • (2)BCEWithLogitsLoss:BCEWithLogitsLoss就是把Sigmoid-BCELoss合成一步,不再需要在最后经过sigmoid进行缩放,直接对最后得到的logits进行处理。
  • 备注:logits,指的是还没有经过sigmoid和softmax缩放的结果哦!

四、softmax、log_softmax、CrossEntropyLoss 、NLLLoss 四个函数的对比
(1)softmax/sigmoid

这个只对应于上面的第一步骤,即相当于是激活函数操作,将输出缩放到[0,1]之间

(2)log_softmax
在softmax的结果上再做多一次log运算,即相当于是一次性完成第一步和第二步。

(3)nll_loss

这个实际上只对应于上面的第三个步骤,

(4)CrossEntropyLoss

CrossEntropyLoss就是把以上Softmax–Log–NLLLoss合并成一步

总结:上面的四种方法只是完成的功能,在具体的写代码的时候,还需要我们可以有两种方式来实现上面的四个功能,即通过函数的形式和通过类的形式,如下:

import torch.nn.functional as F
 
F.xxx          # 通过函数的形式
torch.nn.xxx   # 通过类的形式

总结:通过上面的分析,我们知道了,求多分类交叉熵损失有三种途径可以实现,分别是:

  • (1)三步实现:softmax+log+nll_loss
  • (2)两步实现:log_softmax+nll_loss
  • (3)一步实现:crossEntropyLoss

数据准备:

import numpy as np
import torch
import torch.nn.functional as F
 
 
# 比如这是一个模型的输出,本案例为一个三类别的分类,共有四组样本,如下:
pred_y = np.array([[ 0.30722019 ,-0.8358033 ,-1.24752918],
                   [ 0.72186664 , 0.58657704 ,-0.25026393],
                   [ 0.16449865 ,-0.44255082 , 0.68046693],
                   [-0.52082402 , 1.71407838 ,-1.36618063]])
pred_y=torch.from_numpy(pred_y)
 
# 真实的标签如下所示
true_y = np.array([[ 1 , 0 , 0],
                   [ 0 , 1 , 0],
                   [ 0 , 1 , 0],
                   [ 0 , 0 , 1]])
true_y=torch.from_numpy(true_y)
target = np.argmax(true_y, axis=1) #(4,) #其实就是获得每一给类别的整数值,这个和tensorflow里面不一样哦 内部会自动转换为one-hot形式
target = torch.LongTensor(target)  #(4,)
 
print(target)            # tensor([0,1,1,2])
print("-----------------------------------------------------------")

4.1 三步实现:softmax+log+nll_loss如下:

# 第一步:使用激活函数softmax进行缩放
pred_y = F.softmax(pred_y,dim=1)
print(pred_y)
print('-----------------------------------------------------------')
 
# 第二步:对每一个缩放之后的值求对数log
pred_y=torch.log(pred_y)
print(pred_y)
print('-----------------------------------------------------------')
 
# 第三步:求交叉熵损失
loss=F.nll_loss(pred_y,target)
print(loss)  # 最终的损失为:tensor(1.5929, dtype=torch.float64)

4.2 两步实现:log_softmax+nll_loss

# 第一步:直接使用log_softmax,相当于softmax+log
pred_y=F.log_softmax(pred_y,dim=1)
print(pred_y)
print('-----------------------------------------------------------')
 
# 第二步:求交叉熵损失
loss=F.nll_loss(pred_y,target)
print(loss) # tensor(1.5929, dtype=torch.float64)

注意:使用log_softmax后也可以再使用crossEntropy,效果是等价的,虽然crossEntropy内置了log_softmax。因为重复使用多次log_softmax,结果是不变的,即和使用一次是一样的。不过一般不这么使用。
4.3 一步实现:crossEntropyLoss

# 第一步:求交叉熵损失一步到位
loss=F.cross_entropy(pred_y,target)
print(loss)  tensor(1.5929, dtype=torch.float64)

4.4 总结,在求交叉熵损失的时候,需要注意的是,不管是使用 nll_loss函数,还是直接使用cross_entropy函数,都需要传递一个target参数,这个参数表示的是真实的类别,对应于一个列表的形式而不是一个二维数组,这个和tensorflow是不一样的哦!(Pytorch分类损失函数内部会自动把列表形式(一维数组形式)的整数索引转换为one-hot表示)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,588评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,456评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,146评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,387评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,481评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,510评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,522评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,296评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,745评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,039评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,202评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,901评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,538评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,165评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,415评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,081评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,085评论 2 352

推荐阅读更多精彩内容