Pytorch冻结网络部分参数

Pytorch框架的网络结构中,所有的module都是torch.nn.Module的子类,Module中可以包含其他的Module以树状结构进行嵌套。

  • 查看神经网络的各个模块
model = ResNetAttention(depth=101, pretrained=1, num_classes=16,
                      dropout=0, grayscale=8)
方法一:model._modules.items()
返回网络中所有模块(该模块包含子模块)的iterators,例如:
for name, module in model._modules.items():
    print(name)
    print(module)
--------------------------------------------------------------
# name
base 
# module
ResNet( 
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
... ...)
classifier # name
Conv2d(2048, 17, kernel_size=(1, 1), stride=(1, 1)) # module
classifiers_activation
Sequential(
  (0): Conv2d(2048, 17, kernel_size=(1, 1), stride=(1, 1))
... ...
  (15): Conv2d(2048, 17, kernel_size=(1, 1), stride=(1, 1))
)
activation # name
AggregatedActivation() #module
attentionmap
AttentionMap()
otsu_medthod
OtsuMethod()
obj_classifier
Linear(in_features=2048, out_features=16, bias=True)

方法二:model.children()
返回model中所有直接子模块的一个iterator,例如:
for module in model.children():
    print(module)
----------------------------------------------------------------
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
... ...
Conv2d(2048, 17, kernel_size=(1, 1), stride=(1, 1))
Sequential(
  (0): Conv2d(2048, 17, kernel_size=(1, 1), stride=(1, 1))
... ...)
AggregatedActivation()
AttentionMap()
OtsuMethod()
Linear(in_features=2048, out_features=16, bias=True)

# NOTE 不推荐用model.modules(), 他会递归的返回所有子模块
  • 查看网络参数
    一般冻结部分网络参数,是为了做transfer learning,所以我们会有一个已经训练好的model,可以先加载model并查看参数
# load pretrained model on cpu
# checkpoint = load_checkpoint('./logs11/checkpoint.pth.tar') ---- load on gpu
checkpoint = torch.load( './logs11/checkpoint.pth.tar', map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
start_epoch = checkpoint['epoch']
best_recall1 = checkpoint['best_recall1']
print("=> start epoch {} best top1 recall {:.1%}"
      .format(start_epoch, best_recall1))
# check parameters
for name, module in model._modules.items():
    for p in module.parameters():
        # print(p)
        print(p.size())
---------------------------------------------------------------
=> start epoch 100 best top1 recall 98.2%
torch.Size([64, 3, 7, 7])
torch.Size([64])
torch.Size([64])
torch.Size([64, 64, 1, 1])
... ...
torch.Size([17, 2048, 1, 1])
torch.Size([17])
torch.Size([16, 2048])
torch.Size([16])

  • 冻结网路部分参数
    例如,transfer learning中,我想冻结base_model的参数,只更新其他层的参数
param_optim = []
# layers = []

for name, module in model._modules.items():
    if name != "base":
        # layers.append(name)
        for p in module.parameters():
            param_optim.append(p)
    else:
        for p in module.parameters():
            p.requires_grad = False
# print(param_optim)
# print(layers)
-------------------------------------------------------------
可以把需要更新layers打印出来看
['classifier', 'classifiers_activation', 'activation', 'attentionmap', 'otsu_medthod', 'obj_classifier']
其中有一些layers是没有参数的,不过不影响,param_optim中包含了所有需要更新的参数;其他参数
可以requires_grad = False,就不用计算它们的梯度了

  • 更改optimizer的参数
# optimizer
param_groups = [{'params': param_optim, 'lr_mult': 0.1}]
if args.optimizer == 'sgd':
    optimizer = torch.optim.SGD(param_groups, lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)
elif args.optimizer == 'adam':
    optimizer = torch.optim.Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay)
else:
    raise ValueError("Cannot recognize optimizer type:", args.optimizer)

以上冻结参数的方法在train.py里面进行修改即可。
还有一个小trick 可以在网络里面修改

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
 
        for p in self.parameters():
            p.requires_grad=False
 
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
可以在中间插入requires_grad=False,插入行前面参数就是False
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,335评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,895评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,766评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,918评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,042评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,169评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,219评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,976评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,393评论 1 304
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,711评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,876评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,562评论 4 336
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,193评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,903评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,142评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,699评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,764评论 2 351