[PyTorch]专项 输出-模型存储与加载

一、Python模块 & data

%matplotlib inline 
%config InlineBackend.figure_format = 'retina' 
import matplotlib.pyplot as plt 
import torch 
from torch import nn 
from torch import optim 
import torch.nn.functional as F 
from torchvision import datasets, transforms 

#自定义模块
import helper 
import fc_model 

#Define a transform to normalize the data 
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5, ), (0.5, ))]) 
#Download and load the training data 
trainset = datasets.FashioniNIST('~/.pytorch/F_MNIST_data', download=True, 
                                train=True, transform=transform) 
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) 

#Download and load the test data 
testset = datasets.FashioniNIST ('~/.pytorch/F_MNIST_data', download=True,
                                train=False, transform=transform) 
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True) 

二、建立模型 & 训练

#建立模型 自定义模块fc_model
model = fc_model.Network(784, 10, [512, 256, 128])
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

#训练模型
fc_model.tranin(model, trainloader, testloader, criterion, optimizer, epochs=2)
  • print(model)

    模型:网络架构

  • print(model.state_dict().keys())

    模型参数:存储在字典的键中

三、存储/加载模型

1. 存储模型(参数)

字典checkpoint:保存记录维度的信息

  1. 网络结构
  • input
  • output
  • hidden layers
  • .state_dict() 参数(weights, bias)
checkpoint = {'input_size': 784,
            'output_size': 10,
            'hidden_layers': [each.out_features for each in model.hidden_layers],
            'state_dict': model.state_dict()}
             
torch.save(checkpoint, 'checkpoint.pth')}
model.hidden_layers

注意:属性in_features, out_features

2. 加载模型(参数)

加载模型的参数必须与存储好的模型一致,否则加载错误

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath) 
    model = fc_model.Network(checkpoint['input_size'],
                            checkpoint['output_size'],
                            checkpoint['hidden_layers']) #.out_features提取维度信息
    model.load_state_dict(checkpoint['state_dict'])
    return model
 
#加载模型
model = load_checkpoint('checkpoint.pth')
print (model) 
加载模型的参数
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容