TensorFlow -MNIST识别-1

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 用于设置将记录哪些消息的阈值
old_v = tf.logging.get_verbosity()
# 设置日志反馈模式
tf.logging.set_verbosity(tf.logging.ERROR)

# 载入数据集
mnist = input_data.read_data_sets('/文件路径/MNIST_data', one_hot=True)
# 设置batch大小
batch_size = 100
n_batch = mnist.train.num_examples // batch_size

# 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

# 神经元保留率
keep_prob = tf.placeholder(tf.float32)
# 学习率
LR = tf.Variable(0.001, dtype=tf.float32)

# 神经网络1
W1 = tf.Variable(tf.truncated_normal([784, 600], stddev=0.1))
b1 = tf.Variable(tf.zeros([600]) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
L1_drop = tf.nn.dropout(L1, keep_prob)

# 神经网络2
W2 = tf.Variable(tf.truncated_normal([600, 300], stddev=0.1))
b2 = tf.Variable(tf.zeros([300]) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)
L2_drop = tf.nn.dropout(L2, keep_prob)

W3 = tf.Variable(tf.truncated_normal([300, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]) + 0.1)
# softmax预测分类
prediction = tf.nn.softmax(tf.matmul(L2_drop, W3) + b3)

# 交叉商
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))

# AdamOptimizer()优化器
train_step = tf.train.AdamOptimizer(LR).minimize(loss)

# 初始化变量
init_op = tf.global_variables_initializer()

# 这里是返回一个储存布尔类型的矩阵,
# tf.equal,对比两个矩阵(向量)相相等的元素,
# tf.argmax返回最大值索引(对应的数字)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))

# 求准确率:tf.cast转化数据格式, tf.reduce_mean求平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:
    sess.run(init_op)
    for epoch in range(31):
        sess.run(tf.assign(LR, 0.001*(0.9**epoch)))
        for batch in range(n_batch):
            # 分批次训练,next_batch记录上一个结尾,进行下一个开始
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.9})
        learning_rate = sess.run(LR)
        # 准确率
        acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1})
        # 打印准确率和学习率
        print("Iter" + str(epoch) + ",Testing Accuracy=" + str(acc) +
              "   Learning Rate=" + str(learning_rate))

# 参考文档:https://www.w3cschool.cn/tensorflow_python
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容