pytorch多GPU并行计算及模型的保存与载入

1.多gpu并行计算

class Net(nn.Module):
    def __init__(input, output):
        pass
        #define your network 
net = Net(input, output) #实例化模型
net = nn.DataParallel(net) #数据并行
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #初始化计算设备
net .to(device)

2.模型的保存

if
torch.save(net.)

3.模型的重载

checkpoint = torch.load(resume)
state_dict =checkpoint['state_dict']

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove 'module.' of dataparallel
    new_state_dict[name]=v

model.load_state_dict(new_state_dict)

4.模型的迁移

# cpu or gpu
torch.load('model/path', map_location='cpu')

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。