参考链接:
https://zhuanlan.zhihu.com/p/75054200
https://github.com/qubvel/segmentation_models
一、目的:为什么需要用hook
debug model是非常麻烦的事情:会遇到梯度爆炸,张量尺寸不对应等问题。
最基本的解决办法是在forward()
里丢print和加入断点。
hooks可以解决以上debug的问题。hook与每一层相关联,并且每一层被使用的时候hook都被调用。hook能够冻结特定模块中前向传播或后向传播(forward or backward pass)的执行,并且处理前后向传播相关的输入和输出。
二、hook简介
hook是一个可调用的对象(callable object),它预定义了函数声明(即函数参数,返回值,调用方式等)(with a predefined signature, which can be registerd to any nn.Module
object.)
当触发方法(trigger method)作用于module(module)上(即:forward()
或backward()
)module本身和对应的输入和可能的输出都会传到hook上,并在运算进行到下一个module前执行该hook。
一些sidenotes:
- hook 有两类,对Variabel的hook和对nn. Module的hook。类型取决于hook的注册对象。
在Pytorch中,我们可以把module的hook注册为以下几种类型: - forward prehook (在前向传播之前执行)
- forward hook (在前向传播之后执行)
- backward hook (在后向传播之后执行)
三、具体例子
假设我们需要观察ResNet24每个卷积层的输出,这个任务可以用hook去解决。
先定义model
import torch
from torchvision.models import resnet34
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = resnet34(pretrained=True)
model = model.to(device
然后我们需要建立hooks去存放输出。对我们的任务来说,一个基本的可调用的对象(callable oject)就够了。
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs = []
SaveOutput
这个实例(instance)会记录前向传播的输出张量并存在列表里。(类/class是模板,实例/instance是根据类创建的对象)
Note: 加了双下划线即私有化方法。私有化方法后,我们只能在类的内部使用该方法,不能被外部调用。
一个forward hook可以用register_forward_hook(hook)
方法被注册。其他类型的hook,我们可以用register_backward_hook
和register_forward_pre_hook
。这些方法的返回值是hook指针(hook handle),可以被用于从module里移除这些hooks。
下面我们对每一个卷积层都注册一个hook。
save_output = SaveOutput()
hook_handles = []
for layer in model.modules():
if isinstance(layer, torch.nn.modules.conv.Conv2d):
handle = layer.register_forward_hook(save_output)
hook_handles.append(handle)
如果现在我们运行len(save_output.outputs)
会返回0,因为hook还没被调用。
四、测试效果
我们可以用下图进行测试:
from PIL import Image
from torchvision import transforms as T
image = Image.open('cat.jpg')
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
X = transform(image).unsqueeze(dim=0).to(device)
out = model(X)
钩子会在每一个卷积层执行每一个前向传播执行之后被调用,因此,每一层的输出就被保存下来了。
查看结果如下:
>>> len(save_output.outputs)
36
通过查看列表里的这些张量,我们可以可视化这个网络看到了什么。
下图是第一层的输出。即,对输入图像和number_of_out_channel个卷积核进行卷积运算,所得到的第一层每个channel的输出。
如果我们去看这个网络更深层的内容,网络所学到的特征会越来越high level。举个例子,下图中有一个过滤器(filter)看起来就是专门用来检测眼睛的。
可视化中间卷积层的输出的代码如下(只输出16个filter的卷积结果):
import matplotlib.pyplot as plt
def module_output_to_numpy(tensor):
return tensor.detach().to('cpu').numpy()
images = module_output_to_numpy(save_output.outputs[0])
#这里的0代表读取output里第一个卷积层的输出
with plt.style.context("seaborn-white"):
plt.figure(figsize=(20, 20), frameon=False)
for idx in range(16):
plt.subplot(4, 4, idx+1)
plt.imshow(images[0, idx])
plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[]);
五、总结
hook除了能存中间层的输出之外,还有更多的功能。比如网络剪枝(network pruning)中往往要使用到hook。
其他:torch.unsqueeze(input, dim, out=None)
以下参考该链接
作用:扩展维度
返回一个新的张量,对输入的既定位置插入维度 1
参数:
tensor (Tensor) – 输入张量
dim (int) – 插入维度的索引
out (Tensor, optional) – 结果张量
举个例子:
x = torch.Tensor([1, 2, 3, 4])
print(x) # tensor([1., 2., 3., 4.])
print(x.size()) # torch.Size([4])
print(x.dim()) # 1
print(x.numpy()) # [1. 2. 3. 4.]
print(torch.unsqueeze(x, 0)) # tensor([[1., 2., 3., 4.]])
print(torch.unsqueeze(x, 0).size()) # torch.Size([1, 4])
print(torch.unsqueeze(x, 0).dim()) # 2
其他:torch.squeeze(input, dim=None, out=None)
作用:降维
torch.squeeze(input, dim=None, out=None)
将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
当给定dim时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。
为何只去掉 1 呢?
多维张量本质上就是一个变换,如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度。
六、变量的hook
参考链接
对于中间变量z
,hook 的使用方式为:z.register_hook(hook_fn)
,其中 hook_fn
为一个用户自定义的函数,其签名为:
hook_fn(grad) -> Tensor or None
它的输入为变量 z 的梯度,输出为一个 Tensor 或者是 None (None 一般用于直接打印梯度)。
反向传播时,梯度传播到变量 z,再继续向前传播之前,将会传入 hook_fn。如果 hook_fn的返回值是 None,那么梯度将不改变,继续向前传播,如果 hook_fn的返回值是 Tensor 类型,则该 Tensor 将取代 z 原有的梯度,向前传播。
下面的示例代码中hook_fn
不改变梯度值,仅仅是打印梯度:
import torch
x = torch.Tensor([0, 1, 2, 3]).requires_grad_()
y = torch.Tensor([4, 5, 6, 7]).requires_grad_()
w = torch.Tensor([1, 2, 3, 4]).requires_grad_()
z = x+y
# ===================
def hook_fn(grad):
print(grad)
z.register_hook(hook_fn)
# ===================
o = w.matmul(z)
print('=====Start backprop=====')
o.backward()
print('=====End backprop=====')
print('x.grad:', x.grad)
print('y.grad:', y.grad)
print('w.grad:', w.grad)
print('z.grad:', z.grad)
运行结果如下:
=====Start backprop=====
tensor([1., 2., 3., 4.])
=====End backprop=====
x.grad: tensor([1., 2., 3., 4.])
y.grad: tensor([1., 2., 3., 4.])
w.grad: tensor([ 4., 6., 8., 10.])
z.grad: None
函数hook_fn(grad)
中的grad就是通过注册hook和运行代码进行反向传播时得到的gradient的值,该值会输入进hook_fn
函数。
通过上述实验得出结论,z
绑定了hook_fn
后,梯度反向传播时将会打印出o
对 z
的偏导