【记录】复刻 pytorch nn.CrossEntropyLoss()

def myCrossEntropyLoss(output, label):
    count = label.size(0)

    loss = 0.0
    for x, l in zip(output, label):
        loss += -1 * x[l] + torch.log(torch.exp(x).sum())
    
    return loss/count

output = torch.randn(10, 5, requires_grad = True) #假设是网络的最后一层,5分类
label = torch.empty(10, dtype=torch.long).random_(5) # 0 - 4, 任意选取一个分类
print(output.shape, label.shape)

print()

loss = myCrossEntropyLoss(output, label)
print('my loss = {:.5f}'.format(loss))

nnCrossEntropyLoss = nn.CrossEntropyLoss()
nnCrossEntropyLossWithIngore = nn.CrossEntropyLoss(ignore_index=0)
loss = nnCrossEntropyLoss(output, label)
loss_with_ignore = nnCrossEntropyLossWithIngore(output, label)
print('torch loss = {:.5f}'.format(loss.data))
print('torch loss with ignore = {:.5f}'.format(loss_with_ignore.data))
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容