PyTorch教程-3:PyTorch中神经网络的构建与训练基础

笔者PyTorch的全部简单教程请访问:https://www.jianshu.com/nb/48831659

PyTorch教程-3:PyTorch中神经网络的构建与训练基础

基本原理

在PyTorch中定义一个网络模型,需要让自定义的网络类继承自 torch.nn.Module,并且比较重要的是需要重写其 forward 方法,也就是对网络结构的前向传播做出定义,即在forward方法中,需要定义一个输入变量 input 是如何经过哪些运算得到输出结果的。这样,当一个网络作用于输入变量后,就能得到输出的值(output = MyNet(input))。然后通过计算损失(loss),也就是网络的预测值与真实值之间的差距,再将这个损失反向传播,loss.backward() 就可以计算得到loss对网络中所有参数的反向传播后的梯度值,这里的backward就是依赖于forward定义的运算规则而自动计算的。最后在利用梯度值来更新网络的参数从而完成一步训练。

大体来说,训练一个网路通常需要经理如下的步骤:

  • 定义网络结构以及其中要学习的参数
  • 从数据库获取输入值
  • 将输入值输入网络得到输出值
  • 计算输出值与标签之间的loss
  • 将loss做反向传播求得loss之于所有参数的导数
  • 更新参数,比如SGD的更新方式:weight_new = weights_old − learning_rate × gradient

本文使用一个最简单的LeNet为例,该网络的输入是一个 32×32 大小的单通道灰度图,输出为 10 个分类的值(1×10的向量),具有两个卷积层(与池化层),三个全连接层(与激发函数)。

1.png
input -> (convolution -> acrivate function ->pooling) * 2 -> (fully-connection -> activate function) * 2 -> fully-connection -> output

定义网络

定义网路的类要继承自 torch.nn.Module,并且必须至少重写 forward 方法来定义网络结构,只要定义了forward函数,autogradbackward方法可以自动完成。torch.nn 模块中定义了很多定义网络的常用层、函数,而 torch.nn.functional 模块中则定义了很多网络中常用的函数,这里给出了一个定义LeNet的例子,其中用到了常用的卷积层、池化层、全连接层、激活函数等:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()

        # convolution layers
        self.conv1 = nn.Conv2d(1,6,3)
        self.conv2 = nn.Conv2d(6,16,3)

        # fully-connection layers
        self.fc1 = nn.Linear(16*6*6,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        # max pooling over convolution layers
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)

        # fully-connected layers followed by activation functions
        x = x.view(-1,16*6*6)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        # final fully-connected without activation functon
        x = self.fc3(x)

        return x

net = Net()
print(net)

Net(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

定义完成的网络,可以使用 parameters() 来获得其所有的参数:

parameters = list(net.parameters())
print(len(parameters))

10

向网络中输入值

定义完一个网络结果,接下来我们需要向网络中输入值,从而获得输出结果。
在PyTorch中,torch.nn仅支持mini-batches的类型,所以无法单独输入任何一个input,哪怕是一个输入也要包装成单个sample的batch,即batch-size设置为1。比如说上述网络的第一层是二维卷积从 nn.Conv2d,其实他接受的是一个四维tensor样本数×通道数×高×宽

对于一个单独的输入样本,可以通过使用 torch.Tensor.unsqueeze(dim) 或者 torch.unsqueeze(Tensor, dim) 实现。

  • torch.Tensor.unsqueeze(dim):为原来的tensor增加一个维度,返回一个新的tensor。其接受一个整数做参数,用于标示要增加的维度,比如0表示在第一维增加一维
  • torch.unsqueeze(Tensor, dim):将Tensor增加一个维度并返回新的tensor,第二个参数dim同上

下边是一个例子可以验证上述的方法,随机生成了一个 1×32×32 大小的tensor作为网络的输入,但是需要先提前将其包装成1个大小的batch(等同于直接生成一个随机的 1×1×32×32 大小的tensor):

x = torch.rand(1,32,32)
print(x.size())
y = x.unsqueeze(0)
print(y.size())
z = torch.unsqueeze(x,0)
print(z.size())

torch.Size([1, 32, 32])
torch.Size([1, 1, 32, 32])
torch.Size([1, 1, 32, 32])

将一个单元素的batch喂给我们的网络并获取输出的例子:

x=torch.rand(1,1,32,32)
out = net(x)
print(out)

tensor([[-0.1213,  0.0420, -0.0926,  0.0741,  0.0615, -0.1131,  0.0136, -0.0526,
         -0.0172,  0.0244]], grad_fn=<AddmmBackward>)

计算损失(Loss)

网络的训练需要基于loss,也就是网络预测值与标签真实值之间的差距。nn.Module中同样定义了很多损失函数(loss function),可以直接使用,比如这里使用的平方平均值误差(mean-squared error)MSELoss。

已知我们获得的对于输入x的网络预测值为out,然后生成一个随机的label值(目标值)target(这里和输入值需要保持一致,因此1×10表示1是batch-size,10才是单个标签的大小),计算两个值的损失:

x=torch.rand(1,1,32,32)
out = net(x)
target = torch.rand(1,10)

loss_function = nn.MSELoss()
loss = loss_function(out,target)

print(loss)

tensor(0.3597, grad_fn=<MseLossBackward>)

反向传播

有了loss之后,我们就要通过反向传播计算loss对于每一个参数的导数,很简单,使用 loss.backward() 即可,因为loss就是对于所有参数进行了一定的计算后得到的一个单标量的tensor,且在计算过程中追踪记录了所有的操作。在进行反向传播前,不要忽略了使用 net.zero_grad()所有参数的梯度缓存置0

net.zero_grad()

loss.backward()
print(net.conv1.bias.grad)

tensor([-0.0007,  0.0007, -0.0005,  0.0068,  0.0026,  0.0000])

更新参数

得到了每个参数的梯度,最后就是要更新这些参数,比如在随机梯度下降(Stochastic Gradient Descent,SGD)中的更新方法是:

weight_new = weights_old − learning_rate × gradient

直接写代码完成上述操作即:

lr = 0.01
for p in net.parameters():
    p.data.sub_(p.grad.data * lr)

当然,更好更快捷的方法就是使用PyTorch提供的包与已有的函数:使用 torch.optim 来完成,其中实现了很多常用的更新参数的方法,比如SGD,Adam,RMSProp等。使用optim中方法实例的step方法来进行一步参数更新

import torch.optim as optim

optimizer = optim.SGD(net.parameters(),lr =0.01)

optimizer.zero_grad()
out = net(x)
loss = loss_function(out,target)
net.zero_grad()
loss.backward()
optimizer.step()

重要参考索引

我们涉及到的三个重要的torch包/模块,它们其中提供了大量的对于神经网络的方法,它们完整的参考列表如下(强烈建议过一遍):

模块名 主要内容 参考链接
torch.nn 给出了大量神经网络的层、函数、损失计算方法、其他工具等,是最强大最重要的包 https://pytorch.org/docs/stable/nn.html
torch.nn.functional 给出了大量神经网络的层、函数、损失计算方法、其他工具等的函数实现 https://pytorch.org/docs/stable/nn.functional.html
torch.optim 实现了很多神经网络的参数更新方法(优化器) https://pytorch.org/docs/stable/optim.html
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,254评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,875评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,682评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,896评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,015评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,152评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,208评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,962评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,388评论 1 304
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,700评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,867评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,551评论 4 335
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,186评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,901评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,142评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,689评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,757评论 2 351

推荐阅读更多精彩内容