Pytorch自定义torch.autograd.Function

转载来的:https://zhuanlan.zhihu.com/p/27783097

# -*- coding:utf8 -*-

import torch
from torch.autograd import Variable

class MyReLU(torch.autograd.Function):

    def forward(self, input_):
        # 在forward中,需要定义MyReLU这个运算的forward计算过程
        # 同时可以保存任何在后向传播中需要使用的变量值
        self.save_for_backward(input_)         # 将输入保存起来,在backward时使用
        output = input_.clamp(min=0)           # relu就是截断负数,让所有负数等于0
        return output

    def backward(self, grad_output):
        # 根据BP算法的推导(链式法则),dloss / dx = (dloss / doutput) * (doutput / dx)
        # dloss / doutput就是输入的参数grad_output、
        # 因此只需求relu的导数,在乘以grad_outpu
        input_, = self.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input_ < 0] = 0               # 上诉计算的结果就是左式。即ReLU在反向传播中可以看做一个通道选择函数,所有未达到阈值(激活值<0)的单元的梯度都为0
        return grad_input

# Wrap一个ReLU函数
# 可以直接把刚才自定义的ReLU类封装成一个函数,方便直接调用
def relu(input_):
    # MyReLU()是创建一个MyReLU对象,
    # Function类利用了Python __call__操作,使得可以直接使用对象调用__call__制定的方法
    # __call__指定的方法是forward,因此下面这句MyReLU()(input_)相当于
    # return MyReLU().forward(input_)
    return MyReLU()(input_)

input_ = Variable(torch.linspace(-3, 3, steps=5))
print input_
print relu(input_)
# input_ = Variable(torch.randn(1))
# relu = MyReLU()
# output_ = relu(input_)
#
# # 这个relu对象,就是output_.creator,即这个relu对象将output与input连接起来,形成一个计算图
# print relu
# print output_.creator
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

  • 这一周,比想象中还要快一些。每天在充实的工作中,在做看起来不太难的事情,可能我不太机灵,总要慢一些。 从前,我有每...
    黄帅_360d阅读 331评论 2 1
  • 熙辰丶著 散文集.1 我本是一个不具魅力的天秤女子,身边的人都说我这不好那不好,与生俱来与美丽无缘。 身边的亲人说...
    浪子熙辰阅读 685评论 0 1
  • 2017.08.06日去了一趟珠海, 第一次独自领着小孩子出远门,去到一个陌生的地方。 原本是很好的计划,...
    NancyDY阅读 573评论 0 0
  • 六月,夏季。 这是我毕业的季节,高中的一切,都随着最后一声铃响,滞留在过去,那些人事,全都消散了。得知成绩时,我是...
    杜烟淮阅读 383评论 0 0

友情链接更多精彩内容