numpy实现一个神经网络识别手写数字数据集

前言

机器学习在安全行业的应用非常广,这就要求我们在深耕自己细分领域的同时还应该广泛涉猎;对机器学习相关基础知识,数据分析基本概念有所掌握。提起机器学习,大家想到的可能就是各大知名框架,大家对框架的选择也都各有所爱,得益于框架的良好封装,自己快速搭建一个机器学习网络并不是什么难事。使用框架进行编程的第一句都免不了import某个框架,使用过程变成了无聊的“调参侠”(仅对低级使用者而言,专业的算法工程师还是有很强的业务能力),这就导致我们常常会忽视算法的本质,本次使用numpy实现一个ANN来帮助理解机器学习相关概念,一个框架所有的部分都可以使用numpy实现。真正完成一次实验会对机器学习的本质有更加深入的认识。

背景

本次使用最简单的MNIST进行实验,使用numpy实现一个可以识别手写数字集的人工神经网络。

Mnist:大多数示例使用手写数字的MNIST数据集[1]。该数据集包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。为简单起见,每个图像都被平展并转换为784(28 * 28)个特征的一维numpy数组。

首先使用简单的keras框架构建一个网络进行手写数字的识别,网上相关资料很多,不再赘述,这里给出一份代码。

import numpy as np
import keras
import pandas as pd
from keras import layers
from matplotlib import pyplot as plt
from keras.datasets import mnist as mn

%matplotlib inline

# 读取训练数据和测试数据
(train_img, train_lab), (test_img, test_lab) = mn.load_data()
model = keras.Sequential()
model.add(layers.Flatten()) # (60000, 28, 28) => (60000, 28*28)
model.add(layers.Dense(64, activation='tanh'))
model.add(layers.Dense(10, activation='softmax'))

# 编译模型
model.compile(
    optimizer="adam",
    # 注意因为label是顺序编码的,这里用这个
    loss='sparse_categorical_crossentropy',
    metrics = ['accuracy']
)

# 模型结构
model.summary()

# 使用history保存每个epoch结束的loss,accuracy等信息
history = model.fit(train_img, train_lab, epochs=10, batch_size=500, validation_data=(test_img, test_lab), verbose=2) # 每批500张图片

# 保存模型
model.save('keras_mnist.h5')

可视化训练过程,使用history保存的信息画出折线图。

plt.plot(history.history['val_accuracy'], c='g', label='validation acc')
plt.plot(history.history['accuracy'], c='b', label='train acc')
plt.legend()
plt.show()

train_acc.png
plt.plot(history.history['val_loss'], c='g', label='validation loss')
plt.plot(history.history['loss'], c='b', label='train loss')
plt.legend()
plt.show()
train_loss.png

使用模型进行预测

# 加载训练的模型
from keras.models import load_model
model = load_model("model_name.h5")

result = model.predict(test_img)
def show_test(index):
    plt.imshow(test_img[index],cmap='gray')
    print("label : {}".format(test_lab[index]))
    print("predict : {}".format(result[index].argmax()))
    
index = np.random.randint(1, len(test_img))
show_test(index)
keras_predict.png

上面部分代码就是使用keras实现的过程,非常简单。接下来进入主题,可以对比一下手工实现和使用框架实现的区别。

首先需要明确的输入输出的维度,输入维度很简单,像素是28*28的,我们把每行的数据拼接起来,一张图片的维度就是28*28=784维的向量,输出的是0-9的10维向量,

输入: 784  
输出: 10  

由此我们构建一个最简单的只有一个隐藏层的神经网络。神经元之间采用最简单的线性连接。下面的公式就是一个最简单的神经网络,后续的全部工作就是实现这两个公式。

data是输入的图片,维度是[1,784];output是输出的预测结果维度是[1,10];A和B都是激活函数;h是隐藏层;
\vec{h} = A(data + b_0)

\vec{output} = B(\vec{h}w_1 + b_1)

根据输入输出确定其他参数的维度,只有参数维度不出问题才能确保下面的流程正确进行

b0和data同形,b0维度也是[1,784],

所以h也是[1,784]维

b1和output同形,b1维度也是[1,10]

根据矩阵乘法w1是[784,10]

将上面分析的结果带入原公式中,确认分析没毛病。

data:784
output:10
b_0:784
h:784
w_1:[784, 10]
b_1:10

[1,784] = [1,784] + [1,784]
[1,10] = [1,784][784,10] + [1,10]

确认参数维度之后我们开始初始化公式中的参数。

导入必要的包

import math
import copy
import numpy as np
import matplotlib.pyplot as plt
# 定义参数的维度
dimensions=[28*28,10]
activation=[tanh,softmax]
distribution=[
    {'b':[0,0]},
    {'b':[0,0],'w':[-1,1]},
]
#实现初始化参数
def init_parameters_b(layer):
    dist=distribution[layer]['b']
    return np.random.rand(dimensions[layer])*(dist[1]-dist[0])+dist[0]
def init_parameters_w(layer):
    dist=distribution[layer]['w']
    return np.random.rand(dimensions[layer-1],dimensions[layer])*(dist[1]-dist[0])+dist[0]
def init_parameters():
    parameter=[]
    for i in range(len(distribution)):
        layer_parameter={}
        for j in distribution[i].keys():
            if j=='b':
                layer_parameter['b']=init_parameters_b(i)
                continue
            if j=='w':
                layer_parameter['w']=init_parameters_w(i)
                continue
        parameter.append(layer_parameter)
    return parameter
parameters=init_parameters()

这样parameters就是我们初始化成功的参数,验证参数生成是否正确

#测试参数生成
import tensorflow as tf
print(tf.shape(parameters[0]['b']))
print(tf.shape(parameters[1]['b']))
print(tf.shape(parameters[1]['w']))
parameters

输出如下,可以看到输出和我们的预期一致。训练神经网络模型的过程实际上就是找到合适的参数的过程,等训练结束之后可以比较一下parameters的变化。

tf.Tensor([784], shape=(1,), dtype=int32)
tf.Tensor([10], shape=(1,), dtype=int32)
tf.Tensor([784  10], shape=(2,), dtype=int32)

[{'b': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.])},
 {'b': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
  'w': array([[-0.11063313,  0.01986579, -0.54817987, ...,  0.06447563,
           0.55723463, -0.36988999],
         [ 0.93507145, -0.13798472, -0.68732584, ...,  0.04820882,
          -0.33476673, -0.69842804],
         [-0.22759392, -0.61509861, -0.93002526, ..., -0.45658224,
          -0.4769593 , -0.68456901],
         ...,
         [ 0.28200054, -0.92222148, -0.16388762, ...,  0.95929227,
           0.60109395,  0.84298182],
         [ 0.22296947, -0.4467861 , -0.65828542, ..., -0.10003993,
           0.29943921, -0.46707877],
         [ 0.01722276,  0.04571887, -0.87339843, ..., -0.03931738,
          -0.36247935,  0.61174093]])}]

下一步我们实现公式中的两个激活函数A和B

这里A函数使用tanh()做激活函数,这里也可以采用其他函数,因为本次需要手动计算梯度等操作,所以选择一个比较简单的函数方便计算。B函数softmax()用作分类

#定义需要的两个激活函数
def tanh(x):
    return np.tanh(x)
def softmax(x):
    exp=np.exp(x-x.max())
    return exp/exp.sum()

这里exp=np.exp(x-x.max())的操作是为了防止指数爆炸引起上溢,如果计算np.exp(1000)会报RuntimeWarning: overflow encountered in exp的错误,将每一位数都减去这组数中最大的那一位,是不会对结果产生影响的(高中知识)

计算一组测试数据softmax(np.array([1,2,3])),和softmax(np.array([-2,-1,0]))运行结果都如下

array([0.09003057, 0.24472847, 0.66524096])

接下来定义预测函数,就是output = B(h w_1 + b_1)这部分,给定一个784的输入,给出一个10维的输出。下面的简单代码就构建完成了一个神经网络,接下来我们要做的就是利用梯度下降更新parameters,让这个神经网络表现出更好的分类效果。

#定义预测函数
def predict(img,parameters):
    l0_in=img+parameters[0]['b']
    l0_out=activation[0](l0_in)
    l1_in=np.dot(l0_out,parameters[1]['w'])+parameters[1]['b']
    l1_out=activation[1](l1_in)
    return l1_out

运行预测函数输出如下:

# 测试预测函数
# predict(np.random.rand(784),parameters)

array([1.31779086e-06, 2.43796927e-07, 2.34876198e-10, 4.46492815e-02,
       5.90359130e-02, 8.78048316e-02, 1.67109396e-02, 2.06152860e-10,
       7.91797467e-01, 5.26289356e-09])

读取数据并展示数据

# 读取数据集
from pathlib import Path
import struct
dataset_path=Path('./MNIST')
train_img_path=dataset_path/'train-images.idx3-ubyte'
train_lab_path=dataset_path/'train-labels.idx1-ubyte'
test_img_path=dataset_path/'t10k-images.idx3-ubyte'
test_lab_path=dataset_path/'t10k-labels.idx1-ubyte'

# 5w训练集,1w验证集,1w测试数据集
train_num=50000
valid_num=10000
test_num=10000

with open(train_img_path,'rb') as f:
    struct.unpack('>4i',f.read(16))
    tmp_img=np.fromfile(f,dtype=np.uint8).reshape(-1,28*28)/255
    train_img=tmp_img[:train_num]
    valid_img=tmp_img[train_num:]
    
with open(test_img_path,'rb') as f:
    struct.unpack('>4i',f.read(16))
    test_img=np.fromfile(f,dtype=np.uint8).reshape(-1,28*28)/255

with open(train_lab_path,'rb') as f:
    struct.unpack('>2i',f.read(8))
    tmp_lab=np.fromfile(f,dtype=np.uint8)
    train_lab=tmp_lab[:train_num]
    valid_lab=tmp_lab[train_num:]
    
with open(test_lab_path,'rb') as f:
    struct.unpack('>2i',f.read(8))
    test_lab=np.fromfile(f,dtype=np.uint8)

def show_train(index):
    plt.imshow(train_img[index].reshape(28,28),cmap='gray')
    print('label : {}'.format(train_lab[index]))
def show_valid(index):
    plt.imshow(valid_img[index].reshape(28,28),cmap='gray')
    print('label : {}'.format(valid_lab[index]))
def show_test(index):
    plt.imshow(test_img[index].reshape(28,28),cmap='gray')
    print('label : {}'.format(test_lab[index]))

读取数据和测试数据的部分就完成了,测试数据集中的数据是不会在训练集中出现的,这就好像把高考原题给你去训练,有可能你是仅仅背会了这个题目而不是真正的理解了如何去分析解题,所以高考基本不会出现原题,这样才能考察出你对知识的理解掌握程度。

# 测试show_train
show_train(np.random.randint(train_num))
test.png

接下来就是更新参数的部分,参数我们都是随机生成的,自然神经网络给出的预测结果也是随机的,我们如何去评价预测的结果和真实结果的差异呢,这种评价机制就是损失函数。理想状态下神经网络给出的预测应该是这样的:比如这个图片真实的标签是1,神经网络预测这个图片是1的概率为100%,为其他数字的概率都是0。然而显示中神经网络达不到这样的效果,更多时候都是“人工智障”,我们将所有预测中概率最大的那个作为神经网络的预测结果。

# 定义损失函数
onehot=np.identity(10)
def sqr_loss(img, lab, parameters):
    y_pred=predict(img,parameters)
    y=onehot[lab]
    diff=y-y_pred
    return np.dot(diff,diff)

定义完损失函数之后,我们只需要让损失函数持续的减少,就能向理想化的预测结果不断靠近,究竟如何进行参数的更新,这里就要使用梯度,大学高数告诉我们梯度方向是上升最快的方向,负梯度方向为下降最快的方向,通过往梯度(gradient)下降的方向调整参数,逐步减小损失函数loss function的值,从而得到训练好的模型。

#定义两个激活函数的导数
def d_softmax(data):
    sm=softmax(data)
    return np.diag(sm)-np.outer(sm,sm)

def d_tanh(data):
    return 1/(np.cosh(data))**2

differential={softmax:d_softmax,tanh:d_tanh}

def grad_parameters(img,lab,parameters):
    l0_in=img+parameters[0]['b']
    l0_out=activation[0](l0_in)
    l1_in=np.dot(l0_out,parameters[1]['w'])+parameters[1]['b']
    l1_out=activation[1](l1_in)
    
    diff=onehot[lab]-l1_out
    act1=np.dot(differential[activation[1]](l1_in),diff)
    
    grad_b1=-2*act1
    grad_w1=-2*np.outer(l0_out,act1)
    grad_b0=-2*differential[activation[0]](l0_in)*np.dot(parameters[1]['w'],act1)
    
    return {'w1':grad_w1,'b1':grad_b1,'b0':grad_b0}

输入图片和标签和初始化的参数后,就会计算梯度

# 测试梯度计算
# grad_parameters(train_img[2],train_lab[2],init_parameters())

{'w1': array([[-0., -0., -0., ..., -0., -0., -0.],
        [-0., -0., -0., ..., -0., -0., -0.],
        [-0., -0., -0., ..., -0., -0., -0.],
        ...,
        [-0., -0., -0., ..., -0., -0., -0.],
        [-0., -0., -0., ..., -0., -0., -0.],
        [-0., -0., -0., ..., -0., -0., -0.]]),
 'b1': array([-8.97529566e-04, -1.57305145e-04, -1.13258721e-05,  1.51856573e-02,
        -3.34573808e-03, -8.80455733e-04, -6.32826250e-03, -3.27152101e-03,
        -2.34009828e-04, -5.95096104e-05]),
 'b0': array([-8.35839128e-03, -5.34125889e-03, -1.37208082e-02,  5.61553159e-03,
         1.95620702e-02,  4.82436064e-03, -1.81148739e-02,  7.06783791e-03,
         6.77894581e-03, -6.20645121e-05, -1.03271620e-02,  8.64674933e-03,
         1.20581038e-02, -1.66106477e-02,  1.42724550e-03, -2.10642636e-03,
         5.17487744e-03,  6.77834056e-03,  5.23595222e-03,  9.60313306e-03,
         5.52857690e-03,  1.52551626e-02,  1.26690528e-02, -9.09700995e-03,
        -7.05989943e-03, -1.50500067e-03,  1.54893607e-02,  9.42818548e-03,
        -6.14107702e-03, -5.62287700e-03, -1.77016958e-02, -1.08695394e-02,
        -4.73702194e-03,  9.69312031e-04,  3.38897299e-05,  9.68205152e-03,
         1.40352488e-02,  1.28429731e-02,  8.39480940e-04,  1.05712592e-02,
        -1.73710289e-02, -9.64789484e-03,  7.18972793e-03,  1.03710320e-02,
        -3.68373317e-03,  1.56242949e-03,  4.60410099e-03, -1.41165218e-02,
        -1.05017685e-02, -1.69197853e-02,  1.14012162e-02, -3.17723825e-03,
         5.86251639e-03, -4.55389491e-03, -6.27766397e-04, -2.45242403e-03,
         3.32728074e-03,  9.66014811e-03,  8.54219077e-03, -2.63156769e-03,
         7.90602741e-03,  5.91128736e-04, -4.63103323e-03,  7.27767514e-03,
         2.20731150e-03,  1.65127510e-02,  5.15687997e-03, -4.42550832e-04,
         1.57667838e-03,  4.46071508e-03,  5.84473192e-03,  9.07912645e-03,
         2.09107272e-02,  2.36699460e-02,  1.28038267e-02, -1.19544611e-02,
         1.16261562e-02, -1.62204299e-02,  3.32573066e-04,  1.43708779e-02,
        -1.15140572e-02, -1.62038995e-02, -6.64056214e-03, -2.68635750e-03,
        -4.15420403e-03,  1.43767043e-02, -3.52162151e-03,  7.34874136e-03,
        -1.50134966e-02, -1.85408493e-03, -5.97253728e-03, -7.70481837e-03,
         1.85042403e-03,  1.69247570e-03,  4.48173562e-03,  3.94761096e-03,
        -6.09264268e-03,  1.84609092e-02,  9.15935735e-03, -1.60288369e-02,
         2.47151419e-03, -1.75132583e-03,  2.15554612e-02, -1.08568207e-02,
        -5.38923467e-03,  1.63112882e-02,  1.36774969e-03, -1.01017574e-03,
        -9.67172656e-03, -2.05714170e-02,  1.31786779e-02, -1.20454389e-02,
         9.62068604e-03,  1.20974865e-02,  1.74756972e-02,  3.63575259e-03,
        -1.33278364e-02, -8.15359725e-03,  3.03751796e-03, -8.62243956e-03,
        -3.93034010e-03, -7.80119830e-03,  4.28420827e-03,  3.99100767e-03,
        -6.96652322e-03,  9.98542524e-03,  1.32837396e-02, -1.64502311e-02,
        -2.45694320e-03,  2.86176292e-03, -4.70804440e-03,  7.45219664e-03,
        -7.93939346e-03,  9.73363380e-04, -4.65506787e-03,  5.09916398e-03,
        -3.79960523e-03,  2.49196922e-02, -3.19170301e-03, -8.78287071e-04,
         8.60141093e-03, -5.18163659e-03, -7.67823469e-03,  9.12309830e-03,
         5.69992325e-03, -9.61483330e-03,  5.34989829e-04,  1.34362523e-03,
         1.48173676e-02,  1.01570384e-02, -4.22456962e-03, -6.73901076e-03,
         1.12639146e-02, -4.97857054e-04,  7.50129527e-03, -1.62759133e-04,
         2.66076235e-03, -8.20197284e-03,  9.08791485e-03, -1.17750589e-02,
         4.10305929e-03,  4.44696391e-03,  7.20534143e-03,  1.27302774e-02,
         2.09109442e-02, -1.17840327e-02,  1.68535551e-02, -5.45905767e-03,
         1.51291239e-02, -8.07673020e-03,  9.48311248e-03, -9.34733696e-05,
         1.01434332e-02, -3.11975379e-03,  5.63583825e-03,  5.75590337e-03,
         1.14765621e-02, -1.65725199e-02, -1.81781627e-02, -4.29306537e-03,
         8.29402330e-03,  6.48925555e-03,  9.36958945e-03, -4.97385263e-03,
         3.27683831e-04,  7.87043527e-03, -1.19547598e-02, -1.34038752e-02,
         6.52915908e-03,  1.44435525e-03,  8.83738335e-03,  7.74306472e-03,
        -1.00514819e-02,  5.32062237e-03,  8.15320109e-03,  1.27634258e-02,
        -1.06853142e-02, -2.37312411e-03, -1.46285969e-02,  1.33049237e-02,
         1.30542203e-02, -2.52723104e-03,  8.27540521e-03, -1.02999232e-03,
        -1.44753332e-02,  5.52333176e-03,  6.69350221e-03, -1.58433410e-02,
         7.49662163e-03, -7.04779382e-03, -1.22842597e-02, -1.84830282e-03,
         1.39855218e-02,  6.31352913e-03, -1.48297463e-02,  3.38934718e-03,
        -6.62877874e-03,  5.75719349e-03,  1.44137908e-02,  1.95541441e-02,
         8.10794967e-03, -4.58825637e-03, -6.31258088e-03,  8.76559737e-03,
        -1.49603323e-02, -9.71335649e-03, -3.96255520e-03, -2.12043313e-03,
        -3.12815557e-03,  5.99221757e-03,  5.43202946e-03, -9.03813632e-04,
        -9.52207459e-03, -1.38406399e-02,  1.63812992e-02,  1.43130646e-02,
        -6.26658585e-03, -1.84001540e-02,  8.79091617e-03,  9.12318712e-03,
        -3.84151764e-03,  7.49930864e-03, -2.33399185e-03, -1.34922235e-02,
        -1.20013433e-03, -2.45301294e-03,  1.87235962e-02,  6.47738169e-03,
         9.04561565e-03, -1.78941246e-03, -2.89415340e-03,  2.10333201e-03,
        -6.02399020e-04, -4.35288849e-03,  1.29300391e-02, -3.85006484e-03,
         7.31630507e-04, -6.38921249e-03, -5.93935905e-05, -1.51483706e-02,
        -1.28440162e-03, -1.43829370e-02, -3.33054898e-03, -2.97546136e-03,
        -6.58845675e-03, -6.52283492e-03, -7.70284688e-03, -1.36712024e-02,
        -8.54573417e-03, -9.90562153e-03, -6.22076017e-03, -9.82411806e-03,
         1.45459409e-03,  7.68754545e-03,  1.32780017e-02,  9.00071756e-04,
        -4.74501190e-03,  4.00242229e-03, -3.53622701e-03,  7.31449850e-03,
        -9.75946834e-03,  7.65593601e-03,  6.07010229e-03, -1.72909057e-03,
        -1.69049170e-03, -2.10693899e-03, -7.90723651e-03, -7.45393514e-03,
         4.24886363e-03,  7.32627994e-03,  5.10599485e-03, -1.76507662e-02,
        -5.49388863e-03,  1.46263929e-02,  1.18535559e-02,  1.26915167e-03,
         6.46282523e-04, -1.42360909e-02,  7.37329855e-03, -6.26288989e-05,
         2.43055218e-03,  8.92408413e-03,  1.24414497e-02,  6.55988247e-03,
         9.37842213e-03,  7.08007913e-04, -7.35278334e-03,  5.28320647e-03,
         8.88771246e-03,  1.39800623e-02,  1.26165278e-02,  4.43142767e-03,
         3.47184004e-03, -1.04267387e-02,  1.03648557e-02,  5.17463675e-03,
         1.11589419e-02,  1.46939647e-02,  1.61697670e-02, -1.12996099e-02,
         6.94902080e-03,  1.48351031e-02, -2.95375616e-04, -1.86238982e-03,
        -1.02235549e-02,  5.61529211e-03,  5.47667190e-03, -1.27005573e-03,
         5.36429217e-03, -1.51237814e-02, -1.86295636e-02,  7.33730686e-03,
        -1.85398513e-02,  1.20012698e-02,  1.38622315e-02,  1.48138021e-02,
         4.54187744e-03,  8.56565364e-03, -4.50984784e-03, -4.16407036e-03,
        -3.78523864e-04,  5.17276764e-03, -1.07159807e-03, -5.34147242e-03,
        -4.64290015e-04, -2.55771720e-02,  2.87195933e-03, -5.38110333e-03,
         1.17607976e-02, -5.48219436e-03, -8.98229736e-03,  1.72540309e-02,
         3.34464763e-03, -1.02802491e-02,  6.04949772e-03, -4.29102799e-04,
         3.03261484e-03,  2.72251956e-03,  1.06288529e-02, -9.31323588e-04,
         3.63547515e-03,  1.11205136e-02,  8.68301853e-03, -1.52172758e-02,
         7.28648699e-03, -3.12303012e-03, -1.79116254e-02,  6.46336746e-04,
        -6.21691734e-03,  1.98643930e-02, -1.00060605e-02, -1.62491068e-03,
         9.47750036e-03,  2.76323114e-03, -6.58888517e-03, -2.66808567e-03,
         8.18821416e-03, -3.26934196e-03,  6.36322084e-03, -9.80971793e-03,
         1.43097690e-02, -3.24592005e-03, -2.93351502e-03,  4.66444163e-03,
        -1.04971944e-02,  5.32184795e-04,  1.22446871e-02,  3.68558230e-03,
         2.03293447e-03, -1.14694715e-02,  1.11611296e-02,  6.78077506e-03,
         9.15919965e-03, -6.48321204e-03,  9.14721124e-03,  9.81069558e-03,
         5.46416759e-04,  2.83696708e-03, -6.18660808e-03,  5.23040873e-03,
        -9.43232269e-03,  5.01350486e-04, -7.80010012e-03, -6.11710888e-03,
         5.13209665e-03, -4.63020314e-03, -1.83953692e-03, -4.59957165e-03,
        -1.81975961e-03,  1.13762265e-02,  5.89661920e-03,  3.79557880e-03,
        -7.71640983e-03, -1.23142688e-02,  5.73868261e-03,  1.34087970e-02,
         4.29147159e-03,  9.57904998e-03,  1.07044304e-02,  1.99221102e-02,
        -2.61706503e-03,  1.45577671e-02,  2.77056800e-03,  4.78140880e-03,
         1.20878142e-03,  3.42858245e-03, -1.69856130e-03,  7.83897376e-04,
        -6.53504210e-04,  4.18284416e-03, -9.40530680e-05, -1.60026261e-04,
        -3.15430401e-03, -4.17154650e-03,  1.54810025e-03, -7.50441948e-03,
         9.89481461e-05, -1.41143581e-02,  2.40869300e-03, -1.51957209e-03,
         1.16327294e-02, -9.55667123e-03,  1.44912348e-02, -8.35596990e-03,
         7.95354718e-03,  5.12660372e-03,  1.46623119e-02,  4.65114541e-03,
         1.74226145e-03, -5.67281482e-03, -5.50606280e-03,  4.15548027e-03,
         4.67841941e-03,  5.75094804e-03,  2.76135746e-03, -5.62818508e-03,
        -7.30893481e-04, -6.00692150e-03,  2.36115112e-03,  1.62798089e-02,
         4.02249298e-04,  1.47418653e-02,  7.20601078e-03,  1.14021295e-02,
         1.44496676e-02, -1.32249960e-02,  2.15247121e-03,  4.09362355e-03,
         1.45143285e-02,  1.31540261e-02,  1.55092259e-02, -4.18210240e-03,
         2.64852642e-03, -4.82464320e-03, -4.67698373e-03,  3.50518964e-03,
        -1.36248164e-02, -7.11345454e-03,  4.60594167e-03,  2.63389543e-03,
         1.34032875e-03, -1.29473325e-03, -3.35849685e-03,  1.67917512e-03,
         1.10303519e-02, -3.11467239e-05,  6.49283788e-03, -7.81435880e-03,
         1.76452284e-02,  1.04238370e-02, -4.69215195e-03,  5.81979544e-04,
        -1.98356723e-02, -5.70613891e-03,  8.49987248e-04,  1.68373729e-02,
         1.64191887e-02,  1.04608380e-02, -7.32840721e-03,  2.85896382e-03,
        -2.01301614e-04, -1.02907026e-02, -1.04965167e-02, -5.52796873e-04,
        -1.20299484e-04, -1.09925685e-03,  5.85160937e-03,  2.06569223e-02,
        -1.16742886e-02, -8.26894289e-03,  1.44519510e-02,  8.68505776e-03,
         1.19268832e-02,  5.26283060e-03,  2.20560264e-02,  1.00600326e-02,
        -1.88134070e-03, -1.69761724e-02, -6.26756005e-03, -3.38911679e-04,
         5.13680542e-03, -1.95889890e-03, -2.17098912e-03, -3.15865206e-03,
         1.47845344e-04,  7.96503637e-03, -6.92480846e-03,  3.70510666e-03,
         7.66695055e-03,  6.70358506e-03, -6.96085158e-03,  1.53629095e-02,
         1.86723844e-02, -7.97768542e-03,  1.13408760e-03,  1.64453412e-02,
         7.86555740e-03,  8.11357602e-03, -3.97513230e-03, -5.54770721e-03,
         1.86414221e-02,  1.63247283e-02, -7.57573261e-04,  3.04783372e-03,
         6.02811064e-03,  2.72770615e-04,  3.29787225e-03,  1.78250017e-02,
         8.11287070e-03,  1.23290614e-02,  3.27188272e-03, -8.57967903e-03,
         1.60769195e-02,  7.16771498e-03,  1.16647724e-02, -6.28664396e-03,
        -1.42586547e-02,  5.47605028e-03,  1.14969641e-03,  5.84922870e-04,
        -3.94756512e-03,  1.31725191e-02,  2.90039058e-03, -1.25577523e-02,
        -9.78917120e-03,  8.50944167e-03, -1.33852891e-02, -1.10724119e-03,
        -6.13591574e-03, -1.12463163e-02, -1.19716727e-02,  6.45789325e-03,
        -1.26230884e-02, -9.51967853e-03, -2.75361966e-03, -9.92746131e-03,
         1.45935455e-02,  4.65504099e-03,  3.31550071e-03, -1.04595145e-02,
         5.93363274e-03, -3.81389867e-04,  9.52299641e-03,  1.44994018e-02,
        -4.47748660e-03,  2.85882044e-03,  5.52815415e-03,  1.88440260e-02,
         1.69711364e-02, -7.93964903e-03, -9.42645940e-03,  1.36986903e-03,
        -2.12047186e-02, -9.85594315e-03, -1.25158026e-02,  9.31486036e-03,
        -1.52051804e-02,  1.69511449e-02, -1.91137863e-03, -1.35500385e-02,
         2.55165368e-03, -2.51141096e-02,  9.62349215e-03, -1.12377955e-02,
         9.98116223e-03, -1.07760556e-02, -1.69691111e-03,  5.61699226e-03,
        -1.22563609e-02,  5.54497848e-03, -4.46853532e-03, -1.39720663e-02,
         9.35346776e-03,  1.40185253e-02,  4.24760780e-03,  1.36261429e-02,
         1.79470501e-03,  1.41545587e-02,  6.32130317e-03, -1.15250463e-02,
        -8.54878402e-03, -1.64811988e-02,  1.90066621e-02,  1.54530384e-02,
        -5.55854794e-03,  1.26907388e-02,  5.39264936e-03, -1.06719618e-02,
        -1.22603262e-03, -5.33742688e-03, -1.26367239e-02,  1.92253332e-02,
        -2.95910850e-04,  8.77304966e-03, -5.68655143e-04,  1.38931266e-02,
         1.11144006e-02,  3.54186002e-04, -1.15603241e-03, -1.39468960e-02,
         1.29550223e-02,  1.13014075e-02,  1.31172510e-02,  2.47826652e-03,
        -9.42405051e-03, -2.58612186e-03, -5.29341447e-03, -1.04336783e-02,
         3.66736635e-03, -1.49207100e-02,  6.48964966e-03,  1.18624084e-03,
        -1.19865255e-02, -2.30032940e-02, -1.21192025e-02,  5.45466316e-04,
         1.08917380e-02, -5.32271281e-03, -1.62339706e-02,  5.28767036e-03,
         1.84310691e-03,  3.62564351e-03, -2.19418169e-04, -3.90339853e-03,
        -1.18366759e-02, -2.01445414e-02,  3.76155440e-03, -2.00397348e-03,
        -1.93908225e-04,  8.26558290e-03,  2.37206952e-03, -1.41144863e-02,
         2.40565301e-03, -8.41398971e-03,  4.25295067e-03,  6.57405357e-04,
        -1.38303645e-02,  8.42248375e-03, -3.40463762e-03, -5.49103642e-03,
        -8.01241591e-03, -2.25115870e-03, -6.17287000e-03, -1.46935803e-04,
        -1.63164622e-02,  1.60228803e-02,  1.76595427e-02, -1.12082168e-02,
         9.94637585e-04, -4.64384207e-03,  7.33769878e-03,  4.14695297e-03,
        -6.10681023e-03,  1.80445759e-02, -1.96333718e-03, -1.05458184e-02,
        -3.70002702e-03,  1.28312342e-02,  6.20129990e-03, -8.84648393e-03,
         5.92371651e-03,  2.55170845e-03,  1.52498118e-02,  1.04301101e-02,
        -1.03409919e-02, -1.58708793e-03, -3.73517551e-03, -1.30893179e-03,
        -7.22263076e-03,  1.13466126e-02, -7.20440767e-03, -1.12978988e-02,
         6.42180544e-03,  1.81445550e-02, -1.46089354e-02,  1.57596778e-02,
         1.01798202e-02,  6.24328365e-03,  7.05300350e-03, -8.92576614e-03,
        -5.20019107e-03, -2.01859488e-02, -5.75730127e-03, -7.83280726e-03,
         5.41992041e-03,  3.54268513e-03, -5.70186345e-03,  7.87650572e-03,
        -1.77062060e-02,  1.29850863e-03, -3.31465590e-03,  1.91965326e-02,
         1.79732193e-02, -6.32151857e-03,  7.32256338e-03, -2.07217350e-03,
        -1.50998518e-02, -1.82906671e-03, -5.47495891e-03, -2.03257958e-03,
        -4.06861838e-03, -7.15148229e-03, -3.64476684e-03,  6.81016866e-03,
        -1.38613009e-03,  5.10535689e-03,  1.06649154e-02, -4.20675519e-03,
         5.27032997e-03, -1.32378630e-02,  1.11323960e-02,  8.35729477e-03,
        -1.06519393e-02, -8.21342240e-03, -1.10285836e-02,  9.85257533e-03,
        -4.20384563e-03,  1.13903373e-02,  9.16708088e-03, -1.17642999e-03,
         6.08050951e-03, -9.38422557e-03,  6.79306611e-03, -2.19940829e-03,
        -1.08774752e-02, -9.61702798e-03,  8.93264036e-03, -5.04352703e-03,
        -1.01352665e-02, -1.50189916e-02,  4.42095519e-04, -1.59010711e-02,
        -1.20179188e-02,  1.75890783e-03, -1.23881319e-02,  1.18678946e-02,
         1.59798761e-03, -7.76694562e-03, -4.38607793e-03, -1.30714383e-02,
        -1.08108158e-02,  8.90569743e-03,  3.69362361e-04,  9.15744120e-03])}

定义完梯度以后就可以开始实现梯度下降算法

batch_size=100
def train_batch(current_batch,parameters):
    grad_accu=grad_parameters(train_img[current_batch*batch_size+0],train_lab[current_batch*batch_size+0],parameters)
    for img_i in range(1,batch_size):
        grad_tmp=grad_parameters(train_img[current_batch*batch_size+img_i],train_lab[current_batch*batch_size+img_i],parameters)
        for key in grad_accu.keys():
            grad_accu[key]+=grad_tmp[key]
    for key in grad_accu.keys():
        grad_accu[key]/=batch_size
    return grad_accu

def combine_parameters(parameters,grad,learn_rate):
    parameter_tmp=copy.deepcopy(parameters)
    parameter_tmp[0]['b']-=learn_rate*grad['b0']
    parameter_tmp[1]['b']-=learn_rate*grad['b1']
    parameter_tmp[1]['w']-=learn_rate*grad['w1']
    return parameter_tmp

下面定义一些评估指标,方便后期可视化训练过程,就是把每一个epoch结束后的loss和accuracy等信息保存下来,后期可以通过分析这些信息来调整超参数,优化模型。

def valid_loss(parameters):
    loss_accu=0
    for img_i in range(valid_num):
        loss_accu+=sqr_loss(valid_img[img_i],valid_lab[img_i],parameters)
    return loss_accu/(valid_num/10000)
def valid_accuracy(parameters):
    correct=[predict(valid_img[img_i],parameters).argmax()==valid_lab[img_i] for img_i in range(valid_num)]
    return correct.count(True)/len(correct)
def train_loss(parameters):
    loss_accu=0
    for img_i in range(train_num):
        loss_accu+=sqr_loss(train_img[img_i],train_lab[img_i],parameters)
    return loss_accu/(train_num/10000)
def train_accuracy(parameters):
    correct=[predict(train_img[img_i],parameters).argmax()==train_lab[img_i] for img_i in range(train_num)]
    return correct.count(True)/len(correct)
    
parameters=init_parameters()
current_epoch=0
train_loss_list=[]
valid_loss_list=[]
train_accu_list=[]
valid_accu_list=[]

准备工作全部完成,现在可以开始训练模型

learn_rate=10**-0.6
# learn_rate=1
epoch_num=15
for epoch_ in range(epoch_num):
    for i in range(train_num//batch_size):
        grad_tmp=train_batch(i,parameters)
        parameters=combine_parameters(parameters,grad_tmp,learn_rate)
    current_epoch+=1
    train_loss_list.append(train_loss(parameters))
    train_accu_list.append(train_accuracy(parameters))
    valid_loss_list.append(valid_loss(parameters))
    valid_accu_list.append(valid_accuracy(parameters))

训练15轮,正确率达到90以上,接下来使用训练好的模型识别图片,随机实验多次以后发现效果还是相当不错的。

test_index = np.random.randint(1000)
show_test(test_index)
predict_result = predict(test_img[test_index], parameters)
print("predict:{}".format(predict_result.argmax()))
predict.png

接下来看看保存的训练过程

# 可视化acc
plt.plot(valid_accu_list, c='g', label='validation acc')
plt.plot(train_accu_list, c='b', label='train acc')
plt.legend()
plt.savefig('train_acc.png')
train_acc_1.png
# 可视化loss
plt.plot(valid_loss_list, c='g', label='validation_loss')
plt.plot(train_loss_list, c='b', label='train_loss')
plt.legend()
plt.savefig('train_loss.png')
train_loss_1.png

我们还可以保存模型下一次使用,不过这里保存的模型仅仅是参数,而不是保存模型结构和参数,可以使用Python中的pickle

# 保存参数
import pickle
model_prameters_name = 'Mnist_model.pkl'
f = open(model_prameters_name, 'wb')
pickle.dump(parameters, f)
f.close()
f = open(model_prameters_name, 'rb')
param = pickle.load(f)
print(param)
f.close

至此就在不使用框架的情况下完成了一个最简单的人工神经网络,相比使用keras实现的版本该有的部分基本都有了。框架对相关函数初始化的方式都做了一定的优化,可以加快训练速度,这都是需要很高的数学能力才可以完成。这里我们也参考相关参数优化的论文对这个模型进行一下优化。这里可以参考此论文Understanding the difficulty of training deep feedforward neural networks优化参数初始化部分,使用此种方式初始化参数后,模型可以更快收敛。

glorot10a.png
dimensions=[28*28,10]
activation=[tanh,softmax]
distribution=[
    {'b':[0,0]},
    {'b':[0,0],'w':[-math.sqrt(6/(dimensions[0]+dimensions[1])),math.sqrt(6/(dimensions[0]+dimensions[1]))]},
]

相关的notebook已经上传至Github

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,294评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,780评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,001评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,593评论 1 289
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,687评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,679评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,667评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,426评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,872评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,180评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,346评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,019评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,658评论 3 323
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,268评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,495评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,275评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,207评论 2 352

推荐阅读更多精彩内容