- 每一个
tensor
都有register_hook
方法,每次当关于这个参数的gradient
被计算出来以后都会调用这个方法,因此可以用于debug
等等,下面是对一部分梯度进行mask
。
def _emb_hook(self, grad):
return grad * Variable(self.grad_mask.unsqueeze(1)).type_as(grad)
def set_grad_mask(self, mask):
self.grad_mask = torch.from_numpy(mask)
self.embedding.weight.register_hook(self._emb_hook)