TensorFlow实战(四)MNIST手写数字识别进阶——单、多隐层全连接网络

上节手写数字识别入门用的是单个神经元来处理分类问题,准确率达0.8619。这一节做一些改进,以单隐含层全连接网络为例,可使准确率达0.9744
后进一步调整隐含层数测试发现,加入不同层数隐含层达到的准确率,3层>单层>2层。说明神经网络的层数未必越多越好

单个神经元模型

全连接单隐藏层神经网络

导入数据集

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
#下载MNIST数据集到指定目录下 
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)

一、创建模型

1. 定义全连接层函数

def fcn_layer(inputs, #输入数据
             input_dim, #输入神经元数量
             output_dim, #输出神经元数量
             activation=None): #激活函数
    W = tf.Variable(tf.truncated_normal([input_dim,output_dim], stddev=0.1))
                                    #以截断正态分布的随机数初始化W
    b = tf.Variable(tf.zeros([output_dim]))
                                    #以0初始化b
    XWb = tf.matmul(inputs, W) + b #建立表达式:inputs * W + b
    
    if activation is None: #默认不使用激活函数
        outputs = XWb
    else:
        outputs = activation(XWb)
    return outputs

2. 构建输入层

x = tf.placeholder(tf.float32, [None, 784], name="X")

3. 构建隐藏层

#隐藏层包含256个神经元
h1 = fcn_layer(inputs=x,
              input_dim=784,
              output_dim=256,
              activation=tf.nn.relu)

4. 构建输出层

forward = fcn_layer(inputs=h1,
              input_dim=256,
              output_dim=10,
              activation=None)

pred = tf.nn.softmax(forward)

二、训练模型

1. 定义标签数据占位符

y = tf.placeholder(tf.float32, [None, 10], name = "Y")

2. 定义损失函数

#softmax_cross_entropy_with_logits函数原型:  
tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=pred, name=None)
  • 函数功能:计算最后一层是softmax层的cross entropy,把softmax计算与cross entropy计算放到一起了,用一个函数来实现,用来提高程序的运行速度。
  • 参数name:该操作的name
  • 参数labels:shape是[batch_size, num_classes],神经网络期望输出。
  • 参数logits:shape是[batch_size, num_classes] ,神经网络最后一层的输入。
#loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),
                                            # reduction_indices=1)) #原方法:定义交叉熵损失函数
#改进:使用softmax_cross_entropy_with_logits方法定义交叉熵损失函数
#把softmax和cross entropy放到一个函数里计算,提高运算速度    
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=forward,labels=y))

3. 设置训练参数

train_epochs = 40 #训练轮数
batch_size = 50 #单次训练样本数(批次大小)
total_batch = int(mnist.train.num_examples/batch_size) #一轮训练有多少批次
display_step = 1 #显示粒度
learning_rate = 0.01 #学习率

4. 选择优化器

optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function)

5. 定义准确率

#检查预测类别tf.argmax(pred,1)与实际类别tf.argmax(y,1)的匹配情况,相等为1,不等为0,实际要转浮点数
correct_prediction = tf.equal(tf.argmax(y, 1), tf.arg_max(pred, 1))
#准确率,将布尔值转为浮点数,并计算平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

6. 训练模型

#记录训练开始时间
from time import time
startTime = time()

sess = tf.Session()
sess.run(tf.global_variables_initializer())


#开始训练
for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs, ys = mnist.train.next_batch(batch_size) #读取批次数据
        sess.run(optimizer, feed_dict = {x: xs, y: ys}) #执行批次训练
        
    #total_batch个批次训练完成后,用验证数据计算误差与准确率,验证集没有分批
    loss,acc = sess.run([loss_function,accuracy],
                       feed_dict= {x: mnist.validation.images, y: mnist.validation.labels})
    
    #打印训练过程中的详细信息
    if(epoch+1)%display_step == 0:
        print("Train Epoch:",'%02d'%(epoch+1),"Loss=","{:.9f}".format(loss),\
             "Accuracy=","{:.4f}".format(acc))
print("Train Finished!")

#显示运行总时间
duration = time() - startTime
print("Train Finished takes:", "{:.2f}".format(duration))

Train Epoch: 01 Loss= 0.149148121 Accuracy= 0.9594
...
Train Epoch: 40 Loss= 0.608235240 Accuracy= 0.9742
Train Finished!

Train Finished takes: 180.56

三、评估模型

#测试集
accu_test = sess.run(accuracy,
                    feed_dict = {x: mnist.test.images, y: mnist.test.labels})
print("Test Accuracy:",accu_test)
# Test Accuracy: 0.9743

四、保存模型

1. 初始化参数和文件目录

#存储模型的粒度
save_step = 5
#创建保存模型文件的目录
import os
ckpt_dir = "./ckpt_dir/"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

2. 训练时存储模型

#声明完所有变量后,调用tf.train.Saver
saver = tf.train.Saver()
    #打印训练过程中的详细信息
    if(epoch+1)%display_step == 0:
        print("Train Epoch:",'%02d'%(epoch+1),
        "Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc))
    if(epoch+1)%save_step == 0:
        saver.save(sess, os.path.join(ckpt_dir,
        'mnist_h256_model_{:06d}.ckpt'.format(epoch+1))) #存储模型
        print('mnist_h256_model_{:06d}.ckpt saved'.format(epoch+1))
saver.save(sess, os.path.join(ckpt_dir,'mnist_h256_model.ckpt'))
print("Model saved!")

#每训练 5 轮保存一次模型(前面设置的 save_step=5 )
# Train Epoch: 01 Loss= 0.138733894 Accuracy= 0.9616
# Train Epoch: 02 Loss= 0.129554003 Accuracy= 0.9666
# Train Epoch: 03 Loss= 0.147694156 Accuracy= 0.9636
# Train Epoch: 04 Loss= 0.156693459 Accuracy= 0.9630
# Train Epoch: 05 Loss= 0.206912994 Accuracy= 0.9594
# mnist_h256_model_000005.ckpt saved

五、还原模型

1. 设置模型文件的存放目录

#必须指定为模型文件的存放目录,缺省最多保留最近5份
ckpt_dir = "./ckpt_dir/"

2. 读取模型

saver = tf.train.Saver()
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

ckpt = tf.train.get_checkpoint_state(ckpt_dir)

if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path) #从已保存模型中读取参数
    print("Restore model from "+ckpt.model_checkpoint_path)

3. 输出还原模型的准确率

输出模型准确率,发现和最终存盘的模型准确率一致,说明恢复的就是最新存盘文件模型

print("Accuracy:", accuracy.eval(session=sess,
feed_dict={x: mnist.test.images, y: mnist.test.labels}))

# Accuracy: 0.9744
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,456评论 5 477
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,370评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,337评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,583评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,596评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,572评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,936评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,595评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,850评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,601评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,685评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,371评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,951评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,934评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,167评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 43,636评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,411评论 2 342