pytorch基础五(定义自动求导函数)

本人学习pytorch主要参考官方文档莫烦Python中的pytorch视频教程。
后文主要是对pytorch官网的文档的总结。
代码来自pytorch官网

import torch
# 通过继承torch.autograd.Function类,并实现forward 和 backward函数
class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        """
        在forward函数中,接收包含输入的Tensor并返回包含输出的Tensor。
        ctx是环境变量,用于提供反向传播是需要的信息。可通过ctx.save_for_backward方法缓存数据。
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        在backward函数中,接收包含了损失梯度的Tensor,
        我们需要根据输入计算损失的梯度。
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

dtype = torch.float
device = torch.device("cpu")
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)
learning_rate = 1e-6
for t in range(500):
    relu = MyReLU.apply
    y_pred = relu(x.mm(w1)).mm(w2)
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
    loss.backward()
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        w1.grad.zero_()
        w2.grad.zero_()
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • Android 自定义View的各种姿势1 Activity的显示之ViewRootImpl详解 Activity...
    passiontim阅读 174,650评论 25 709
  • 用两张图告诉你,为什么你的 App 会卡顿? - Android - 掘金 Cover 有什么料? 从这篇文章中你...
    hw1212阅读 13,131评论 2 59
  • 1、通过CocoaPods安装项目名称项目信息 AFNetworking网络请求组件 FMDB本地数据库组件 SD...
    阳明AI阅读 16,032评论 3 119
  • zqann阅读 252评论 0 0
  • 今日体验:今天核对了下各店喷漆,把结算单,电脑系统,还有实时到账核对了下,有问题的及时联系前台,发现问题解决问题。
    A郑淑英阅读 233评论 0 0