[PyTorch] register_hook

  • 每一个tensor都有register_hook方法,每次当关于这个参数的gradient被计算出来以后都会调用这个方法,因此可以用于debug等等,下面是对一部分梯度进行mask
    def _emb_hook(self, grad):
        return grad * Variable(self.grad_mask.unsqueeze(1)).type_as(grad)

    def set_grad_mask(self, mask):
        self.grad_mask = torch.from_numpy(mask)
        self.embedding.weight.register_hook(self._emb_hook)
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容