pytorch中的损失函数

1. 多标签分类损失函数

pytorch中能计算多标签分类任务loss的方法有好几个。
binary_cross_entropy和binary_cross_entropy_with_logits都是来自torch.nn.functional的函数,BCELoss和BCEWithLogitsLoss都来自torch.nn,它们的区别:

函数名 解释
binary_cross_entropy Function that measures the Binary Cross Entropy between the target and the output
binary_cross_entropy_with_logits Function that measures Binary Cross Entropy between target and output logits
BCELoss Function that measures the Binary Cross Entropy between the target and the output
BCEWithLogitsLoss Function that measures Binary Cross Entropy between target and output logits

区别只在于这个logits,损失函数(类)名字中带了with_logits,这里的logits指的是该损失函数已经内部自带了计算logit的操作,无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间。
nn.functional.xxx是函数接口,而nn.Xxx是nn.functional.xxx的类封装,并且nn.Xxx都继承于一个共同祖先nn.Module。

In [257]: import torch
In [258]: import torch.nn as nn
In [259]: import torch.nn.functional as F

In [260]: true = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
In [261]: pred = torch.rand((2,3))

In [262]: true
Out[262]:
tensor([[1., 0., 1.],
        [1., 0., 0.]])

In [263]: pred
Out[263]:
tensor([[0.0391, 0.7691, 0.1190],
        [0.8846, 0.1628, 0.2641]])

In [264]: F.binary_cross_entropy(torch.sigmoid(pred), true)
Out[264]: tensor(0.7361)

In [265]: F.binary_cross_entropy_with_logits(pred, true)
Out[265]: tensor(0.7361)


In [267]: lf2 = nn.BCELoss()
In [268]: lf2(torch.sigmoid(pred), true)
Out[268]: tensor(0.7361)

In [269]: lf = nn.BCEWithLogitsLoss()
In [270]: lf(pred, true)
Out[270]: tensor(0.7361)

# -(ylog(p)+(1-y)log(1-p))
In [268]: torch.sum(-(true*torch.log(torch.sigmoid(pred))+(1-true)*torch.log(1-torch.sigmoid(pred))))/6  
Out[268]: tensor(0.7361)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容