基于pytorch是三种损失函数可视化理解

x_hat = np.linspace(-1, 1, 101).astype('float32')
x_hat = torch.from_numpy(x_hat.reshape(-1, 1))
x = np.zeros(101).astype('float32')
x = torch.from_numpy(x.reshape(-1, 1))

y1 = torch.nn.L1Loss(reduction='none')(x_hat, x)
y2 = torch.nn.MSELoss(reduction='none')(x_hat, x)
y3 = torch.nn.SmoothL1Loss(reduction='none', beta=0.5)(x_hat, x)

plt.plot(x_hat, y1)
plt.plot(x_hat, y2)
plt.plot(x_hat, y3)
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容