pytorch入门_网络架构/训练/保存提取

1. 用nn.Module搭建网络架构

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5) # 1 input image channel, 6 output channels, 5x5 square convolution kernel
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120) # an affine operation: y = Wx + b
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv2(x)), 2) # If the size is a square you can only specify a single number
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:] # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

net = Net()
net

'''神经网络的输出结果是这样的
Net (
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear (400 -> 120)
  (fc2): Linear (120 -> 84)
  (fc3): Linear (84 -> 10)
)
'''

  上面这段是LeNet的架构, 需要说明的是

1.1 重要的函数有两个, 一个是_init_(), 一个是forward()

  其他的可以根据需要再加, _init_()是用来定义网络中需要的模块, 比如基本的conv, fc的模块的大小都在这里面定义好, forward()是用来定义前向的架构, 就是顺序地堆叠模块就可以了, 这个和keras没有什么区别, 我觉得Pytorch的这套架构比较有整体性, 就是把一个模型封装在一个类里面, 包含module和architecture都用这个类体现出来了


1.2 在_init_()里开头的那一句

super(Net, self).__init__()

  这句的意思是_init_()函数继承自父类nn.Module
  super()函数是用于调用父类的一个方法, super(Net, self)首先找到Net的父类(就是nn.Module), 然后把类Net的对象转换为类nn.Module的对象, 具体的可以看一下
http://www.runoob.com/python/python-func-super.html


1.3 conv和fc模块的入参

self.conv1 = nn.Conv2d(1, 6, 5) # 1 input image channel, 6 output channels, 5x5 square convolution kernel
self.fc1   = nn.Linear(16*5*5, 120) # an affine operation: y = Wx + b

  conv的入参有三个, 分别是输入channel, 输出channel, 以及kernel的大小(5就代表5x5)
  fc的入参只有两个, 分别就是输入的结点数, 输出的结点数


1.4 torch.nn.functional as F

  这个是通常写法, 我们用到的激活函数, pool函数都在这个F里面(损失函数还有conv/fc等层在nn里面), 这个要注意一下, 那BN/dropout是在nn还是在F里面?


2. 如何开始训练

1中是把网络定义好了, 那离能够训练这个网络还有几步呢?我们需要想一下训练一个神经网络需要一些什么基本的步骤

  • 输入值格式(就是patch的格式)
  • Optimizer
  • loss fucntion
  • 梯度更新过程

2.1 输入值格式

  我们的应用都是图片输入, 那pytorch有没有keras那种直接一个文件夹输入作为训练集或测试集(ImageDataGenerator())的那种方法呢?是有的, 就是Dataloader()

import torchvision
import torchvision.transforms as transforms


# torchvision数据集的输出是在[0, 1]范围内的PILImage图片。
# 我们此处使用归一化的方法将其转化为Tensor,数据范围为[-1, 1]

transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                             ])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, 
                                          shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
'''注:这一部分需要下载部分数据集 因此速度可能会有一些慢 同时你会看到这样的输出

Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Extracting tar file
Done!
Files already downloaded and verified
'''

  这段代码就是用DataLoader()来生成patch的方法, 这个和keras的ImageDataGenerator()就很像了, 生成的trainloader是一个包含tensor的迭代器, 在训练的过程中就可以直接从迭代器里去除tensor, 然后装入Variable中作为神经网络的输入, 这就是输入值格式
  当然, 这只是一种简单的从数据集中得到batch的方法, 要是我们的数据是从某个文件夹来还得经过点pre-process的话(比如SR)会更复杂一些, 我们会在SR benchmark笔记里面再详细说

2.2 Optimizer

import torch.optim as optim
# create your optimizer
optimizer = optim.SGD(net.parameters(), lr = 0.01)

# in your training loop:
optimizer.zero_grad() # zero the gradient buffers
optimizer.step()

  Optimizer的定义和Keras差不多, 可以用参数定义lr, momentum, weight_decay等, 这个去查文档就好, optimizer.zero_grad()的意思是初始化所有的梯度buffer为0, 这个是在每一次计算梯度更新之前做的, optimizer.step()就是根据你定义的optimizer执行梯度更新


2.3 loss function

import torch.optim as optim
# make your loss function
criterion = nn.MSELoss()

# in your training loop:
output = net(input)
loss = criterion(output, target)
loss.backward()

  criterion定义了loss function, 然后在训练的loop中loss就根据criterion不停去算, loss是一个Variable, 它具有backward属性, 就是在训练过程中可以直接用.backward来计算loss对每个weight的梯度值


2.4 训练过程

  有了输入, loss function和Optimizer, 我们就可以开始进行训练了, 过程就是从迭代器trainloader把input tensor送入网络, 然后算output, 根据loss function算loss, 从loss再反过去算梯度, 根据Optimizer去更新权重

for epoch in range(2): # loop over the dataset multiple times
    
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        
        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()        
        optimizer.step()
        
        # print statistics
        running_loss += loss.data[0]
        if i % 2000 == 1999: # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss / 2000))
            running_loss = 0.0
print('Finished Training')

2.5 神经网络的快速搭建方法

  除了上面提到的搭建神经网络的方法之外, pytorch还提供了另一种更快速的搭建方法, 有点类似于Keras的Sequentiao模型
http://keras-cn.readthedocs.io/en/latest/getting_started/sequential_model/

net = torch.nn.Sequential(
  torch.nn.Linear(2,10),
  torch.nn.ReLU,
  torch.nn.Linear(10,2),
)
# fast implementation of FC 2->10->2

3. 保存和提取训练结果

  保存和加载网络模型有两种方法, 一种是把模型和参数一起save(相当于keras的save()), 还有一种就是只save参数(相当于keras的save_weight())

# 保存和加载整个模型  
torch.save(model_object, 'model.pth')  
model = torch.load('model.pth')  
 
# 仅保存和加载模型参数  
torch.save(model_object.state_dict(), 'params.pth')  
model_object.load_state_dict(torch.load('params.pth')) 

  当网络比较大的时候, 用第一种方法会花比较多的时间, 同时所占的存储空间也比较大
  还有一个问题就是存储模型的时候, 有的时候会存成.pkl格式, 应该是没有本质区别, 都是Pickle格式, 后续在用C来读取网络的时候, 也从Pickle的C实现来考虑直接解析Pytorch的模型
https://www.zhihu.com/question/274533811

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,332评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,508评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,812评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,607评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,728评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,919评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,071评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,802评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,256评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,576评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,712评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,389评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,032评论 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,798评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,026评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,473评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,606评论 2 350

推荐阅读更多精彩内容

  • Android 自定义View的各种姿势1 Activity的显示之ViewRootImpl详解 Activity...
    passiontim阅读 171,832评论 25 707
  • 原文链接:https://yq.aliyun.com/articles/178374 0. 简介 在过去,我写的主...
    dopami阅读 5,650评论 1 3
  • 116 日记星球的第八期明天开营,本想不再参加复训的。大家常说一个人走不动,一群人走的远。一个人有时会有懒的心里,...
    馨之芬芳阅读 126评论 0 0
  • 2018年5月8日 星期二 天气 晴 今天是写亲子日记的第四天,说心里话自己不擅长表达和写作,...
    陈煜越爸爸阅读 314评论 0 4
  • 我们不合适,我爱吃土豆,你不爱。 我们不合适,我睡觉打呼,你不打。 我们不合适,我喜欢女孩,你喜欢男孩。 我们不合...
    Spidercc阅读 109评论 0 1