Pytorch里hook的介绍

参考链接:

https://zhuanlan.zhihu.com/p/75054200
https://github.com/qubvel/segmentation_models

主要参考链接
文中所提到的完整代码和教程见该链接

一、目的:为什么需要用hook

debug model是非常麻烦的事情:会遇到梯度爆炸,张量尺寸不对应等问题。
最基本的解决办法是在forward()里丢print和加入断点。
hooks可以解决以上debug的问题。hook与每一层相关联,并且每一层被使用的时候hook都被调用。hook能够冻结特定模块中前向传播或后向传播(forward or backward pass)的执行,并且处理前后向传播相关的输入和输出。

二、hook简介

hook是一个可调用的对象(callable object),它预定义了函数声明(即函数参数,返回值,调用方式等)(with a predefined signature, which can be registerd to any nn.Module object.)
当触发方法(trigger method)作用于module(module)上(即:forward()backward())module本身和对应的输入和可能的输出都会传到hook上,并在运算进行到下一个module前执行该hook。
一些sidenotes:

  • hook 有两类,对Variabel的hook和对nn. Module的hook。类型取决于hook的注册对象。
    在Pytorch中,我们可以把module的hook注册为以下几种类型:
  • forward prehook (在前向传播之前执行)
  • forward hook (在前向传播之后执行)
  • backward hook (在后向传播之后执行)

三、具体例子

假设我们需要观察ResNet24每个卷积层的输出,这个任务可以用hook去解决。
先定义model

import torch
from torchvision.models import resnet34

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = resnet34(pretrained=True)
model = model.to(device

然后我们需要建立hooks去存放输出。对我们的任务来说,一个基本的可调用的对象(callable oject)就够了。

class SaveOutput:
    def __init__(self):
        self.outputs = []
        
    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out)
        
    def clear(self):
        self.outputs = []

SaveOutput这个实例(instance)会记录前向传播的输出张量并存在列表里。(类/class是模板,实例/instance是根据类创建的对象)
Note: 加了双下划线即私有化方法。私有化方法后,我们只能在类的内部使用该方法,不能被外部调用。
一个forward hook可以用register_forward_hook(hook)方法被注册。其他类型的hook,我们可以用register_backward_hookregister_forward_pre_hook。这些方法的返回值是hook指针(hook handle),可以被用于从module里移除这些hooks。
下面我们对每一个卷积层都注册一个hook。

save_output = SaveOutput()

hook_handles = []

for layer in model.modules():
    if isinstance(layer, torch.nn.modules.conv.Conv2d):
        handle = layer.register_forward_hook(save_output)
        hook_handles.append(handle)

如果现在我们运行len(save_output.outputs)会返回0,因为hook还没被调用。

四、测试效果

我们可以用下图进行测试:


图片
from PIL import Image
from torchvision import transforms as T

image = Image.open('cat.jpg')
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
X = transform(image).unsqueeze(dim=0).to(device)
out = model(X)

钩子会在每一个卷积层执行每一个前向传播执行之后被调用,因此,每一层的输出就被保存下来了。
查看结果如下:

>>> len(save_output.outputs)
36

通过查看列表里的这些张量,我们可以可视化这个网络看到了什么。
下图是第一层的输出。即,对输入图像和number_of_out_channel个卷积核进行卷积运算,所得到的第一层每个channel的输出。

ResNet34的第一层的输出

如果我们去看这个网络更深层的内容,网络所学到的特征会越来越high level。举个例子,下图中有一个过滤器(filter)看起来就是专门用来检测眼睛的。


ResNet34的第16层的输出

可视化中间卷积层的输出的代码如下(只输出16个filter的卷积结果):

import matplotlib.pyplot as plt

def module_output_to_numpy(tensor):
    return tensor.detach().to('cpu').numpy()    

images = module_output_to_numpy(save_output.outputs[0])
#这里的0代表读取output里第一个卷积层的输出

with plt.style.context("seaborn-white"):
    plt.figure(figsize=(20, 20), frameon=False)
    for idx in range(16):
        plt.subplot(4, 4, idx+1)
        plt.imshow(images[0, idx])
    plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[]);

五、总结

hook除了能存中间层的输出之外,还有更多的功能。比如网络剪枝(network pruning)中往往要使用到hook

其他:torch.unsqueeze(input, dim, out=None)

以下参考该链接
作用:扩展维度
返回一个新的张量,对输入的既定位置插入维度 1
参数:
tensor (Tensor) – 输入张量
dim (int) – 插入维度的索引
out (Tensor, optional) – 结果张量
举个例子:

x = torch.Tensor([1, 2, 3, 4])
print(x)  # tensor([1., 2., 3., 4.])
print(x.size())  # torch.Size([4])
print(x.dim())  # 1
print(x.numpy())  # [1. 2. 3. 4.]
print(torch.unsqueeze(x, 0))  # tensor([[1., 2., 3., 4.]])
print(torch.unsqueeze(x, 0).size())  # torch.Size([1, 4])
print(torch.unsqueeze(x, 0).dim())  # 2

其他:torch.squeeze(input, dim=None, out=None)

作用:降维
torch.squeeze(input, dim=None, out=None)
将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
当给定dim时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。
为何只去掉 1 呢?
多维张量本质上就是一个变换,如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度。

六、变量的hook

参考链接
对于中间变量z,hook 的使用方式为:z.register_hook(hook_fn),其中 hook_fn为一个用户自定义的函数,其签名为:
hook_fn(grad) -> Tensor or None
它的输入为变量 z 的梯度,输出为一个 Tensor 或者是 None (None 一般用于直接打印梯度)。
反向传播时,梯度传播到变量 z,再继续向前传播之前,将会传入 hook_fn。如果 hook_fn的返回值是 None,那么梯度将不改变,继续向前传播,如果 hook_fn的返回值是 Tensor 类型,则该 Tensor 将取代 z 原有的梯度,向前传播。

下面的示例代码中hook_fn 不改变梯度值,仅仅是打印梯度:

import torch

x = torch.Tensor([0, 1, 2, 3]).requires_grad_()
y = torch.Tensor([4, 5, 6, 7]).requires_grad_()
w = torch.Tensor([1, 2, 3, 4]).requires_grad_()
z = x+y
# ===================
def hook_fn(grad):
    print(grad)
z.register_hook(hook_fn)
# ===================
o = w.matmul(z)
print('=====Start backprop=====')
o.backward()
print('=====End backprop=====')
print('x.grad:', x.grad)
print('y.grad:', y.grad)
print('w.grad:', w.grad)
print('z.grad:', z.grad)

运行结果如下:

=====Start backprop=====
tensor([1., 2., 3., 4.])
=====End backprop=====
x.grad: tensor([1., 2., 3., 4.])
y.grad: tensor([1., 2., 3., 4.])
w.grad: tensor([ 4.,  6.,  8., 10.])
z.grad: None

函数hook_fn(grad)中的grad就是通过注册hook和运行代码进行反向传播时得到的gradient的值,该值会输入进hook_fn函数。
通过上述实验得出结论,z绑定了hook_fn后,梯度反向传播时将会打印出oz的偏导

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,717评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,501评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,311评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,417评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,500评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,538评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,557评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,310评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,759评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,065评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,233评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,909评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,548评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,172评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,420评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,103评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,098评论 2 352