一维线性回归的PyTorch实现

这段代码是我学习PyTorch的时候参照教材写的,记录一下学习过程。

import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt

# 创建训练数据
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],
                    [9.779],[6.182],[7.59],[2.167],[7.042],
                    [10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],
                    [3.366],[2.596],[2.53],[1.221],[2.827],
                    [3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
# 画出来看看
# axs = plt.plot(x_train,y_train,'ob')
# plt.show()


x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)

# 定义线性回归模型
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.liner = nn.Linear(1,1)  #输入和输出都是1

    def forward(self,x):
        out = self.liner(x)
        return out

# 模型初始化
if torch.cuda.is_available():
    model = LinearRegression().cuda()
else:
    model = LinearRegression()

# 定义损失函数和优化函数
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.001)

# 开始训练模型
num_epochs = 1000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        target = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        target = Variable(y_train)

    # forward
    out = model(inputs)
    loss = criterion(out,target)
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1)%20 ==0:
        print('Epoch[{}/{}],loss:{:.6f}'
              .format(epoch+1,num_epochs,loss.item()))
print(model.liner.weight.item(),model.liner.bias.item())
# 预测结果
model.eval()
predict = model(Variable(x_train))
predict = predict.data.numpy()
# 画出回归结果
plt.plot(x_train.numpy(),y_train.numpy(),'ob')
plt.plot(x_train.numpy(),predict)
plt.show()
图1 预测结果
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 夜莺2517阅读 127,761评论 1 9
  • 版本:ios 1.2.1 亮点: 1.app角标可以实时更新天气温度或选择空气质量,建议处女座就不要选了,不然老想...
    我就是沉沉阅读 6,961评论 1 6
  • 我是黑夜里大雨纷飞的人啊 1 “又到一年六月,有人笑有人哭,有人欢乐有人忧愁,有人惊喜有人失落,有的觉得收获满满有...
    陌忘宇阅读 8,605评论 28 53
  • 兔子虽然是枚小硕 但学校的硕士四人寝不够 就被分到了博士楼里 两人一间 在学校的最西边 靠山 兔子的室友身体不好 ...
    待业的兔子阅读 2,647评论 2 9