nnU中的初始化部分参数测试demo


import torch
from nnunet.network_architecture.initialization import InitWeights_He
from nnunet.network_architecture.u2net_at import u2net_at
from torch import nn


conv_op = nn.Conv3d
dropout_op = nn.Dropout3d
norm_op = nn.InstanceNorm3d


norm_op_kwargs = {'eps': 1e-5, 'affine': True}
dropout_op_kwargs = {'p': 0, 'inplace': True}
net_nonlin = nn.LeakyReLU
net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
network = u2net_at(1, 32, 3,
                        4,
                        1, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
                        dropout_op_kwargs,
                        net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
                        4, 4, False, True, True)

old_net = torch.load('/home/lab347/jyh/other_item/nnUNet/dataset/nnUNet_trained_models/'
                     'nnUNet/3d_lowres/Task040_KiTS/nnUNetTrainerV2__nnUNetPlansv2.1/fold_3_u2net/'
                     'model_best.model', map_location=torch.device('cpu'))
i = 0
for k,v in old_net['state_dict'].items():
    i += 1
    if i < 10 and i > 6:
        print(k)
        print(v)

i = 0
print('-----********************old_weights*************************-----')
for k,v in network.state_dict().items():
    i += 1
    if i < 10 and i > 6:
        print(k)
        print(v)

network.load_state_dict(old_net['state_dict'],strict=False)
i = 0
print('-----********************new_weights************************-----')
for k,v in network.state_dict().items():
    i += 1
    if i < 10 and i > 6:
        print(k)
        print(v)

print('----------')
经测试,相同部分的参数已初始化为old_net中的参数,注意字段strict的值
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容