(八)sequence to sequence —4

实现双向的dynamic_lstm+beam_search

基于tensorflow1.4 Seq2seq的实现

encoder使用的是双向的LSTM

import helpers
import tensorflow as tf
from tensorflow.python.util import nest
from tensorflow.contrib import seq2seq,rnn

tf.__version__

tf.reset_default_graph()
sess = tf.InteractiveSession()

PAD = 0
EOS = 1


vocab_size = 10
input_embedding_size = 20
encoder_hidden_units = 25

decoder_hidden_units = encoder_hidden_units

import helpers as data_helpers
batch_size = 10

# 一个generator,每次产生一个minibatch的随机样本

batches = data_helpers.random_sequences(length_from=3, length_to=8,
                                   vocab_lower=2, vocab_upper=10,
                                   batch_size=batch_size)

print('产生%d个长度不一(最短3,最长8)的sequences, 其中前十个是:' % batch_size)
for seq in next(batches)[:min(batch_size, 10)]:
    print(seq)
    
tf.reset_default_graph()
sess = tf.InteractiveSession()
mode = tf.contrib.learn.ModeKeys.TRAIN
产生10个长度不一(最短3,最长8)的sequences, 其中前十个是:
[6, 5, 7, 2]
[3, 7, 2, 9, 7, 8]
[2, 9, 2, 8, 9]
[7, 7, 8]
[6, 5, 6, 7, 9, 2, 7]
[6, 2, 3, 6]
[4, 7, 7]
[6, 7, 7, 4, 8, 3, 2, 3]
[8, 4, 5, 4]
[3, 3, 5, 9, 4]

1.使用seq2seq库实现seq2seq模型

1. 计算图的数据的placeholder

with tf.name_scope('minibatch'):
    encoder_inputs = tf.placeholder(tf.int32, [None, None], name='encoder_inputs')
    
    encoder_inputs_length = tf.placeholder(tf.int32, [None], name='encoder_inputs_length')
    
    decoder_targets = tf.placeholder(tf.int32, [None, None], name='decoder_targets')
    
    decoder_inputs = tf.placeholder(shape=(None, None),dtype=tf.int32,name='decoder_inputs')
    
    #decoder_inputs_length和decoder_targets_length是一样的
    decoder_inputs_length = tf.placeholder(shape=(None,),
                                            dtype=tf.int32,
                                            name='decoder_inputs_length')
  

2.设置embedding部分

# 构建embedding矩阵,encoder和decoder公用该词向量矩阵
embedding = tf.get_variable('embedding', [vocab_size,input_embedding_size])
encoder_inputs_embedded = tf.nn.embedding_lookup(embedding,encoder_inputs)

3.定义lstm_cell

fw_cell = bw_cell =  rnn.LSTMCell(encoder_hidden_units)

4.定义encoder 部分

with tf.variable_scope('encoder'):
    ((encoder_fw_outputs,
  encoder_bw_outputs),
 (encoder_fw_final_state,
  encoder_bw_final_state)) = (
    tf.nn.bidirectional_dynamic_rnn(cell_fw=fw_cell,
                                    cell_bw=bw_cell,
                                    inputs=encoder_inputs_embedded,
                                    sequence_length=encoder_inputs_length,
                                    dtype=tf.float32, time_major=False)
    )
    encoder_fw_outputs
<tf.Tensor 'encoder/bidirectional_rnn/fw/fw/transpose:0' shape=(?, ?, 25) dtype=float32>
    encoder_bw_outputs
<tf.Tensor 'encoder/ReverseSequence:0' shape=(?, ?, 25) dtype=float32>
    encoder_fw_final_state
LSTMStateTuple(c=<tf.Tensor 'encoder/bidirectional_rnn/fw/fw/while/Exit_2:0' shape=(?, 25) dtype=float32>, h=<tf.Tensor 'encoder/bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(?, 25) dtype=float32>)
    encoder_bw_final_state
LSTMStateTuple(c=<tf.Tensor 'encoder/bidirectional_rnn/bw/bw/while/Exit_2:0' shape=(?, 25) dtype=float32>, h=<tf.Tensor 'encoder/bidirectional_rnn/bw/bw/while/Exit_3:0' shape=(?, 25) dtype=float32>)

对encoder的输出进行合并

输出:
outputs是一个(output_fw, output_bw)元组,output_fw和output_bw的shape都是[batch_size, sequence_length, num_units]

output_states是一个(output_state_fw, output_state_bw) 元组,分别是前向和后向最后一个Cell的Output,output_state_fw和output_state_bw的类型都是LSTMStateTuple,这个类有两个属性c和h,分别表示Memory Cell和Hidden State,如下图:


basic_lstm_cell.png
    encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)

    encoder_final_state_h = tf.concat((encoder_fw_final_state.h, encoder_bw_final_state.h), 1)
    encoder_final_state_c = tf.concat((encoder_fw_final_state.c, encoder_bw_final_state.c), 1)
    encoder_final_state = rnn.LSTMStateTuple(
        c=encoder_final_state_c,
        h=encoder_final_state_h
    )
    encoder_final_state
LSTMStateTuple(c=<tf.Tensor 'concat_2:0' shape=(?, 50) dtype=float32>, h=<tf.Tensor 'concat_1:0' shape=(?, 50) dtype=float32>)

5.定义decoder 部分

def _create_rnn_cell2():
    def single_rnn_cell(encoder_hidden_units):
        # 创建单个cell,这里需要注意的是一定要使用一个single_rnn_cell的函数,不然直接把cell放在MultiRNNCell
        # 的列表中最终模型会发生错误
        single_cell = rnn.LSTMCell(encoder_hidden_units*2)
        #添加dropout
        single_cell = rnn.DropoutWrapper(single_cell, output_keep_prob=0.5)
        return single_cell
            #列表中每个元素都是调用single_rnn_cell函数
            #cell = rnn.MultiRNNCell([single_rnn_cell() for _ in range(self.num_layers)])
    cell = rnn.MultiRNNCell([single_rnn_cell(encoder_hidden_units) for _ in range(1)])
    return cell 

with tf.variable_scope('decoder'):
    #single_cell = rnn.LSTMCell(encoder_hidden_units)
    #decoder_cell = rnn.MultiRNNCell([single_cell for _ in range(1)])
    decoder_cell = rnn.LSTMCell(encoder_hidden_units*2)
    #定义decoder的初始状态
    decoder_initial_state = encoder_final_state
    
    #定义output_layer
    output_layer = tf.layers.Dense(vocab_size,kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1))
    
    decoder_inputs_embedded = tf.nn.embedding_lookup(embedding, decoder_inputs)
    
    # 训练阶段,使用TrainingHelper+BasicDecoder的组合,这一般是固定的,当然也可以自己定义Helper类,实现自己的功能
    training_helper = seq2seq.TrainingHelper(inputs=decoder_inputs_embedded,
                                                        sequence_length=decoder_inputs_length,
                                                        time_major=False, name='training_helper')
    training_decoder = seq2seq.BasicDecoder(cell=decoder_cell, helper=training_helper,
                                                       initial_state=decoder_initial_state,
                                                       output_layer=output_layer)
    
    # 调用dynamic_decode进行解码,decoder_outputs是一个namedtuple,里面包含两项(rnn_outputs, sample_id)
    # rnn_output: [batch_size, decoder_targets_length, vocab_size],保存decode每个时刻每个单词的概率,可以用来计算loss
    # sample_id: [batch_size], tf.int32,保存最终的编码结果。可以表示最后的答案
    max_target_sequence_length = tf.reduce_max(decoder_inputs_length, name='max_target_len')
    decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder=training_decoder,
                                                          impute_finished=True,
                                                          maximum_iterations=max_target_sequence_length)
    decoder_logits_train = tf.identity(decoder_outputs.rnn_output)
    sample_id = decoder_outputs.sample_id
    max_target_sequence_length = tf.reduce_max(decoder_inputs_length, name='max_target_len')
    mask = tf.sequence_mask(decoder_inputs_length,max_target_sequence_length, dtype=tf.float32, name='masks')
    print('\t%s' % repr(decoder_logits_train))
    print('\t%s' % repr(decoder_targets))
    print('\t%s' % repr(sample_id))
    loss = seq2seq.sequence_loss(logits=decoder_logits_train,targets=decoder_targets, weights=mask)
    <tf.Tensor 'decoder/Identity:0' shape=(?, ?, 10) dtype=float32>
    <tf.Tensor 'minibatch/decoder_targets:0' shape=(?, ?) dtype=int32>
    <tf.Tensor 'decoder/decoder/transpose_1:0' shape=(?, ?) dtype=int32>
with tf.variable_scope('decoder',reuse=True):
    start_tokens = tf.ones([batch_size, ], tf.int32)*1  #[batch_size]  数值为1
    encoder_state = nest.map_structure(lambda s: seq2seq.tile_batch(s, 3),
                                                   encoder_final_state)
    inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=decoder_cell, embedding=embedding,
                                                                             start_tokens=start_tokens,
                                                                             end_token=1,
                                                                             initial_state=encoder_state,
                                                                             beam_width=3,
                                                                             output_layer=output_layer)
    beam_decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder=inference_decoder,maximum_iterations=10)
    
train_op = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(loss)
sess.run(tf.global_variables_initializer())
def next_feed():
    batch = next(batches)
    
    encoder_inputs_, encoder_inputs_length_ = data_helpers.batch(batch)
    decoder_targets_, decoder_targets_length_ = data_helpers.batch(
        [(sequence) + [EOS] for sequence in batch]
    )
    decoder_inputs_, decoder_inputs_length_ = data_helpers.batch(
        [[EOS] + (sequence) for sequence in batch]
    )
    
    # 在feedDict里面,key可以是一个Tensor
    return {
        encoder_inputs: encoder_inputs_.T,
        decoder_inputs: decoder_inputs_.T,
        decoder_targets: decoder_targets_.T,
        encoder_inputs_length: encoder_inputs_length_,
        decoder_inputs_length: decoder_inputs_length_
    }

x = next_feed()
print('encoder_inputs:')
print(x[encoder_inputs][0,:])
print('encoder_inputs_length:')
print(x[encoder_inputs_length][0])
print('decoder_inputs:')
print(x[decoder_inputs][0,:])
print('decoder_inputs_length:')
print(x[decoder_inputs_length][0])
print('decoder_targets:')
print(x[decoder_targets][0,:])
encoder_inputs:
[6 6 3 4 9 7 4 7]
encoder_inputs_length:
8
decoder_inputs:
[1 6 6 3 4 9 7 4 7]
decoder_inputs_length:
9
decoder_targets:
[6 6 3 4 9 7 4 7 1]
loss_track = []
max_batches = 6001
batches_in_epoch = 200

try:
    # 一个epoch的learning
    for batch in range(max_batches):
        fd = next_feed()
        _, l = sess.run([train_op, loss], fd)
        loss_track.append(l)
        
        if batch == 0 or batch % batches_in_epoch == 0:
            print('batch {}'.format(batch))
            print('  minibatch loss: {}'.format(sess.run(loss, fd)))
            predict_ = sess.run(beam_decoder_outputs.predicted_ids, fd)
            #print(predict_)
            for i, (inp, pred) in enumerate(zip(fd[encoder_inputs], predict_)):
                print('  sample {}:'.format(i + 1))
                print('    input     > {}'.format(inp))
                print('    predicted > {}'.format(pred))
                if i >= 2:
                    break
            print()
        
except KeyboardInterrupt:
    print('training interrupted')
batch 0
  minibatch loss: 2.3011417388916016
  sample 1:
    input     > [4 3 6 6 3 7 7 5]
    predicted > [[9 9 9]
 [9 9 9]
 [6 6 6]
 [6 6 6]
 [6 6 6]
 [5 6 6]
 [6 5 5]
 [6 6 6]
 [6 6 6]
 [5 5 8]]
  sample 2:
    input     > [3 5 3 4 7 0 0 0]
    predicted > [[ 0  0  0]
 [ 4  4  4]
 [ 4  4  4]
 [ 4  4  4]
 [ 9  9  9]
 [ 4  4  4]
 [ 9  9  9]
 [ 7  7  7]
 [ 1  7  1]
 [ 1  7 -1]]
  sample 3:
    input     > [7 4 8 9 9 2 0 0]
    predicted > [[ 9  9  9]
 [ 9  9  9]
 [ 5  5  5]
 [ 0  0  0]
 [ 7  6  6]
 [ 1  6  6]
 [-1  7  7]
 [-1  7  7]
 [-1  7  7]
 [-1  2  7]]

batch 200
  minibatch loss: 1.5333470106124878
  sample 1:
    input     > [6 7 4 2 7 3 0]
    predicted > [[6 6 6]
 [6 6 6]
 [7 7 3]
 [3 3 7]
 [7 3 3]
 [3 3 3]
 [1 1 1]]
  sample 2:
    input     > [7 5 8 5 0 0 0]
    predicted > [[ 5  5  7]
 [ 5  5  5]
 [ 5  5  5]
 [ 1  1  5]
 [ 1 -1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [7 9 7 6 7 2 0]
    predicted > [[5 5 5]
 [7 7 7]
 [7 3 7]
 [3 7 3]
 [7 7 3]
 [3 3 3]
 [1 1 1]]

batch 400
  minibatch loss: 1.137063980102539
  sample 1:
    input     > [2 5 4 8 9 5 2 8]
    predicted > [[9 9 9]
 [5 5 5]
 [8 8 8]
 [4 4 4]
 [5 5 5]
 [4 4 4]
 [4 5 5]
 [5 8 4]
 [1 1 1]]
  sample 2:
    input     > [3 9 3 9 0 0 0 0]
    predicted > [[ 3  9  3]
 [ 9  3  9]
 [ 9  3  3]
 [ 3  9  9]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [9 4 2 8 5 0 0 0]
    predicted > [[ 9  9  9]
 [ 4  4  4]
 [ 5  5  5]
 [ 4  4  8]
 [ 4  8  4]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 600
  minibatch loss: 0.8631468415260315
  sample 1:
    input     > [9 5 2 9 6 8 6 9]
    predicted > [[5 5 5]
 [9 9 9]
 [9 9 9]
 [6 6 6]
 [9 9 2]
 [4 6 9]
 [6 4 6]
 [7 5 8]
 [1 1 1]]
  sample 2:
    input     > [5 7 9 0 0 0 0 0]
    predicted > [[ 5  7  9]
 [ 7  9  7]
 [ 9  5  5]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [8 2 9 0 0 0 0 0]
    predicted > [[ 8  4  4]
 [ 9  9  9]
 [ 2  7  2]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 800
  minibatch loss: 0.7218129634857178
  sample 1:
    input     > [3 4 6 7 2 6 0 0]
    predicted > [[ 6  6  6]
 [ 2  3  3]
 [ 3  2  2]
 [ 2  2  2]
 [ 6  6  6]
 [ 7  5  2]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [5 5 2 8 0 0 0 0]
    predicted > [[ 5  5  5]
 [ 5  5  5]
 [ 2  5  5]
 [ 8  6  8]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [4 3 9 0 0 0 0 0]
    predicted > [[ 4  8  9]
 [ 3  9  6]
 [ 9  3  4]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 1000
  minibatch loss: 0.42369818687438965
  sample 1:
    input     > [8 8 4 0 0 0 0]
    predicted > [[ 8  8  4]
 [ 8  4  8]
 [ 4  8  8]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [6 3 5 4 0 0 0]
    predicted > [[ 6  6  3]
 [ 3  5  4]
 [ 5  3  6]
 [ 4  4  5]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [9 5 5 0 0 0 0]
    predicted > [[ 5  9  5]
 [ 9  5  9]
 [ 5  5  9]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 1200
  minibatch loss: 0.43877652287483215
  sample 1:
    input     > [4 6 8 9 2 0 0 0]
    predicted > [[ 4  8  8]
 [ 6  4  6]
 [ 8  6  4]
 [ 9  2  9]
 [ 2  9  2]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [8 9 2 0 0 0 0 0]
    predicted > [[ 8  9  4]
 [ 9  8  9]
 [ 2  4  3]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [4 4 5 5 4 7 7 0]
    predicted > [[ 4  4  4]
 [ 4  4  4]
 [ 5  5  5]
 [ 5  4  7]
 [ 4  7  4]
 [ 7  5  5]
 [ 7  7  5]
 [ 1  1  1]
 [-1 -1 -1]]

batch 1400
  minibatch loss: 0.37541431188583374
  sample 1:
    input     > [3 9 7 5 8 6 0]
    predicted > [[ 3  3  3]
 [ 9  5  5]
 [ 5  3  9]
 [ 7  9  3]
 [ 6  8  6]
 [ 8  7  5]
 [ 1  1  8]
 [-1 -1  1]]
  sample 2:
    input     > [4 6 9 7 4 9 6]
    predicted > [[4 4 4]
 [6 6 4]
 [9 9 3]
 [4 4 6]
 [7 7 9]
 [9 3 7]
 [6 4 4]
 [1 1 1]]
  sample 3:
    input     > [8 8 2 2 7 8 9]
    predicted > [[8 8 8]
 [2 8 2]
 [8 2 8]
 [8 2 8]
 [7 8 2]
 [2 7 7]
 [9 9 9]
 [1 1 1]]

batch 1600
  minibatch loss: 0.32577282190322876
  sample 1:
    input     > [2 9 7 3 5 3 6 7]
    predicted > [[9 9 9]
 [2 5 7]
 [7 6 2]
 [3 7 3]
 [7 3 6]
 [6 2 5]
 [5 2 3]
 [3 3 4]
 [1 1 1]]
  sample 2:
    input     > [4 3 9 5 0 0 0 0]
    predicted > [[ 4  4  4]
 [ 3  3  9]
 [ 9  9  3]
 [ 5  9  3]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [5 7 9 8 5 0 0 0]
    predicted > [[ 5  5  5]
 [ 7  7  5]
 [ 9  8  3]
 [ 8  9  8]
 [ 5  5  7]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 1800
  minibatch loss: 0.36575061082839966
  sample 1:
    input     > [2 7 4 3 4 6 5 0]
    predicted > [[ 2  2  2]
 [ 7  4  7]
 [ 4  7  4]
 [ 3  3  6]
 [ 4  6  9]
 [ 6  5  3]
 [ 5  4  5]
 [ 1  1  1]
 [-1 -1 -1]]
  sample 2:
    input     > [9 4 4 0 0 0 0 0]
    predicted > [[ 9  4  9]
 [ 4  9  4]
 [ 4  9  2]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [7 3 3 2 9 4 7 4]
    predicted > [[ 3  3  3]
 [ 7  7  7]
 [ 3  2  2]
 [ 2  3  3]
 [ 9  9  9]
 [ 4  4  4]
 [ 7  7  7]
 [ 4  4  1]
 [ 1  1 -1]]

batch 2000
  minibatch loss: 0.19473139941692352
  sample 1:
    input     > [8 6 3 0 0 0 0]
    predicted > [[ 8  8  6]
 [ 6  3  8]
 [ 3  6  3]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [6 4 4 8 5 2 0]
    predicted > [[ 6  6  4]
 [ 4  4  6]
 [ 4  4  4]
 [ 8  8  8]
 [ 5  2  2]
 [ 2  5  5]
 [ 1  1  1]
 [-1 -1 -1]]
  sample 3:
    input     > [7 2 7 6 0 0 0]
    predicted > [[ 7  7  7]
 [ 2  7  7]
 [ 7  2  2]
 [ 6  6  4]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 2200
  minibatch loss: 0.22542116045951843
  sample 1:
    input     > [7 3 4 0 0 0 0 0]
    predicted > [[ 7  3  7]
 [ 3  7  6]
 [ 4  4  9]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [3 9 7 3 0 0 0 0]
    predicted > [[ 3  3  3]
 [ 9  3  7]
 [ 7  9  9]
 [ 3  7  3]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [9 7 6 7 6 3 2 8]
    predicted > [[7 7 9]
 [9 9 7]
 [6 6 6]
 [3 3 7]
 [2 2 6]
 [8 6 3]
 [6 8 8]
 [7 7 2]
 [1 1 1]]

batch 2400
  minibatch loss: 0.236276313662529
  sample 1:
    input     > [5 3 7 8 7 3 0 0]
    predicted > [[ 5  5  5]
 [ 3  7  7]
 [ 7  3  3]
 [ 8  8  8]
 [ 7  3  3]
 [ 3  7  3]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [4 8 3 5 9 3 0 0]
    predicted > [[ 4  4  4]
 [ 8  8  8]
 [ 3  9  5]
 [ 5  3  3]
 [ 9  5  3]
 [ 3  3  9]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [8 2 4 8 7 2 0 0]
    predicted > [[ 8  8  8]
 [ 2  2  4]
 [ 8  4  2]
 [ 4  8  2]
 [ 2  7  8]
 [ 7  2  7]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 2600
  minibatch loss: 0.18354903161525726
  sample 1:
    input     > [2 8 5 5 3 0 0 0]
    predicted > [[ 2  4  2]
 [ 8  5  5]
 [ 5  2  8]
 [ 5  3  8]
 [ 3  8  5]
 [ 1  5  3]
 [-1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [6 8 5 5 5 0 0 0]
    predicted > [[ 6  6  6]
 [ 8  5  8]
 [ 5  8  5]
 [ 5  8  5]
 [ 5  5  2]
 [ 1  5  1]
 [-1  1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [6 2 6 4 2 4 6 3]
    predicted > [[6 6 6]
 [2 2 2]
 [6 4 4]
 [4 6 6]
 [4 6 6]
 [2 2 4]
 [6 4 2]
 [3 9 9]
 [1 1 1]]

batch 2800
  minibatch loss: 0.20125198364257812
  sample 1:
    input     > [9 7 6 2 6 3 0 0]
    predicted > [[ 9  2  2]
 [ 7  9  9]
 [ 6  6  7]
 [ 2  7  6]
 [ 6  3  6]
 [ 3  6  3]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [9 7 5 6 5 9 6 0]
    predicted > [[ 9  5  5]
 [ 7  9  9]
 [ 5  6  6]
 [ 6  7  7]
 [ 5  9  9]
 [ 2  5  7]
 [ 6  6  9]
 [ 1  1  1]
 [-1 -1 -1]]
  sample 3:
    input     > [9 2 2 0 0 0 0 0]
    predicted > [[ 9  2  9]
 [ 2  9  2]
 [ 2  9  9]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 3000
  minibatch loss: 0.14697885513305664
  sample 1:
    input     > [6 2 3 0 0 0 0]
    predicted > [[ 6  6  6]
 [ 2  3  2]
 [ 3  2  9]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [3 9 8 4 9 2 2]
    predicted > [[3 3 3]
 [9 8 9]
 [8 9 8]
 [4 9 9]
 [9 4 4]
 [2 2 2]
 [2 2 2]
 [1 1 1]]
  sample 3:
    input     > [7 7 3 2 0 0 0]
    predicted > [[ 7  7  7]
 [ 7  3  3]
 [ 3  7  7]
 [ 2  2  7]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 3200
  minibatch loss: 0.19483646750450134
  sample 1:
    input     > [3 7 8 8 8 6 0 0]
    predicted > [[ 3  3  3]
 [ 7  8  8]
 [ 8  7  7]
 [ 8  8  8]
 [ 8  6  6]
 [ 6  8  5]
 [ 1  1  8]
 [-1 -1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [6 2 5 8 5 3 3 3]
    predicted > [[ 6  6  6]
 [ 5  5  5]
 [ 2  4  8]
 [ 8  2  2]
 [ 3  3  3]
 [ 5  3  2]
 [ 3  5  5]
 [ 2  8  3]
 [ 1  3  1]
 [-1  1 -1]]
  sample 3:
    input     > [6 4 4 4 3 2 8 0]
    predicted > [[ 6  6  6]
 [ 4  4  4]
 [ 4  4  4]
 [ 4  4  4]
 [ 3  3  3]
 [ 2  8  4]
 [ 8  2  2]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 3400
  minibatch loss: 0.14759384095668793
  sample 1:
    input     > [2 2 6 0 0 0 0 0]
    predicted > [[ 2  2  2]
 [ 2  2  2]
 [ 6  3  3]
 [ 1  1  6]
 [-1 -1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [5 6 6 2 8 5 0 0]
    predicted > [[ 5  2  2]
 [ 6  6  8]
 [ 6  5  6]
 [ 2  8  5]
 [ 8  6  6]
 [ 5  8  5]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [6 9 2 0 0 0 0 0]
    predicted > [[ 6  6  6]
 [ 9  2  9]
 [ 2  9  9]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 3600
  minibatch loss: 0.13171222805976868
  sample 1:
    input     > [2 6 3 9 3 7 7 0]
    predicted > [[ 2  6  6]
 [ 6  9  9]
 [ 3  2  2]
 [ 9  3  3]
 [ 3  7  7]
 [ 7  3  6]
 [ 7  7  7]
 [ 1  1  1]
 [-1 -1 -1]]
  sample 2:
    input     > [3 6 3 3 5 3 6 0]
    predicted > [[ 3  3  3]
 [ 6  3  3]
 [ 3  6  6]
 [ 3  6  3]
 [ 5  5  4]
 [ 3  3  7]
 [ 6  3  3]
 [ 1  1  1]
 [-1 -1 -1]]
  sample 3:
    input     > [2 2 8 9 6 5 5 7]
    predicted > [[2 2 2]
 [2 2 2]
 [8 8 8]
 [9 9 6]
 [6 6 9]
 [5 5 5]
 [5 7 5]
 [7 5 7]
 [1 1 1]]

batch 3800
  minibatch loss: 0.058824554085731506
  sample 1:
    input     > [7 9 9 8 7 7 4 3]
    predicted > [[7 7 9]
 [9 9 7]
 [9 9 7]
 [8 7 8]
 [7 8 9]
 [7 4 7]
 [4 3 4]
 [3 7 3]
 [1 1 1]]
  sample 2:
    input     > [6 4 3 9 5 3 0 0]
    predicted > [[ 6  6  6]
 [ 4  4  4]
 [ 3  3  3]
 [ 9  9  5]
 [ 5  3  9]
 [ 3  5  3]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [8 9 9 7 0 0 0 0]
    predicted > [[ 8  9  9]
 [ 9  8  8]
 [ 9  8  8]
 [ 7  7  7]
 [ 1  9  1]
 [-1  1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 4000
  minibatch loss: 0.09603714197874069
  sample 1:
    input     > [6 3 4 3 7 3 0 0]
    predicted > [[ 6  6  6]
 [ 3  3  3]
 [ 4  4  4]
 [ 3  7  3]
 [ 7  3  3]
 [ 3  3  7]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [8 3 3 6 0 0 0 0]
    predicted > [[ 8  8  3]
 [ 3  3  8]
 [ 3  6  8]
 [ 6  3  6]
 [ 1  1  3]
 [-1 -1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [2 6 4 8 9 9 2 3]
    predicted > [[2 2 2]
 [6 6 6]
 [4 8 4]
 [8 4 9]
 [9 9 8]
 [9 2 9]
 [2 9 3]
 [3 3 2]
 [1 1 1]]

batch 4200
  minibatch loss: 0.18101732432842255
  sample 1:
    input     > [5 4 8 9 8 5 5 0]
    predicted > [[ 5  5  5]
 [ 4  4  4]
 [ 8  8  8]
 [ 9  9  9]
 [ 8  5  4]
 [ 5  8  5]
 [ 5  5  8]
 [ 1  1  1]
 [-1 -1 -1]]
  sample 2:
    input     > [2 6 4 8 8 3 0 0]
    predicted > [[ 2  2  2]
 [ 6  6  6]
 [ 4  4  8]
 [ 8  8  4]
 [ 8  3  8]
 [ 3  8  3]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [6 6 7 2 6 4 6 0]
    predicted > [[ 6  6  6]
 [ 6  6  6]
 [ 7  7  7]
 [ 2  2  6]
 [ 6  6  2]
 [ 6  4  4]
 [ 4  6  6]
 [ 1  1  1]
 [-1 -1 -1]]

batch 4400
  minibatch loss: 0.13958677649497986
  sample 1:
    input     > [4 2 9 4 9 7 2 2]
    predicted > [[4 4 4]
 [2 9 9]
 [9 2 2]
 [9 4 4]
 [4 7 2]
 [7 2 7]
 [2 9 9]
 [2 2 2]
 [1 1 1]]
  sample 2:
    input     > [5 6 7 2 9 2 2 6]
    predicted > [[5 5 5]
 [6 6 6]
 [7 2 7]
 [2 7 2]
 [9 9 9]
 [2 7 2]
 [2 2 6]
 [6 6 2]
 [1 1 1]]
  sample 3:
    input     > [2 7 5 8 9 3 5 0]
    predicted > [[ 2  2  2]
 [ 7  7  7]
 [ 5  5  8]
 [ 8  9  5]
 [ 9  8  9]
 [ 3  3  5]
 [ 5  5  3]
 [ 1  1  1]
 [-1 -1 -1]]

batch 4600
  minibatch loss: 0.06325320154428482
  sample 1:
    input     > [3 5 7 9 0 0 0]
    predicted > [[ 3  7  3]
 [ 5  3  7]
 [ 7  9  5]
 [ 9  5  9]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [9 6 8 2 0 0 0]
    predicted > [[ 9  9  9]
 [ 6  8  6]
 [ 8  6  6]
 [ 2  2  5]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [9 9 6 0 0 0 0]
    predicted > [[ 9  9  9]
 [ 9  9  9]
 [ 6  6  6]
 [ 1  2  6]
 [-1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 4800
  minibatch loss: 0.08858782052993774
  sample 1:
    input     > [6 5 3 0 0 0 0 0]
    predicted > [[ 6  6  6]
 [ 5  5  7]
 [ 3  6  5]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [4 9 4 6 9 4 0 0]
    predicted > [[ 4  4  4]
 [ 9  9  9]
 [ 4  4  6]
 [ 6  6  4]
 [ 9  4  9]
 [ 4  9  4]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [5 5 4 9 9 0 0 0]
    predicted > [[ 5  5  5]
 [ 5  5  5]
 [ 4  4  4]
 [ 9  9  9]
 [ 9  5  4]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 5000
  minibatch loss: 0.07043668627738953
  sample 1:
    input     > [3 5 3 8 4 7 4 0]
    predicted > [[ 3  3  3]
 [ 5  5  5]
 [ 3  3  3]
 [ 8  8  4]
 [ 4  4  8]
 [ 7  4  7]
 [ 4  7  8]
 [ 1  1  1]
 [-1 -1 -1]]
  sample 2:
    input     > [4 9 2 2 5 4 0 0]
    predicted > [[ 4  4  4]
 [ 9  2  2]
 [ 2  9  9]
 [ 2  5  9]
 [ 5  4  5]
 [ 4  2  4]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [4 4 9 7 3 4 9 4]
    predicted > [[4 4 4]
 [4 4 4]
 [9 7 9]
 [7 9 7]
 [3 9 3]
 [4 4 4]
 [9 3 4]
 [4 4 9]
 [1 1 1]]

batch 5200
  minibatch loss: 0.09076255559921265
  sample 1:
    input     > [7 6 9 2 5 6 3 0]
    predicted > [[ 7  7  7]
 [ 6  6  2]
 [ 9  9  6]
 [ 2  2  9]
 [ 5  5  8]
 [ 6  6  3]
 [ 3  7  5]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [5 4 9 4 2 3 3 0]
    predicted > [[ 5  5  4]
 [ 4  4  5]
 [ 9  9  9]
 [ 4  2  2]
 [ 2  4  9]
 [ 3  3  6]
 [ 3  3  3]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [6 6 5 6 2 6 8 2]
    predicted > [[ 6  6  6]
 [ 6  6  3]
 [ 5  5  6]
 [ 6  6  5]
 [ 2  2  6]
 [ 6  6  4]
 [ 8  8  2]
 [ 2  6  6]
 [ 1  1  5]
 [-1 -1  1]]

batch 5400
  minibatch loss: 0.06568838655948639
  sample 1:
    input     > [5 6 3 8 9 2 0 0]
    predicted > [[ 5  5  5]
 [ 6  6  3]
 [ 3  3  6]
 [ 8  8  8]
 [ 9  2  4]
 [ 2  9  9]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [6 2 4 3 9 7 8 8]
    predicted > [[6 6 6]
 [2 2 2]
 [4 4 3]
 [3 3 4]
 [9 9 4]
 [7 8 9]
 [8 7 7]
 [8 8 8]
 [1 1 1]]
  sample 3:
    input     > [4 5 4 4 5 5 8 0]
    predicted > [[ 4  4  4]
 [ 5  4  4]
 [ 4  5  5]
 [ 4  5  5]
 [ 5  5  4]
 [ 5  4  5]
 [ 8  8  8]
 [ 1  1  1]
 [-1 -1 -1]]

batch 5600
  minibatch loss: 0.10547281056642532
  sample 1:
    input     > [5 9 5 4 9 0 0 0]
    predicted > [[ 5  5  5]
 [ 9  5  9]
 [ 5  9  5]
 [ 4  4  4]
 [ 9  9  4]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [5 9 2 0 0 0 0 0]
    predicted > [[ 5  9  5]
 [ 9  5  2]
 [ 2  2  9]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [2 3 6 5 7 8 5 2]
    predicted > [[2 2 2]
 [3 3 3]
 [6 6 6]
 [5 5 5]
 [7 7 8]
 [8 8 7]
 [5 2 5]
 [2 5 2]
 [1 1 1]]

batch 5800
  minibatch loss: 0.033192142844200134
  sample 1:
    input     > [3 5 4 6 8 0 0]
    predicted > [[ 3  3  3]
 [ 5  5  5]
 [ 4  6  6]
 [ 6  4  4]
 [ 8  5  8]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [4 4 3 2 5 8 0]
    predicted > [[ 4  4  4]
 [ 4  4  2]
 [ 3  3  4]
 [ 2  2  3]
 [ 5  5  8]
 [ 8  4  5]
 [ 1  1  1]
 [-1 -1 -1]]
  sample 3:
    input     > [9 6 2 0 0 0 0]
    predicted > [[ 9  9  2]
 [ 6  2  9]
 [ 2  6  6]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]

batch 6000
  minibatch loss: 0.05001354217529297
  sample 1:
    input     > [3 8 5 7 0 0 0 0]
    predicted > [[ 3  3  3]
 [ 8  8  8]
 [ 5  7  5]
 [ 7  5  3]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 2:
    input     > [5 2 9 3 4 3 0 0]
    predicted > [[ 5  5  2]
 [ 2  2  5]
 [ 9  3  9]
 [ 3  9  3]
 [ 4  4  4]
 [ 3  3  3]
 [ 1  1  1]
 [-1 -1 -1]
 [-1 -1 -1]]
  sample 3:
    input     > [4 9 7 6 0 0 0 0]
    predicted > [[ 4  4  4]
 [ 9  9  9]
 [ 7  6  3]
 [ 6  7  7]
 [ 1  1  6]
 [-1 -1  1]
 [-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
%matplotlib inline
import matplotlib.pyplot as plt
plt.plot(loss_track)
print('loss {:.4f} after {} examples (batch_size={})'.format(loss_track[-1], 
                                                             len(loss_track)*batch_size, batch_size))
loss 0.0543 after 60010 examples (batch_size=10)

[图片上传失败...(image-c7495d-1544602586825)]


最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,014评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,796评论 3 386
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,484评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,830评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,946评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,114评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,182评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,927评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,369评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,678评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,832评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,533评论 4 335
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,166评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,885评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,128评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,659评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,738评论 2 351

推荐阅读更多精彩内容