x_hat = np.linspace(-1, 1, 101).astype('float32')
x_hat = torch.from_numpy(x_hat.reshape(-1, 1))
x = np.zeros(101).astype('float32')
x = torch.from_numpy(x.reshape(-1, 1))
y1 = torch.nn.L1Loss(reduction='none')(x_hat, x)
y2 = torch.nn.MSELoss(reduction='none')(x_hat, x)
y3 = torch.nn.SmoothL1Loss(reduction='none', beta=0.5)(x_hat, x)
plt.plot(x_hat, y1)
plt.plot(x_hat, y2)
plt.plot(x_hat, y3)
基于pytorch是三种损失函数可视化理解
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
推荐阅读更多精彩内容
- 一种基于均值不等式的Listwise损失函数 1 前言 1.1 Learning to Rank 简介 Learn...
- 对于 Python语言来说,比较传统的数据可视化模块是Matplotlib,但它存在不够美观、静态性、不易分享等缺...