基类定义
pytorch损失类也是模块的派生,损失类的基类是_Loss,定义如下
class _Loss(Module):
def __init__(self, size_average=None, reduce=None, reduction='elementwise_mean'):
super(_Loss, self).__init__()
if size_average is not None or reduce is not None:
self.reduction = _Reduction.legacy_get_string(size_average, reduce)
else:
self.reduction = reduction
看这个类,有两点我们知道:
- 损失类是模块
- 不改变forward函数,但是具备执行功能
还有其他模块的性质
子类介绍
从_Loss派生的类有
| 名称 | 说明 | 公式 |
|---|---|---|
| _WeightedLoss | ||
| L1Loss | X与Y的 shape相同 | ![]() |
| PoissonNLLLoss | 适合多目标分类![]() |
![]() |
| KLDivLoss | 适用于连续分布的距离计算 | |
| MSELoss | 均方差 | ![]() |
| BCEWithLogitsLoss | 多目标不需要经过sigmoid | ![]() |
| HingeEmbeddingLoss | Y中的元素只能为1或-1 | ![]() |
| MultiLabelMarginLoss | 适用于多目标分类 | ![]() |
| SmoothL1Loss | ![]() |
|
| SoftMarginLoss | ![]() |
|
| CosineEmbeddingLoss | ![]() |
|
| MarginRankingLoss | ![]() |
|
| TripletMarginLoss | ![]() |
从_WeightedLoss继续派生的函数有
| 名称 | 说明 | |
|---|---|---|
| NLLLoss | ![]() |
|
| BCELoss | ![]() |
|
| CrossEntropyLoss | ![]() |
|
| MultiLabelSoftMarginLoss | ![]() |
|
| MultiMarginLoss | ![]() |
















