基于tensorflow+RNN的MNIST数据集手写数字分类

2018年9月25日笔记

tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。
RNN是recurrent neural network的简称,中文叫做循环神经网络。
MNIST是Mixed National Institue of Standards and Technology database的简称,中文叫做美国国家标准与技术研究所数据库
此文在上一篇文章《基于tensorflow+DNN的MNIST数据集手写数字分类预测》的基础上修改模型为循环神经网络模型,模型准确率从98%提升到98.5%,错误率减少了25%
《基于tensorflow+DNN的MNIST数据集手写数字分类预测》文章链接:https://www.jianshu.com/p/9a4ae5655ca6

0.编程环境

操作系统:Win10
tensorflow版本:1.6
tensorboard版本:1.6
python版本:3.6

1.致谢声明

本文是作者学习《周莫烦tensorflow视频教程》的成果,感激前辈;
视频链接:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/

2.配置环境

使用循环神经网络模型要求有较高的机器配置,如果使用CPU版tensorflow会花费大量时间。
读者在有nvidia显卡的情况下,安装GPU版tensorflow会提高计算速度50倍。
安装教程链接:https://blog.csdn.net/qq_36556893/article/details/79433298
如果没有nvidia显卡,但有visa信用卡,请阅读我的另一篇文章《在谷歌云服务器上搭建深度学习平台》,链接:https://www.jianshu.com/p/893d622d1b5a

3.下载并解压数据集

MNIST数据集下载链接: https://pan.baidu.com/s/1fPbgMqsEvk2WyM9hy5Em6w 密码: wa9p
下载压缩文件MNIST_data.rar完成后,选择解压到当前文件夹不要选择解压到MNIST_data。
文件夹结构如下图所示:

image.png

4.完整代码

此章给读者能够直接运行的完整代码,使读者有编程结果的感性认识。
如果下面一段代码运行成功,则说明安装tensorflow环境成功。
想要了解代码的具体实现细节,请阅读后面的章节。
完整代码中定义函数RNN使代码简洁,但在后面章节中为了易于读者理解,本文作者在第6章搭建神经网络将此部分函数改写为只针对于该题的顺序执行代码。

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
learing_rate = 0.001
batch_size =100
n_steps = 28
n_inputs = 28
n_hidden_units = 128
n_classes = 10
X_holder = tf.placeholder(tf.float32)
Y_holder = tf.placeholder(tf.float32)

def RNN(X_holder):
    reshape_X = tf.reshape(X_holder, [-1, n_steps, n_inputs])
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units)
    outputs, states = tf.nn.dynamic_rnn(lstm_cell, reshape_X, dtype=tf.float32)
    cell_list = tf.unstack(tf.transpose(outputs, [1, 0, 2]))
    last_cell = cell_list[-1]
    Weights = tf.Variable(tf.truncated_normal([n_hidden_units, n_classes]))
    biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))
    predict_Y = tf.matmul(last_cell, Weights) + biases
    return predict_Y
predict_Y = RNN(X_holder)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict_Y, labels=Y_holder))
optimizer = tf.train.AdamOptimizer(learing_rate)
train = optimizer.minimize(loss)

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

isCorrect = tf.equal(tf.argmax(predict_Y, 1), tf.argmax(Y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(isCorrect, tf.float32))
for i in range(1000):
    X, Y = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:X, Y_holder:Y})
    step = i + 1
    if step % 100 == 0:
        test_X, test_Y = mnist.train.next_batch(3000)
        test_accuracy = session.run(accuracy, feed_dict={X_holder:test_X, Y_holder:test_Y})
        print("step:%d test accuracy:%.4f" %(step, test_accuracy))

上面一段代码的运行结果如下:

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
step:100 test accuracy:0.8483
step:200 test accuracy:0.8987
step:300 test accuracy:0.9230
step:400 test accuracy:0.9437
step:500 test accuracy:0.9457
step:600 test accuracy:0.9513
step:700 test accuracy:0.9687
step:800 test accuracy:0.9660
step:900 test accuracy:0.9710
step:1000 test accuracy:0.9740

5.数据准备

第1行代码导入库warnings;
第2行代码表示不打印警告信息;
第3行代码导入库tensorflow,取别名tf;
第4行代码从tensorflow.examples.tutorials.mnist库中导入input_data方法;
第6行代码表示重置tensorflow图
第7行代码加载数据库MNIST赋值给变量mnist;
第8-13行代码定义超参数学习率learning_rate、批量大小batch_size、步数n_steps、输入层大小n_inputs、隐藏层大小n_hidden_units、输出层大小n_classes。
第14、15行代码中placeholder中文叫做占位符,将每次训练的特征矩阵X和预测目标值Y赋值给变量X_holder和Y_holder。

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
learing_rate = 0.001
batch_size =100
n_steps = 28
n_inputs = 28
n_hidden_units = 128
n_classes = 10
X_holder = tf.placeholder(tf.float32)
Y_holder = tf.placeholder(tf.float32)

6.搭建神经网络

本文作者将此章中使用tensorflow库的所有方法的API链接总结成下表,访问需要vpn。

方法 链接
tf.reshape https://www.tensorflow.org/api_docs/python/tf/manip/reshape
tf.nn.rnn_cell.LSTMCell https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicLSTMCell
tf.nn.dynamic_rnn https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn
tf.transpose https://www.tensorflow.org/api_docs/python/tf/transpose
tf.unstack https://www.tensorflow.org/api_docs/python/tf/unstack
tf.Variable https://www.tensorflow.org/api_docs/python/tf/Variable
tf.truncated_normal https://www.tensorflow.org/api_docs/python/tf/truncated_normal
tf.matmul https://www.tensorflow.org/api_docs/python/tf/matmul
tf.reduce_mean https://www.tensorflow.org/api_docs/python/tf/reduce_mean
tf.nn.softmax_cross_entropy_with_logits https://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits
tf.train.AdamOptimizer https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer

第1行代码reshape中文叫做重塑形状,将输入数据X_holder重塑形状为模型需要的;
第2行代码调用tf.nn.rnn_cell.LSTMCell方法实例化LSTM细胞对象;
第3行代码调用tf.nn.dynamic_rnn方法实例化rnn模型对象;
第4、5行代码取得rnn模型中最后一个细胞的数值;
第6、7行代码定义在训练过程会更新的权重Weights、偏置biases;
第8行代码表示xW+b的计算结果赋值给变量predict_Y,即预测值;
第9行代码表示交叉熵作为损失函数loss;
第10行代码表示AdamOptimizer作为优化器optimizer;
第11行代码定义训练过程,即使用优化器optimizer最小化损失函数loss。

reshape_X = tf.reshape(X_holder, [-1, n_steps, n_inputs])
lstm_cell = tf.nn.rnn_cell.LSTMCell(n_hidden_units)
outputs, state = tf.nn.dynamic_rnn(lstm_cell, reshape_X, dtype=tf.float32)
cell_list = tf.unstack(tf.transpose(outputs, [1, 0, 2]))
last_cell = cell_list[-1]
Weights = tf.Variable(tf.truncated_normal([n_hidden_units, n_classes]))
biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))
predict_Y = tf.matmul(last_cell, Weights) + biases
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict_Y, labels=Y_holder))
optimizer = tf.train.AdamOptimizer(learing_rate)
train = optimizer.minimize(loss)

7.参数初始化

对于神经网络模型,重要是其中的W、b这两个参数。
开始神经网络模型训练之前,这两个变量需要初始化。
第1行代码调用tf.global_variables_initializer实例化tensorflow中的Operation对象。


image.png

第2行代码调用tf.Session方法实例化会话对象;
第3行代码调用tf.Session对象的run方法做变量初始化。

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

8.模型训练

第1行代码tf.argmax方法中的第2个参数为1,即求出矩阵中每1行中最大数的索引;
如果argmax方法中的第1个参数为0,即求出矩阵中每1列最大数的索引;
tf.equal方法可以比较两个向量的在每个元素上是否相同,返回结果为向量,向量中元素的数据类型为布尔bool;
第2行代码tf.cast方法可以强制转换向量中元素的数据类型,tf.reduce_mean可以求出向量中元素的均值;
第3行代码表示迭代训练1000次;
第4行代码表示从mnist数据的训练集中选取batch_size数量的样本;
第5行代码每运行1次,即模型训练1次;
第6-10行代码表示从mnist数据的测试集中选取10000个样本计算模型预测准确率。

isCorrect = tf.equal(tf.argmax(predict_Y, 1), tf.argmax(Y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(isCorrect, tf.float32))
for i in range(1000):
    X, Y = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:X, Y_holder:Y})
    step = i + 1
    if step % 100 == 0:
        test_X, test_Y = mnist.test.next_batch(10000)
        test_accuracy = session.run(accuracy, feed_dict={X_holder:test_X, Y_holder:test_Y})
        print("step:%d test accuracy:%.4f" %(step, test_accuracy))

上面一段代码的运行结果如下:

step:100 test accuracy:0.8479
step:200 test accuracy:0.8986
step:300 test accuracy:0.9370
step:400 test accuracy:0.9421
step:500 test accuracy:0.9522
step:600 test accuracy:0.9581
step:700 test accuracy:0.9607
step:800 test accuracy:0.9650
step:900 test accuracy:0.9661
step:1000 test accuracy:0.9685

文章篇幅所限,只打印查看1000次训练的结果,训练5000次即可达到98.5%的准确率。

9.总结

1.本文是作者写的第9篇关于tensorflow编程的博客;
2.在mnist案例中,rnn模型最高可达到98.5%的准确率,cnn模型最高可达到99.2%的准确率,因为本文中的rnn模型只考虑的图像矩阵中每1行的关系,rnn模型可以提取空间特征。
3.理解第6章搭建神经网络的过程中,虽然代码只有11行,本文作者花费了2天将近10多个小时;
4.周莫烦前辈的视频当中关于此章的代码已经过时,并且github中的代码的模型准确率不超过90%,没有超过单隐藏层的DNN网络98%的准确率,这违背了RNN优于DNN的出发点。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,657评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,662评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,143评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,732评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,837评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,036评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,126评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,868评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,315评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,641评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,773评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,470评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,126评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,859评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,095评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,584评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,676评论 2 351

推荐阅读更多精彩内容