回顾神经网络分类任务的整体流程

# -*- coding: utf-8 -*-
# @Time    : 2020/3/9 23:33
# @Author  : zhoujianwen
# @Email   : zhou_jianwen@qq.com
# @File    : mnist_train.py
# @Describe: 回顾神经网络分类任务的整体流程


import torch
from torch import nn  # 神经网络库
from torch.nn import functional as F  # 常用函数
from torch import optim #  优化工具包

import torchvision  # 视觉工具包
from matplotlib import pyplot as plt  # 数据可示化工具包

from utils import plot_image, plot_curve, one_hot

# 解决Pycharm导入模块时提示“Unresolved reference”
# 在pycharm中设置source路径
# file–>setting–>project:server–>project structure-->选择python(工程名)-->点击Sources图标-->Apply即可
# 将放package的文件夹设置为source,这样import的模块类等,就是通过这些source文件夹作为根路径来查找,也就是
# 在这些source文件夹中查找import的东西。

batch_size = 512

# step1. load dataset
# Normalize 零—均值规范化也叫标准差标准化,mean:0.1307,std:0.3081,其转化公式s = (x - mean)/std,
# 特征标准化不会改变特征取值分布,只是为了保证参数变量的取值范围具有相似的尺度,以帮助梯度下降算法收敛更快。
# shuffle 将数据集随机打乱

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=False)

x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())

打印结果
torch.Size([512, 1, 28, 28]) torch.Size([512]) tensor(-0.4242) tensor(2.8215)

plot_image(x, y, 'image sample')

打印结果


plot_image

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        # xw+b , 其中256,64的数值都是由经验决定的,28*28输入的维度,10是一个分类值0-9
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1+b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2+b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3+b3
        x = self.fc3(x)
        return x



net = Net()
# [w1, b1, w2, b2, w3, b3],optimizer是一个优化器,更新参数值
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)


train_loss = []

for epoch in range(60):

    for batch_idx, (x, y) in enumerate(train_loader):

        # x: [b, 1, 28, 28], y: [512]
        # [b, 1, 28, 28] => [b, 784],其中b是batchsize,28*28 => 784 可以看作是x_i的样本数据
        x = x.view(x.size(0), 28*28)
        # => [b, 10]
        out = net(x)
        # [b, 10]
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)  # 获得代价函数的初始值

        optimizer.zero_grad()  # 在BP之前首先将梯度清零,以保证每次更新的负梯度值是最新的。
        loss.backward()  # 计算出梯度信息
        # w' = w - lr*grad
        optimizer.step()  # 更新参数信息

        train_loss.append(loss.item())  # 保存当前参数信息

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())  # 每训练完一个mini-batch就显示当前训练模型的参数状态

plot_curve(train_loss)  # 模型训练完毕,显示代价函数曲线收敛的走势
# we get optimal [w1, b1, w2, b2, w3, b3]  # 模型训练完之后会得到这一组最优参数解,使得loss值全局最小。

打印结果


train_loss

这里的loss值不是用来衡量模型的性能指标,只是用来辅助我们更好地训练模型,衡量模型的性能指标有很多种方法,最终
衡量模型的指标是它的准确度。
下面使用测试集对模型进行准确度测试。

total_correct = 0
for x,y in test_loader:
    x  = x.view(x.size(0), 28*28)
    out = net(x) # 输入测试样本数据x_i,预测出概率模型
    ''' 
    out: [b, 10] => pred: [b]  , 比如输出标签对应的预测概率为[0.1,0.9,0.01,......,0.08],∑P(y|x) = 1
    argmax获得预测概率最大元素所在的索引号,max=0.9,argmax(out)=[0,1,0,......,0],
    从而获得one-hot的预测编码
    若预测概率是out = [0.01,0.02,0.03,0.705,...,0.09],则 argmax(out) = [0,0,0,3,0,0,0,0,0,0],
    '''
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:', acc)

打印结果
test acc: 0.9684

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')

打印预测结果


plot_image
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,921评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,635评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,393评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,836评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,833评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,685评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,043评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,694评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 42,671评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,670评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,779评论 1 332
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,424评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,027评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,984评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,214评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,108评论 2 351
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,517评论 2 343

推荐阅读更多精彩内容