1. 多标签分类损失函数
pytorch中能计算多标签分类任务loss的方法有好几个。
binary_cross_entropy和binary_cross_entropy_with_logits都是来自torch.nn.functional的函数,BCELoss和BCEWithLogitsLoss都来自torch.nn,它们的区别:
函数名 | 解释 |
---|---|
binary_cross_entropy | Function that measures the Binary Cross Entropy between the target and the output |
binary_cross_entropy_with_logits | Function that measures Binary Cross Entropy between target and output logits |
BCELoss | Function that measures the Binary Cross Entropy between the target and the output |
BCEWithLogitsLoss | Function that measures Binary Cross Entropy between target and output logits |
区别只在于这个logits,损失函数(类)名字中带了with_logits,这里的logits指的是该损失函数已经内部自带了计算logit的操作,无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间。
nn.functional.xxx是函数接口,而nn.Xxx是nn.functional.xxx的类封装,并且nn.Xxx都继承于一个共同祖先nn.Module。
In [257]: import torch
In [258]: import torch.nn as nn
In [259]: import torch.nn.functional as F
In [260]: true = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
In [261]: pred = torch.rand((2,3))
In [262]: true
Out[262]:
tensor([[1., 0., 1.],
[1., 0., 0.]])
In [263]: pred
Out[263]:
tensor([[0.0391, 0.7691, 0.1190],
[0.8846, 0.1628, 0.2641]])
In [264]: F.binary_cross_entropy(torch.sigmoid(pred), true)
Out[264]: tensor(0.7361)
In [265]: F.binary_cross_entropy_with_logits(pred, true)
Out[265]: tensor(0.7361)
In [267]: lf2 = nn.BCELoss()
In [268]: lf2(torch.sigmoid(pred), true)
Out[268]: tensor(0.7361)
In [269]: lf = nn.BCEWithLogitsLoss()
In [270]: lf(pred, true)
Out[270]: tensor(0.7361)
# -(ylog(p)+(1-y)log(1-p))
In [268]: torch.sum(-(true*torch.log(torch.sigmoid(pred))+(1-true)*torch.log(1-torch.sigmoid(pred))))/6
Out[268]: tensor(0.7361)