1.多gpu并行计算
class Net(nn.Module):
def __init__(input, output):
pass
#define your network
net = Net(input, output) #实例化模型
net = nn.DataParallel(net) #数据并行
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #初始化计算设备
net .to(device)
2.模型的保存
if
torch.save(net.)
3.模型的重载
checkpoint = torch.load(resume)
state_dict =checkpoint['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of dataparallel
new_state_dict[name]=v
model.load_state_dict(new_state_dict)
4.模型的迁移
# cpu or gpu
torch.load('model/path', map_location='cpu')