label smooth

标签平滑:
Label Smoothing(标签平滑)是一个经典的正则化方法,机器学习的样本中通常会存在少量错误标签,这些错误标签会影响到预测的效果。标签平滑采用如下思路解决这个问题:在训练时即假设标签可能存在错误,避免“过分”相信训练样本的标签。当目标函数为交叉熵时,这一思想有非常简单的实现:

# coding: UTF-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class CELoss(nn.Module):
    ''' Cross Entropy Loss with label smoothing '''
    def __init__(self, label_smooth=None, class_num=2):
        super().__init__()
        self.label_smooth = label_smooth
        self.class_num = class_num
 
    def forward(self, pred, target):
        ''' 
        Args:
            pred: prediction of model output    [N, M]
            target: ground truth of sampler [N]
        '''
        eps = 1e-12
        
        if self.label_smooth is not None:
            # cross entropy loss with label smoothing
            logprobs = F.log_softmax(pred, dim=1)   # softmax + log
            target = F.one_hot(target, self.class_num)  # 转换成one-hot
            
            # label smoothing
            # 实现 1
            # target = (1.0-self.label_smooth)*target + self.label_smooth/self.class_num    
            # 实现 2
            # implement 2
            target = torch.clamp(target.float(), min=self.label_smooth/(self.class_num-1), max=1.0-self.label_smooth)
            loss = -1*torch.sum(target*logprobs, 1)
        else:
            # standard cross entropy loss
            loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))
        return loss.mean()

在训练过程中调用:

from CELoss import CELoss
loss2 = CELoss(label_smooth=0.05, class_num=2)  # 标签平滑
with torch.no_grad():
    for texts, labels in data_iter:
        outputs = model(texts) # [batch_size, num_class=2]
        loss = F.cross_entropy(outputs, labels)
        # loss=loss2(outputs, labels)
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 花书上关于网络优化的笔记记录于https://www.jianshu.com/p/06bb6d6a5227 花书上...
    单调不减阅读 1,497评论 0 0
  • 概念: 机器学习的样本中通常会存在少量错误标签,这些错误标签会影响到预测的效果。标签平滑采用如下思路解决这个问题:...
    三方斜阳阅读 2,400评论 1 3
  • 本文系转载! 原文链接:https://juejin.im/post/6844903520089407502 机器...
    DeepNLPLearner阅读 1,045评论 0 0
  • 1 为什么要对特征做归一化 特征归一化是将所有特征都统一到一个大致相同的数值区间内,通常为[0,1]。常用的特征归...
    顾子豪阅读 6,632评论 2 22
  • 1 为什么要对特征做归一化 特征归一化是将所有特征都统一到一个大致相同的数值区间内,通常为[0,1]。常用的特征归...
    顾子豪阅读 1,426评论 0 1