1. 什么是损失函数(Loss Function)
损失函数在训练AI模型的过程中,用于计算模型预测值与真实值之间的误差(Error),又称误差函数(Error function)
2. PyTorch中内建的损失函数
在torch.nn中内建了很多常用的损失函数,依据用途,可以分为三类:
- 用于回归问题(Regression loss):回归损失主要关注连续值,例如: L1范数损失(L1Loss), 均方误差损失(MSELoss)等。
- 用于分类问题(Classification loss):分类损失函数处理离散值,例如,交叉熵损失(CrossEntropyLoss), 二进制交叉熵损失(BCELoss)等。
- 用于排名问题(Ranking loss): 排名损失用于预测输入样本之间的相对距离,例如,MarginRankingLoss,TripletMarginLoss等。
3. 使用PyTorch编写自定义损失函数
3.1 何时自己编写损失函数
当任务不再是单一的回归或者分类,PyTorch的内建损失函数无法满足任务需求时,需要开发者自己手动编写损失函数,例如:YOLOv5的total_loss为
total_loss = class loss + bbox loss+ object loss
就需要手动编写。
3.2 用PyTorch编写损失函数模板
方案1 -- 编写一个继承nn.Module的Loss函数类,并在forward()方法中实现loss计算:
class MyLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return torch.mean(torch.pow((x - y), 2))
方案2 -- 编写一个Loss函数类,并在_ call _方法中实现loss计算:
class MyLoss():
def __init__(self):
def forward(self, x, y):
return torch.mean(torch.pow((x - y), 2))
以上两种Loss函数类的使用方式都是:
loss_fn = MyLoss() # 实例化Loss函数
loss = loss_fn(preds, targets) # 计算loss值,即预测值与真实值之间的误差
...
loss.backward() #
在实际开发工作中,以上两种实现方案都非常常见,运行效率几乎一致,无性能优劣之分;由于loss函数并没有需要学习的参数,所以使用哪一种主要看开发者习惯的编程风格(coding style)。
参考资料: