Pytorch模型中的parameter与buffer

模型保存

在 Pytorch 中一种模型保存和加载的方式如下:

# save
torch.save(model.state_dict(), PATH)

# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

可以看到模型保存的是 model.state_dict() 的返回对象。 model.state_dict() 的返回对象是一个 OrderDict ,它以键值对的形式包含模型中需要保存下来的参数,例如:

class MyModule(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModule, self).__init__()
        self.lin = nn.Linear(input_size, output_size)
    def forward(self, x):
        return self.lin(x)

module = MyModule(4, 2)
print(module.state_dict())

输出结果:

模型中的参数就是线性层的 weight 和 bias.

Parameter 和 buffer

模型中需要保存下来的参数包括两种:

  • 一种是反向传播需要被optimizer更新的,称之为 parameter
  • 一种是反向传播不需要被optimizer更新,称之为 buffer

第一种参数我们可以通过 model.parameters() 返回;第二种参数我们可以通过 model.buffers() 返回。因为我们的模型保存的是 state_dict 返回的 OrderDict,所以这两种参数不仅要满足是否需要被更新的要求,还需要被保存到OrderDict。

那么现在的问题是这两种参数如何创建呢,创建好了如何保存到OrderDict呢?

  • 第一种参数有两种方式:
    1.我们可以直接将模型的成员变量(self.xxx) 通过nn.Parameter() 创建,会自动注册到parameters中,可以通过model.parameters()返回,并且这样创建的参数会自动保存到OrderDict中去;
    2.通过nn.Parameter()创建普通Parameter对象,不作为模型的成员变量,然后将Parameter对象通过register_parameter()进行注册,可以通model.parameters()返回,注册后的参数也会自动保存到OrderDict中去;
  • 第二种参数我们需要创建tensor, 然后将tensor通过register_buffer()进行注册,可以通model.buffers() 返回,注册完后参数也会自动保存到OrderDict中去。

分析一个例子:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        buffer = torch.randn(2, 3)  # tensor
        self.register_buffer('my_buffer', buffer)
        self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成员变量

    def forward(self, x):
        # 可以通过 self.param 和 self.my_buffer 访问
        pass
model = MyModel()
for param in model.parameters():
    print(param)
print("----------------")
for buffer in model.buffers():
    print(buffer)
print("----------------")
print(model.state_dict())

输出结果:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        buffer = torch.randn(2, 3)  # tensor
        param = nn.Parameter(torch.randn(3, 3))  # 普通 Parameter 对象
        self.register_buffer('my_buffer', buffer)
        self.register_parameter("param", param)

    def forward(self, x):
        # 可以通过 self.param 和 self.my_buffer 访问
        pass
model = MyModel()
for param in model.parameters():
    print(param)
print("----------------")
for buffer in model.buffers():
    print(buffer)
print("----------------")
print(model.state_dict())

输出:

register_buffer的函数原型:

register_buffer(name, tensor) name: string tensor: Tensor

register_parameter的函数原型:

register_parameter(name, param) name: string param: Parameter

为什么不直接将不需要进行参数修改的变量作为模型类的成员变量就好了,还要进行注册?
有两个原因:
1.不进行注册,参数不能保存到 OrderDict,也就无法进行保存
2.模型进行参数在CPU和GPU移动时, 执行model.to(device),注册后的参数也会自动进行设备移动

看一个例子:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.my_tensor = torch.randn(1) # 参数直接作为模型类成员变量
        self.register_buffer('my_buffer', torch.randn(1)) # 参数注册为 buffer
        self.my_param = nn.Parameter(torch.randn(1))

    def forward(self, x):
            return x

model = MyModel()
print(model.state_dict())
model.cuda()
print(model.my_tensor)
print(model.my_buffer)

输出结果:

可以看到模型类的成员变量不在OrderDict中,不能进行保存;模型在进行设备移动时,模型类的成员变量没有进行移动。

总结

1.模型中需要进行更新的参数注册为Parameter,不需要进行更新的参数注册为buffer
2.模型保存的参数是 model.state_dict() 返回的 OrderDict
3.模型进行设备移动时,模型中注册的参数(Parameter和buffer)会同时进行移动

转载于https://zhuanlan.zhihu.com/p/89442276

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 226,488评论 6 524
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 97,466评论 3 411
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 174,084评论 0 371
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 62,024评论 1 305
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 70,882评论 6 405
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 54,395评论 1 318
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 42,539评论 3 433
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 41,670评论 0 282
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 48,194评论 1 329
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 40,173评论 3 352
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 42,302评论 1 362
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 37,872评论 5 354
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 43,581评论 3 342
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 33,984评论 0 25
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 35,179评论 1 278
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 50,888评论 3 385
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 47,306评论 2 369

推荐阅读更多精彩内容