import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
torch.manual_seed(1) # reproducible
# fake data
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
x,y = Variable(x,requires_grad=False) ,Variable(y,requires_grad=False)
def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1),
)
optimizer = torch.optim.SGD(net1.parameters(),lr=0.5)
loss_func = torch.nn.MSELoss()
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# plot result
plt.figure(1,figsize=(10,3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
# 2 ways to save the net
torch.save(net1,'net.pkl') # save entire net
torch.save(net1.state_dict(),'net_params.pkl') # save only the parameters
def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
prediction = net2(x)
#plot result
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
def restore_params():
# restore only the parameters in net1 to net3
net3 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1),
)
# copy net1's parameters into net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
# plot result
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
plt.show()
# save net1
save()
# restore entire net
restore_net()
# restore only the net parameters
restore_params()
torch 保存提取
最后编辑于 :
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- 介绍如何保存神经网络,这样以后想要用的时候直接提取就可以保存模型的步骤1.训练模型2.保存模型3.导入模型并应用 ...
- 欢迎大家关注我的专题:爬虫修炼之道 上篇 爬虫修炼之道——编写一个爬取多页面的网络爬虫主要讲解了如何使用pytho...