1. weight:
-
a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.
就是给出weight参数后,会将其shape和input的shape相匹配。回忆公式:
- 默认情况,也就是weight=None时,上述公式中的Wn=1;当weight!=None时,也就意味着我们需要为每一个样本赋予权重Wi,这样weight的shape和input一致就很好理解了。
首先看pytorch中weight参数作用后的结果,weight就是为每一个样本加权:
import torch
import torch.nn as nn
input = torch.tensor([[-0.4089,-1.2471,0.5907],
[-0.4897,-0.8267,-0.7349],
[0.5241,-0.1246,-0.4751]])
m=nn.Sigmoid()
S_input=m(input)
target=torch.FloatTensor([[0,1,1],[0,0,1],[1,0,1]])
w = [0.1, 0.9] # 标签0和标签1的权重
weight = torch.zeros(target.shape) # 权重矩阵
for i in range(target.shape[0]):
for j in range(target.shape[1]):
weight[i][j] = w[int(target[i][j])]
print(weight)
BCEWithLogitsLoss=nn.BCEWithLogitsLoss(weight=weight)
loss = BCEWithLogitsLoss(input,target)
print(loss)
loss = 0.0
for i in range(S_input.shape[0]):
for j in range(S_input.shape[1]):
loss += -weight[i][j] * (target[i][j] * torch.log(S_input[i][j]) + (1 - target[i][j]) * torch.log(1 - S_input[i][j]))
print(loss/(S_input.shape[0]*S_input.shape[1])) # 默认取均值
tensor([[0.1000, 0.9000, 0.9000],
[0.1000, 0.1000, 0.9000],
[0.9000, 0.1000, 0.9000]])
tensor(0.4711)
tensor(0.4711)
- pytorch官方的代码和自己实现的计算出的损失一致,再次说明BCEWithLogitsLoss的weight权重会分别对应的作用在每一个样本上。