BCEWithLogitsLoss参数weight

1. weight:
  • a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.

    就是给出weight参数后,会将其shape和input的shape相匹配。回忆公式:
  • 默认情况,也就是weight=None时,上述公式中的Wn=1;当weight!=None时,也就意味着我们需要为每一个样本赋予权重Wi,这样weight的shape和input一致就很好理解了。
    首先看pytorch中weight参数作用后的结果,weight就是为每一个样本加权:
import torch
import torch.nn as nn
input = torch.tensor([[-0.4089,-1.2471,0.5907],
                      [-0.4897,-0.8267,-0.7349],
                      [0.5241,-0.1246,-0.4751]])
m=nn.Sigmoid()
S_input=m(input)

target=torch.FloatTensor([[0,1,1],[0,0,1],[1,0,1]])

w = [0.1, 0.9] # 标签0和标签1的权重
weight = torch.zeros(target.shape)  # 权重矩阵
for i in range(target.shape[0]):
    for j in range(target.shape[1]):
        weight[i][j] = w[int(target[i][j])]
print(weight)

BCEWithLogitsLoss=nn.BCEWithLogitsLoss(weight=weight)
loss = BCEWithLogitsLoss(input,target)
print(loss)
loss = 0.0
for i in range(S_input.shape[0]):
    for j in range(S_input.shape[1]):
        loss += -weight[i][j] * (target[i][j] * torch.log(S_input[i][j]) + (1 - target[i][j]) * torch.log(1 - S_input[i][j]))
print(loss/(S_input.shape[0]*S_input.shape[1])) # 默认取均值

tensor([[0.1000, 0.9000, 0.9000],
        [0.1000, 0.1000, 0.9000],
        [0.9000, 0.1000, 0.9000]])
tensor(0.4711)
tensor(0.4711)
  • pytorch官方的代码和自己实现的计算出的损失一致,再次说明BCEWithLogitsLoss的weight权重会分别对应的作用在每一个样本上。
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容