PyTorch torch.optim 传入两个网络参数

CLASS torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

以Adam优化器为例,其params定义如下:

  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

所以我们传入的是一个迭代器,可以通过tertools.chain将两个网络参数连接起来。

import itertools
......
optimizer = torch.optim.Adam(itretools.chain(net1.parameters(), net2.parameters()), 0.001, weight_decay = 1e-5)
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容