pytorch nn.Parameter自动注册到module中

Parameter作为Module类的参数,可以自动的添加到Module类的参数列表中,并且可以使用Module.parameters()提供的迭代器获取到
使用parm.data访问这个参数的具体数据

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(1, 1)
        self.fun_param = nn.Parameter(torch.FloatTensor([1, ]))
net = Net()
for param in net.named_parameters():
    print(param)
    print("*"*50)

其中的pararm中就可以含有自定义的fun_param参数,而不是只有nn.Linear

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

友情链接更多精彩内容