利用RNN进行文章生成

# -*- coding: utf-8 -*-
"""
Created on Sat Apr 28 16:40:04 2018

@author: yanghe
"""
import tensorflow as tf
import numpy as np
import io

text = io.open('shakespeare.txt', encoding='utf-8').read().lower()

chars = sorted(list(set(text)))
learning_rate = 0.01
training_steps = 30
batch_size = 64
display_step = 10
n_steps = 38
n_input = 38
n_hidden = 256
n_classes = 38
# 构建数据集 
#  X:[len(text)-n_steps,n_steps]
#
def build_data(text, n_steps = 40, stride = 3):
    X = []
    Y = []
    for i in range(0, len(text) - n_steps, stride):
        X.append(text[i: i + n_steps])
        Y.append(text[i + n_steps])
    print('number of training examples:', len(X))
    
    return X, Y


def vectorization(X, Y, n_input, char_indices, n_steps = n_steps):
    m = len(X)
    x = np.zeros((m, n_steps, n_input), dtype=np.bool)
    y = np.zeros((m, n_input), dtype=np.bool)
    for i, sentence in enumerate(X):
        for t, char in enumerate(sentence):
            x[i, t, char_indices[char]] = 1
        y[i, char_indices[Y[i]]] = 1
        
    return x, y 
    
    
def sample(preds, temperature=1.0):
    e_x = np.exp(preds - np.max(preds))
    preds = e_x / e_x.sum(axis=1)
    out = np.random.choice(range(len(chars)), p = preds.ravel())
    return out
    

    
def random_mini_batches(x,y,mini_bath_size =64,seed =0):
    np.random.seed(seed)
    m = x.shape[0]
    mini_batches = []
    permutation = list(np.random.permutation(m))
    shuffled_x = x[permutation]
    shuffled_y = y[permutation]
    
    num_complete_minibatches = int(np.floor(m/mini_bath_size))
    for k in range(0,num_complete_minibatches):
        mini_batch_x = shuffled_x[k*mini_bath_size:(k+1)*mini_bath_size,:,:]
        mini_batch_y = shuffled_y[k*mini_bath_size:(k+1)*mini_bath_size,]
        mini_batch = (mini_batch_x,mini_batch_y)
        mini_batches.append(mini_batch)
    if m % mini_bath_size  != 0:
        mini_batch_x = shuffled_x[(k+1)*mini_bath_size:,:,:]
        mini_batch_y = shuffled_y[(k+1)*mini_bath_size:,:]      
        mini_batch = (mini_batch_x,mini_batch_y)
        mini_batches.append(mini_batch)
    return mini_batches

def BiRNN(x):
    with tf.variable_scope('Bi_RNN'):
        weights = tf.get_variable('weights', shape=[2*n_hidden, n_classes], initializer=tf.truncated_normal_initializer(stddev=0.1))
        biases = tf.get_variable('biases', shape=[n_classes], initializer=tf.truncated_normal_initializer(stddev=0.1))
        x = tf.transpose(x, [1, 0, 2])
        x = tf.reshape(x, [-1, n_steps])
        x = tf.split(x, n_steps)
        lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
        lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
        outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(lstm_fw_cell,
                                                                lstm_bw_cell,
                                                                x,
                                                                dtype=tf.float32)
        
    return tf.matmul(outputs[-1], weights) + biases

def train():
    y_ = BiRNN(x)
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_, labels=y))
    train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    correct_pred = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    saver = tf.train.Saver()
    with tf.Session() as sess :
        tf.global_variables_initializer().run()
        mini_batches = random_mini_batches(X_, Y_, mini_bath_size=batch_size) 
        for i in range(training_steps):
            for x_train,y_train in mini_batches:
                _ = sess.run(train_op, feed_dict={x:x_train , y:y_train })
            if i % 5 == 0 :
                j = np.random.randint(len(mini_batches[0]))
                validate_acc = sess.run([accuracy], feed_dict={x:mini_batches[j][0] , y:mini_batches[j][1]})
                print("After %d , validation accuracy is %s " % (i,  validate_acc))
        test_acc=sess.run(accuracy,feed_dict={x:mini_batches[-1][0] , y:mini_batches[-1][1]})
        print(("After %d training step(s), test accuracy using average model is %.2f" %(training_steps, test_acc)))    
        saver.save(sess , 'saver/moedl1.ckpt')
    
def generate_output():
    #with tf.variable_scope('Bi_RNN') as scope:
    #scope.reuse_variables()
    y_ = BiRNN(x)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess,'saver/moedl1.ckpt')
        generated = ''
        usr_input = input("Write the beginning of your poem, the Shakespeare machine will complete it. Your input is: ")
        sentence = ('{0:0>' + str(n_steps) + '}').format(usr_input).lower()
        generated += usr_input 
        for i in range(400):
            x_pred = np.zeros((1, n_steps, len(chars)))
    
            for t, char in enumerate(sentence):
                if char != '0':
                    x_pred[0, t, char_indices[char]] = 1.
            preds = sess.run(y_ ,feed_dict={x:x_pred})
    
            
            next_index = sample(preds, temperature = 1.0)
            next_char = indices_char[next_index]
            generated += next_char
            sentence = sentence[1:] + next_char
            if next_char == '\n':
                continue
        print(generated)
        
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))
X, Y = build_data(text, n_steps, stride = 3)
X_, Y_ = vectorization(X, Y, n_input = len(chars), char_indices = char_indices) 

x = tf.placeholder(tf.float32, [None, n_steps,n_input])
y = tf.placeholder(tf.float32, [None, n_classes])

#train()
generate_output()

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,014评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,796评论 3 386
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,484评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,830评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,946评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,114评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,182评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,927评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,369评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,678评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,832评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,533评论 4 335
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,166评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,885评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,128评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,659评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,738评论 2 351

推荐阅读更多精彩内容