本文参考pytorch官方文档https://pytorch-cn.readthedocs.io/zh/latest/notes/extending/
扩展torch.autograd
如果想要添加一个新的Operation 到autograd的话,我们的operation需要继承class Function. autograd使用Function计算结果和梯度,同时编码operation的历史。每个新的operation(function)都需要实现三个方法:
__init__(operation) 如果你的operation包含非Variabel参数,那么就将其作为__Init__的参数传入到operation中。例如: AddConstant Function加一个常数,Transpose Function需要指定哪个维度需要交换。如果你的operation中不需要额外的参数,你可以忽略__init__。
forward() - 在里面写执行此operation的代码。可以有任意数量的参数。记住:forward()的参数只能是Variable,函数的返回值既可以是Variable,也可以是Variables的tuple。同时,请参考Function中的doc 有哪些方法只能在forward中调用的。
backward() - 梯度计算公式,参数的个数和forward返回值的个数一样,每个参数代表传回到此operation的梯度, backward()的返回值的个数应该和此operation输入的个数一样,每个返回值对应了输入值的梯度,如果operation的输入不需要梯度,或者不可导,你可以返回None.如果forward()存在可选参数,你可以返回比输入更多的梯度,只是返回的是None.