希望以后可以一直注意到这个问题,就是PyTorch的图在进行backward的时候是不保存中间变量的grad的,因此在之后用.grad
去查看梯度来检查梯度传播是无效的。这个问题,也可参见why-cant-i-see-grad-of-an-intermediate-variable中提出的详细例子。
那如何能够查看中间变量的梯度呢?Adam Paszke在这个问题底下又给出了一个简短而有用的例子,具体如下:
grads = {}
def save_grad(name):
def hook(grad):
grads[name] = grad
return hook
x = Variable(torch.randn(1,1), requires_grad=True)
y = 3*x
z = y**2
# In here, save_grad('y') returns a hook (a function) that keeps 'y' as name
y.register_hook(save_grad('y'))
z.register_hook(save_grad('z'))
z.backward()
print(grads['y'])
print(grads['z'])
主要是通过hook机制,使得PyTorch图在进行backward的时候触发保存下中间变量的grad。