一、Python模块 & data
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms
#自定义模块
import helper
import fc_model
#Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, ))])
#Download and load the training data
trainset = datasets.FashioniNIST('~/.pytorch/F_MNIST_data', download=True,
train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
#Download and load the test data
testset = datasets.FashioniNIST ('~/.pytorch/F_MNIST_data', download=True,
train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
二、建立模型 & 训练
#建立模型 自定义模块fc_model
model = fc_model.Network(784, 10, [512, 256, 128])
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
#训练模型
fc_model.tranin(model, trainloader, testloader, criterion, optimizer, epochs=2)
-
print(model)
-
print(model.state_dict().keys())
三、存储/加载模型
1. 存储模型(参数)
字典checkpoint:保存记录维度的信息
- 网络结构
- input
- output
- hidden layers
.state_dict()
参数(weights, bias)
checkpoint = {'input_size': 784,
'output_size': 10,
'hidden_layers': [each.out_features for each in model.hidden_layers],
'state_dict': model.state_dict()}
torch.save(checkpoint, 'checkpoint.pth')}
注意:属性in_features, out_features
2. 加载模型(参数)
加载模型的参数必须与存储好的模型一致,否则加载错误
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = fc_model.Network(checkpoint['input_size'],
checkpoint['output_size'],
checkpoint['hidden_layers']) #.out_features提取维度信息
model.load_state_dict(checkpoint['state_dict'])
return model
#加载模型
model = load_checkpoint('checkpoint.pth')
print (model)