PyTorch中定义模型时,有时候会遇到self.register_buffer('name', Tensor)的操作,该方法的作用是定义一组参数,该组参数的特别之处在于:模型训练时不会更新(即调用 optimizer.step() 后该组参数不会变化,只可人为地改变它们的值),但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。
为了更好地理解这句话,按照惯例,我们通过一个例子实验来解释:
首先,定义一个模型并实例化:
import torch
import torch.nn as nn
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# (1)常见定义模型时的操作
self.param_nn = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(1, 1, 3, bias=False)),
('fc', nn.Linear(1, 2, bias=False))
]))
# (2)使用register_buffer()定义一组参数
self.register_buffer('param_buf', torch.randn(1, 2))
# (3)使用形式类似的register_parameter()定义一组参数
self.register_parameter('param_reg', nn.Parameter(torch.randn(1, 2)))
# (4)按照类的属性形式定义一组变量
self.param_attr = torch.randn(1, 2)
def forward(self, x):
return x
net = Model()
上例中,我们通过继承nn.Module类定义了一个模型,在模型参数的定义中,我们分别以(1)常见的nn.Module类形式、(2)self.register_buffer()形式、(3)self.register_parameter()形式,以及(4)python类的属性形式定义了4组参数。
(1)哪些参数可以在模型训练时被更新?
这可以通过net.parameters()查看,因为定义优化器时是这样的:optimizer = SGD(net.parameters(), lr=0.1)。为了方便查看,我们使用 net.named_parameters():
In [8]: list(net.named_parameters())
Out[8]:
[('param_reg',
Parameter containing:
tensor([[-0.0617, -0.8984]], requires_grad=True)),
('param_nn.conv.weight',
Parameter containing:
tensor([[[[-0.3183, -0.0426, -0.2984],
[-0.1451, 0.2686, 0.0556],
[-0.3155, 0.0451, 0.0702]]]], requires_grad=True)),
('param_nn.fc.weight',
Parameter containing:
tensor([[-0.4647],
[ 0.7753]], requires_grad=True))]
可以看到,我们定义的4组参数中,只有(1)和(3)定义的参数可以被更新,而self.register_buffer()和以python类的属性形式定义的参数都不能被更新。也就是说,modules和parameters可以被更新,而buffers和普通类属性不行。
那既然这两种形式定义的参数都不能被更新,二者可以互相替代吗?答案是不可以,原因看下一节:
(2)这其中哪些才算是模型的参数呢?
模型的所有参数都装在 state_dict 中,因为保存模型参数时直接保存 net.state_dict()。我们看一下其中究竟是哪些参数:
In [9]: net.state_dict()
Out[9]:
OrderedDict([('param_reg', tensor([[-0.0617, -0.8984]])),
('param_buf', tensor([[-1.0517, 0.7663]])),
('param_nn.conv.weight',
tensor([[[[-0.3183, -0.0426, -0.2984],
[-0.1451, 0.2686, 0.0556],
[-0.3155, 0.0451, 0.0702]]]])),
('param_nn.fc.weight',
tensor([[-0.4647],
[ 0.7753]]))])
可以看到,通过 nn.Module 类、self.register_buffer() 以及 self.register_parameter() 定义的参数都在 state-dict 中,只有用python类的属性形式定义的参数不包含其中。也就是说,保存模型时,buffers,modules和parameters都可以被保存,但普通属性不行。
(3)self.register_buffer() 的使用方法
在用self.register_buffer('name', tensor) 定义模型参数时,其有两个形参需要传入。第一个是字符串,表示这组参数的名字;第二个就是tensor 形式的参数。
在模型定义中调用这个参数时(比如改变这组参数的值),可以使用self.name 获取。本文例中,就可用self.param_buf 引用。这和类属性的引用方法是一样的。
在实例化模型后,获取这组参数的值时,可以用 net.buffers() 方法获取,该方法返回一个生成器(可迭代变量):
In [10]: net.buffers()
Out[10]: <generator object Module.buffers at 0x00000289CA0032E0>
In [11]: list(net.buffers())
Out[11]: [tensor([[-1.0517, 0.7663]])]
# 也可以用named_buffers() 方法同时获取名字
In [12]: list(net.named_buffers())
Out[12]: [('param_buf', tensor([[-1.0517, 0.7663]]))]
(4)modules, parameters 和 buffers
实际上,PyTorch 定义的模型用OrderedDict() 的方式记录这三种类型,分别保存在self._modules, self._parameters 和 self._buffers 三个私有属性中。调试模式时就可以看到每个模型都有这几个私有属性:
由于是私有属性,我们无法在实例化的变量上调用这些属性,可以在模型定义中调用它们:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# 常见定义模型时的操作
self.param_nn = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(1, 1, 3, bias=False)),
('fc', nn.Linear(1, 2, bias=False))
]))
# 使用register_buffer()定义一组参数
self.register_buffer('param_buf', torch.randn(1, 2))
# 使用形式类似的register_parameter()定义一组参数
self.register_parameter('param_reg', nn.Parameter(torch.randn(1, 2)))
# 按照类的属性形式定义一组变量
self.param_attr = torch.randn(1, 2)
print('self._modules: ', self._modules)
print('self._parameters: ', self._modules)
print('self._buffers: ', self._modules)
def forward(self, x):
return x
模型实例化时,调用了 init() 方法,我们就可以看到调用输出结果:
In [21]: net = Model()
self._modules: OrderedDict([('param_nn', Sequential(
(conv): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), bias=False)
(fc): Linear(in_features=1, out_features=2, bias=False)
))])
self._parameters: OrderedDict([('param_reg', Parameter containing:
tensor([[-0.5666, -0.2624]], requires_grad=True))])
self._buffers: OrderedDict([('param_buf', tensor([[-0.4005, -0.8199]]))])
在模型的实例化变量上调用时,三者有着相似的方法:
net.modules()
net.named_modules()
net.parameters()
net.named_parameters()
net.buffers()
net.named_buffers()
细心的读着可能会发现,self._parameters 和 net.parameters() 的返回值并不相同。这里self._parameters 只记录了使用 self.register_parameter() 定义的参数,而net.parameters() 返回所有可学习参数,包括self._modules 中的参数和self._parameters 参数的并集。
实际上,由nn.Module类定义的参数和self.register_parameter() 定义的参数性质是一样的,都是nn.Parameter 类型。