【转载】TensorFlow入门之MNIST最佳实践

原文链接:www.cnblogs.com/yinzm/p/7123639.html


在上一篇《TensorFlow入门之MNIST样例代码分析》中,我们讲解了如果来用一个三层全连接网络实现手写数字识别。但是在实际运用中我们需要更有效率,更加灵活的代码。在TensorFlow实战这本书中给出了更好的实现,他将程序分为三个模块,分别是前向传播过程模块,训练模块和验证检测模块。并且在这个版本中添加了模型持久化功能,我们可以将模型保存下来,方便之后的模型检验,并且我们可以一边训练新的模型,一边来检验模型,代码更加的灵活高效。

前向传播模块

首先将前向传播过程抽象出来,作为一个可以作为训练测试共享的模块,取名为mnist_inference.py,将这个过程抽象出来的好处是,一是可以保证在训练或者测试的过程中前向传播的一致性,提高代码的复用性。还有一点是我们可以更好地将其与滑动平均模型与模型持久化功能结合,更加灵活的来检验新的模型。mnist_inference.py代码如下:

# -*- coding: utf-8 -*-

import tensorflow as tf

# 定义神经网络结构相关的参数

INPUT_NODE = 784

OUTPUT_NODE = 10

LAYER1_NODE = 500

# 通过tf.get_variable函数来获取变量。在训练神经网络时会创建这些变量;在测试时会通

# 过保存的模型加载这些变量的取值。而且更加方便的是,因为可以在变量加载时将滑动平均变

# 量重命名,所以可以直接通过相同的名字在训练时使用变量自身,而在测试时使用变量的滑动

# 平均值。在这个函数中也会将变量的正则化损失加入到损失集合。

def get_weight_variable(shape, regularizer):

weights = tf.get_variable(

"weights", shape,

initializer=tf.truncated_normal_initializer(stddev=0.1)

)

# 当给出了正则化生成函数时,将当前变量的正则化损失加入名字为losses的集合。在这里

# 使用了add_to_collection函数将一个张量加入一个集合,而这个集合的名称为losses。

# 这是自定义的集合,不在TensorFlow自动管理的集合列表中。

if regularizer != None:

tf.add_to_collection('losses', regularizer(weights))

return weights

# 定义神经网络的前向传播过程

def inference(input_tensor, regularizer):

# 声明第一层神经网络的变量并完成前向传播过程。

with tf.variable_scope('layer1'):

# 这里通过tf.get_variable或者tf.Variable没有本质区别,因为在训练或者测试

# 中没有在同一个程序中多次调用这个函数。如果在同一个程序中多次调用,在第一次

# 调用之后需要将reuse参数设置为True。

weights = get_weight_variable(

[INPUT_NODE, LAYER1_NODE], regularizer

)

biases = tf.get_variable(

"biases", [LAYER1_NODE],

initializer=tf.constant_initializer(0.0)

)

layer1 = tf.nn.relu(tf.matmul(input_tensor, weights)+biases)

# 类似的声明第二层神经网络的变量并完成前向传播过程。

with tf.variable_scope('layer2'):

weights = get_weight_variable(

[LAYER1_NODE, OUTPUT_NODE], regularizer

)

biases = tf.get_variable(

"biases", [OUTPUT_NODE],

initializer=tf.constant_initializer(0.0)

)

layer2 = tf.matmul(layer1, weights) + biases

# 返回最后前向传播的结果

return layer2

训练模块

将训练模型的模块提取出来,训练模块命名为mnist_train.py,在下面的代码中每过1000个step我们就保存一次模型。代码如下:

# -*- coding: utf-8 -*-

import os

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

# 加载mnist_inference.py中定义的常量和前向传播的函数。

import mnist_inference

# 配置神经网络的参数。

BATCH_SIZE = 100

LEARNING_RATE_BASE = 0.8

LEARNING_RATE_DECAY = 0.99

REGULARIZATION_RATE = 0.0001

TRAINING_STEPS = 30000

MOVING_AVERAGE_DECAY = 0.99

# 模型保存的路径和文件名

MODEL_SAVE_PATH = "./model/"

MODEL_NAME = "model.ckpt"

def train(mnist):

# 定义输入输出placeholder。

x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')

y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')

regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)

# 直接使用mnist_inference.py中定义的前向传播过程

y = mnist_inference.inference(x, regularizer)

global_step = tf.Variable(0, trainable=False)

# 定义损失函数、学习率、滑动平均操作以及训练过程

variable_averages = tf.train.ExponentialMovingAverage(

MOVING_AVERAGE_DECAY, global_step

)

variable_averages_op = variable_averages.apply(

tf.trainable_variables()

)

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(

logits=y, labels=tf.argmax(y_, 1)

)

cross_entropy_mean = tf.reduce_mean(cross_entropy)

loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))

learning_rate = tf.train.exponential_decay(

LEARNING_RATE_BASE,

global_step,

mnist.train.num_examples / BATCH_SIZE,

LEARNING_RATE_DECAY

)

train_step = tf.train.GradientDescentOptimizer(learning_rate)\

.minimize(loss, global_step=global_step)

with tf.control_dependencies([train_step, variable_averages_op]):

train_op = tf.no_op(name='train')

# 初始化TensorFlow持久化类

saver = tf.train.Saver()

with tf.Session() as sess:

tf.global_variables_initializer().run()

# 在训练过程中不再测试模型在验证数据上的表现,验证和测试的过程将会有一个独

# 立的程序来完成。

for i in range(TRAINING_STEPS):

xs, ys = mnist.train.next_batch(BATCH_SIZE)

_, loss_value, step = sess.run([train_op, loss, global_step],

feed_dict={x: xs, y_: ys})

# 每1000轮保存一次模型

if i % 1000 == 0:

# 输出当前的训练情况。这里只输出了模型在当前训练batch上的损失

# 函数大小。通过损失函数的大小可以大概了解训练的情况。在验证数

# 据集上正确率的信息会有一个单独的程序来生成

print("After %d training step(s), loss on training "

"batch is %g." % (step, loss_value))

# 保存当前的模型。注意这里给出了global_step参数,这样可以让每个

# 被保存的模型的文件名末尾加上训练的轮数,比如“model.ckpt-1000”,

# 表示训练1000轮之后得到的模型。

saver.save(

sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME),

global_step=global_step

)

def main(argv=None):

mnist = input_data.read_data_sets("./data", one_hot=True)

train(mnist)

if __name__ == "__main__":

tf.app.run()

验证与测试模块

验证模块与测试模块可以对保存好的训练模型进行验证与测试,在下面的代码中我们选择每过10秒钟验证一个最新的模型。这样做的好处是可以将训练与验证或者测试分割开来,同时进行。该模块命名为mnist_eval.py。

# -*- coding: utf-8 -*-

import time

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

# 加载mnist_inference.py 和mnist_train.py中定义的常量和函数。

import mnist_inference

import mnist_train

# 每10秒加载一次最新的模型,并且在测试数据上测试最新模型的正确率

EVAL_INTERVAL_SECS = 10

def evaluate(mnist):

with tf.Graph().as_default() as g:

# 定义输入输出的格式。

x = tf.placeholder(

tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input'

)

y_ = tf.placeholder(

tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input'

)

validate_feed = {x: mnist.validation.images,

y_: mnist.validation.labels}

# 直接通过调用封装好的函数来计算前向传播的结果。因为测试时不关注ze正则化损失的值

# 所以这里用于计算正则化损失的函数被设置为None。

y = mnist_inference.inference(x, None)

# 使用前向传播的结果计算正确率。如果需要对未知的样例进行分类,那么使用

# tf.argmax(y,1)就可以得到输入样例的预测类别了。

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# 通过变量重命名的方式来加载模型,这样在前向传播的过程中就不需要调用求滑动平均

# 的函数来获取平均值了。这样就可以完全共用mnist_inference.py中定义的

# 前向传播过程。

variable_averages = tf.train.ExponentialMovingAverage(

mnist_train.MOVING_AVERAGE_DECAY

)

variables_to_restore = variable_averages.variables_to_restore()

saver = tf.train.Saver(variables_to_restore)

# 每隔EVAL_INTERVAL_SECS秒调用一次计算正确率的过程以检验训练过程中正确率的

# 变化。

while True:

with tf.Session() as sess:

# tf.train.get_checkpoint_state函数会通过checkpoint文件自动

# 找到目录中最新模型的文件名。

ckpt = tf.train.get_checkpoint_state(

mnist_train.MODEL_SAVE_PATH

)

if ckpt and ckpt.model_checkpoint_path:

# 加载模型。

saver.restore(sess, ckpt.model_checkpoint_path)

# 通过文件名得到模型保存时迭代的轮数。

global_step = ckpt.model_checkpoint_path\

.split('/')[-1].split('-')[-1]

accuracy_score = sess.run(accuracy,

feed_dict=validate_feed)

print("After %s training step(s), validation "

"accuracy = %g" % (global_step, accuracy_score))

else:

print("No checkpoint file found")

return

time.sleep(EVAL_INTERVAL_SECS)

def main(argv=None):

mnist = input_data.read_data_sets("./data", one_hot=True)

evaluate(mnist)

if __name__ == "__main__":

tf.app.run()


总结

这个样例是一个非常好的可以用来理解TensorFlow的程序,特别是TensorFlow的计算图的理解,还有模型持久化与恢复,变量的管理,滑动平均模型的实现等等。还有这种灵活的模块分块的思想也值得学习。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,686评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,668评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,160评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,736评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,847评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,043评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,129评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,872评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,318评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,645评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,777评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,470评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,126评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,861评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,095评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,589评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,687评论 2 351

推荐阅读更多精彩内容