pytorch 的 Module.modules() vs Module.children()

这俩货都会返回模型的子模块,但其行为大不相同,经常会被搞混。本文将区分二者的功能。连带着区分其各自对应的 named_modulesnamed_children

Module.modules()

如果你想迭代地返回模型所有的模块,那么应该使用 .modules(),例如:

from torch import nn

net = nn.Sequential(nn.Linear(2,2),
                    nn.ReLU(),
                    nn.Sequential(nn.Sigmoid(), nn.ReLU()))

print(list(net.modules()))

其输出结果为(为了方便分析,把结果拉开并且加上了注释):

[Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): ReLU()
  (2): Sequential(
    (0): Sigmoid()
    (1): ReLU()
  )
), # ----------------------------------------------- 整体网络结构

Linear(in_features=2, out_features=2, bias=True),  # 第一个子模块

ReLU(),  # ----------------------------------------- 第二个子模块               


Sequential(
  (0): Sigmoid()
  (1): ReLU()
),  # ----------------------------------------------- 第三个子模块

# 到这里子模块也遍历完了,开始遍历子子(孙)模块,也就是第三个子模块内部的模块

Sigmoid(),  # ---------------------------------- 第一个孙模块
ReLU()  # -------------------------------------- 第二个孙模块
]

在初始化网络权重时,有时需要根据不同层来采用不同初始化策略,此时就要用到 named_modules 。因为只有它才可以递归地遍历到网络的最小构成模块。部分pytorch 的代码采用 named_children 迭代,其实是错误的,但往往错误的 weight initialization 方式,仍然可以达到很好的效果,由此可见初始化实在是一门玄学。

Module.children()

该方法只输出模块的直接子模块,不会递归地输出最小模块单元。

from torch import nn

net = nn.Sequential(nn.Linear(2,2),
                    nn.ReLU(),
                    nn.Sequential(nn.Sigmoid(), nn.ReLU()))

print(list(net.children()))

运行结果如下:

[Linear(in_features=2, out_features=2, bias=True),  # 第一个子模块

ReLU(), # --------------------------------------------第二个子模块

Sequential(
  (0): Sigmoid()
  (1): ReLU()
) # --------------------------------------------------第三个子模块
]

可以看到只有直接子模块被打印出来,子模块的子模块并没有被遍历。这个方法主要用于做模型的拼接、初始化,等等。

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容