手把手教你用Tensorflow搭建RNN

自从跟导师吹了波牛之后,开始将自己的研究重心转为NLP方向。那么RNN就成了必须要了解的一个DL模型(导师暑假去了趟西安,学习一波DL,回来神神叨叨地说CNN到瓶颈期了,别瞎折腾了)。对于RNN,我就默认大家都懂得其中的原理,有不明白的可以去看NG 的视频教学:https://mooc.study.163.com/smartSpec/detail/1001319001.htm。话不多说,开搞开搞!!!


Step1 搭建环境

系统:Windows7+TensorFlow1.9.0(cpu)+Python3.6


Step2 加载数据集

MNIST是一个手写数字数据库,它有55000个训练样本集和10000个测试样本集。它是MNIST数据库的一个子集。其中每张图片固定大小为28×28的黑白图片。如下图所示:

MNIST数据集

使用Tensorflow加载内置的MNIST数据集,具体方法展示如下。

加载方法

加载完成后,打印训练样本数(ntrain),测试样本数(ntest),样本 总维度(dim)以及分类数(nlasses)

打印结果

Step3 RNN模型

接下来我们就需要去看看,我们所要搭建的RNN模型到底长啥样。

这是从NG那边搞来的图,相信大家都能明白其中的奥秘吧。但当时我在学的时候很困惑,这东西实际中咋用,搞成这副鬼样子,真有那么神?em...下面这张图就是利用MNIST数据集来做的一个实验。

简单叙述下该模型的建立过程,每一个样本都可以看出是一个[28,28]的矩阵,那么将矩阵的每一行作为一个输入向量,大小为[1,28],那么整个模型就拥有28个输入神经元,这28个神经元我们将其统称为输入层。完成后我再输入层与RNN层之间增加一个含有128个神经元的隐藏层,用于对输入层进行特征 提取,形成一个[1,128]的向量传入RNN中。RNN中的内部构造参见LSTM,形成两个向量分别为LSTM_O,LSTM_S,大小都为[1,128]。其中LSTM_O为RNN模型的输出,LSTM_S为RNN模型的内部记忆向量,传递到下一个RNN神经元。最后对LSTM_O进行Softmax处理,通过概率分析出该样本的类别。

接下来我们对模型中所涉及的权重、偏置以及各层神经元数量的设置。其中W["h1"]为输入层到隐藏层的权重,大小为[28,128],W["h2"]为隐藏层到RNN的权重,大小为[128,10]。b["b1"]与b["b2"]同理。

权重、偏置

Step4 创建RNN模型(关键步骤)

创建RNN

42~43行:对输入数据进行预处理操作。这里涉及到batch_size的问题,在训练时我们通常是将一批数据导入模型来提高模型的效率,那么批次的大小就是batch_size。即我们可以理解为我们是将一个batch_size*28*28的三维矩阵导入了我们的RNN模型,那么我们就要对该矩阵进行变换从而满足我们[None,28]的要求。

44~45行:隐藏层处理好输入数据后形成一个[None,128]的矩阵。然后对该矩阵进行切割,我们的RNN一共有28个输入单元,那就切成28个咯。

46~50行:将切好的矩阵依次传入RNN中。接下来是对RNN内部的设置,这边使用的是LSTM(tf.nn.rnn_cell.BasicLSTMCell()),当然tf.nn也为我们实现好了其他的内部设置方便我们调用。


Step5 超参数的定义(损,优,学,准,初)

该步骤定义模型中我们所需要的一些超参数。Tensorflow拥有现成的方法,方便我们调用。

学习率:leraning_rate

损失:cost

优化方法:optm

参数初始化:init

超参数的定义

Step6 训练测试

最后一步,设置好迭代次数(training_epoch)、批次大小(batch_size)等后,将MNIST的数据集加载到Step5中设置好的X、Y中,完成训练测试。没什么好多说的,我每个模型最后的训练测试都这样,照着写吧。

训练测试

Step7 结果展示

总共迭代了5次,电脑跑太慢了...

准确率达到93.9%,还彳亍 口 巴!

结果

附上所有代码:

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

from tensorflow.contrib import rnn

# 加载数据

mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

trainimgs, trainlabels, testimgs, testlabels \

= mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

ntrain, ntest, dim, nclasses\

=trainimgs.shape[0],testimgs.shape[0],trainimgs.shape[1],trainlabels.shape[1]

#print(ntrain, ntest, dim, nclasses)

print ("MNIST loaded")

#设置参数,权重,偏置

diminput = 28

dimhidden = 128

dimoutput = nclasses

nsteps = 28

W = {"h1" : tf.Variable(tf.random_normal([diminput,dimhidden])),

    "h2" : tf.Variable(tf.random_normal([dimhidden,dimoutput]))}

b = {"b1" : tf.Variable(tf.random_normal([dimhidden])),

    "b2" : tf.Variable(tf.random_normal([dimoutput]))}

# 创建模型

def RNN(X,W,b,nsteps):

    X = tf.transpose(X,[1,0,2])

    X = tf.reshape(X,[-1,diminput])

    H_1 = tf.matmul(X,W["h1"])+b["b1"]

    H_1 = tf.split(H_1,nsteps,0)

    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden,forget_bias=1.0)

    LSTM_O,LSTM_S = rnn.static_rnn(lstm_cell,H_1,dtype=tf.float32)

    O = tf.matmul(LSTM_O[-1],W["h2"])+b["b2"]

    return {"X":X,"H_1":H_1,"LSTM_O":LSTM_O,"LSTM_S":LSTM_S,"O":O} 

print ("Network ready")

# 设置损失,优化,学习率,准确率,参数初始化

learning_rate = 0.001

x      = tf.placeholder("float", [None, nsteps, diminput])

y      = tf.placeholder("float", [None, dimoutput])

myrnn  = RNN(x, W, b, nsteps)

pred  = myrnn['O']

cost  = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=pred))

optm  = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

accr  = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1), tf.argmax(y,1)), tf.float32))

init  = tf.global_variables_initializer()

print ("Network Ready!")

# 训练,测试

#所有样本迭代(epoch)5次

training_epochs = 5

#每进行一次迭代选择的样本数

batch_size      = 16

#展示

display_step    = 1

sess = tf.Session()

sess.run(init)

print ("Start optimization")

for epoch in range(training_epochs):

    avg_cost = 0.

    total_batch = int(mnist.train.num_examples/batch_size)

    #total_batch = 100

    # Loop over all batches

    for i in range(total_batch):

        batch_xs, batch_ys = mnist.train.next_batch(batch_size)

        batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))

        # Fit training using batch data

        feeds = {x: batch_xs, y: batch_ys}

        sess.run(optm, feed_dict=feeds)

        # Compute average loss

        avg_cost += sess.run(cost, feed_dict=feeds)/total_batch

    # Display logs per epoch step

    if epoch % display_step == 0:

        print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))

        feeds = {x: batch_xs, y: batch_ys}

        train_acc = sess.run(accr, feed_dict=feeds)

        print (" Training accuracy: %.3f" % (train_acc))

        testimgs = testimgs.reshape((ntest, nsteps, diminput))

        feeds = {x: testimgs, y: testlabels}

        test_acc = sess.run(accr, feed_dict=feeds)

        print (" Test accuracy: %.3f" % (test_acc))

print ("Optimization Finished.")

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,634评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,951评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,427评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,770评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,835评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,799评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,768评论 3 416
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,544评论 0 271
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,979评论 1 308
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,271评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,427评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,121评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,756评论 3 324
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,375评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,579评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,410评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,315评论 2 352