总结一下做bi-directional LSTM时,tensorflow的关键函数bidirectional_dynamic_rnn function. 此函数的输入输出特性。
sample code:
# Create input data
X = np.random.randn(2,10,8)
# The second example is of length 6
X[1,6:] = 0
X_lengths[10,6]
cell = tf.nn.rnn_cell.LSTMCell(num_units=20,state_is_tuple=True)
outputs,states = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell,cell_bw=cell,dtype=tf.float64,sequence_length=X_lengths,inputs=X)
output_fw, output_bw = outputs
states_fw, states_bw = states
sequence_length参数的提供。这个参数是一个向量,给的是Input里面,每一个batch中,每一个Data的第一维。比如上面Input有两个batch,每一个是(10, 8) 的dimension,于是sequence_length就可以是[10, 10]。这个参数很重要,给的不对就会出错。
LSTMStateTuple, LSTM state tuple是tensorflow用来表示state的一种格式定义,tensorflow的state_size, zero_state, and output state都是用LSTMStateTuple表示的。LSTMStateTuple包括(c, h) => c 是hidden state,h是output state;
就上面的例子,我们来print state_fw(forwarding cell的state)的c, and h
with tf.Session() as sess:
sess.run(tf.global_variables_initializer());
states_shape = tf.shape(states);
print(states_shape.eval());
c, h = states_fw;
print('print c');
print(sess.run(c));
print('print c1');
print(sess.run(c[1]));
c的输出
[[ 0.10020639 0.04123133 0.1607345 0.04516676 0.14993827 0.04191888
0.37377536 0.155445 0.0250121 0.19496255 0.14872166 -0.27649689
-0.04177214 -0.64201794 0.33243907 0.19359499 -0.36642676 0.21112514
-0.22745071 -0.26275169]
[-0.11904546 0.31178861 -0.07443244 0.09425232 0.49845703 -0.03854894
-0.19182923 0.21352997 -0.46194925 0.16779708 -0.33193142 -0.19566778
-0.06612427 -0.11908099 -0.00726331 0.33335469 0.01462403 -0.58337865
-0.07321394 -0.05640467]]
h的输出
[[ 0.15347487 -0.11674097 -0.09588729 0.00890338 0.12735097 0.09076786
0.06652527 0.00207616 0.0500825 0.12345199 -0.01998148 0.08035006
0.13353144 -0.01102215 -0.09175959 0.1455235 0.11431857 0.15356262
-0.0725327 -0.03418285]
[ 0.15409209 0.09243608 -0.05506054 -0.07781891 0.0971408 0.06453969
0.03290611 -0.09163908 0.01231368 0.15045137 0.06517371 0.1267243
0.02084229 -0.16214882 -0.20116859 0.24669899 0.04287307 0.04801212
-0.02658258 -0.10132215]]
c, h都是(2, 20),2是因为有两个batch input,20是因为cell unit is 20. 所以每对于一个batch input,lstm 都给一个forward state,forwarding state分成c (hidden state), 和h (output state). 然后分别是unit size的vector。backward state也是一样。
- Output
Output也分为forward cell和backward_cell的,我们就看forward cell。
[[[-1.06669103e-01 -3.88756185e-02 2.60744107e-02 2.43657585e-03
9.88481327e-02 -1.15583728e-01 -1.56563918e-01 1.58782306e-01
4.85588806e-02 -5.77454750e-02 3.36419355e-02 -1.98287352e-02
-1.10931715e-01 3.82711717e-02 -9.20855234e-02 -9.49620258e-02
6.52089598e-02 -3.53805118e-02 1.30414250e-03 6.59167148e-02]
[ 1.59046099e-02 4.41911904e-02 -1.85633770e-02 -1.02646872e-01
5.67889560e-02 -1.80578232e-01 1.16289962e-02 3.43832302e-02
-9.97417632e-02 8.52767259e-03 4.19908236e-02 -2.92100595e-02
-7.40990631e-03 -5.57406104e-02 -9.63055482e-03 6.64506140e-02
-1.60809539e-01 -1.13207020e-01 6.67947362e-02 -1.01494835e-01]
[ 1.29653547e-01 1.64123612e-01 1.14319183e-01 -3.13978750e-01
3.21875813e-02 -1.20610558e-01 1.65674751e-01 9.67606455e-02
-7.53250048e-02 1.64018976e-01 8.14044036e-02 -2.66330649e-01
-8.32540716e-02 -2.30083245e-01 -3.66429645e-02 2.69508256e-01
-3.16302908e-01 -7.56739776e-02 -7.03734257e-02 -7.27750202e-02]
[ 2.10778175e-01 1.29119904e-01 -5.28395789e-02 -1.56958078e-01
-5.97345897e-02 1.20157188e-01 2.17347619e-01 5.99727875e-02
-1.64570565e-01 8.08612044e-02 2.07909278e-02 -2.07703283e-01
-1.63712849e-02 -2.01749788e-01 1.14587708e-01 2.06933175e-01
-2.55334961e-01 -6.03161794e-02 -1.97578049e-01 -1.72242306e-01]
[ 3.00083183e-01 1.51446013e-01 4.82394269e-02 -3.39348109e-01
-3.38322196e-01 -2.43085569e-01 1.82949930e-01 2.40813002e-01
-2.80735872e-01 2.41043685e-01 -9.54472693e-02 -2.93748317e-01
5.88980390e-02 -1.35271864e-01 2.89699583e-01 8.75721403e-02
-3.55877522e-01 -1.33211501e-01 -7.57881317e-02 -2.64836250e-01]
[ 2.41987682e-01 1.47695318e-02 -5.03866150e-02 -1.44482469e-01
-1.44640164e-02 -6.75506747e-02 2.32746312e-01 1.47519780e-01
-6.56321553e-02 1.60696093e-01 5.45594683e-03 -1.88477690e-01
5.45150185e-02 -1.77628408e-01 1.05268972e-01 1.41610215e-01
-1.60580096e-01 -7.24836242e-03 -1.00759851e-01 -1.47514166e-01]
[ 2.12808506e-01 -3.27166227e-02 4.31225953e-03 -1.02816763e-01
3.25901490e-03 3.66293909e-02 1.54212310e-01 1.69800784e-01
1.10284434e-02 1.74149175e-01 3.41514708e-02 -1.91123912e-01
3.02006378e-02 -1.99656590e-01 2.26262571e-02 2.52412919e-01
-7.68398775e-02 5.94911484e-02 -1.31153846e-01 -1.20808211e-01]
[ 3.81763504e-01 2.68575596e-02 -1.08793781e-01 -8.00019483e-02
-3.68294635e-02 1.71446728e-01 1.18992211e-01 -2.13071169e-02
-4.61473814e-02 1.82351966e-01 -8.44481138e-02 1.19407754e-02
4.08584125e-02 -1.80411471e-01 -7.74698125e-03 1.93662041e-01
-1.20557645e-01 -2.10084183e-02 -2.27600119e-01 -1.63846952e-01]
[ 1.64898924e-01 6.37549641e-02 1.66957306e-02 -1.59360332e-01
-1.51426048e-01 1.28056643e-01 2.85171791e-01 3.50425360e-04
-1.99342025e-02 1.89266634e-01 -6.53912049e-02 -6.65443778e-02
7.44334992e-02 4.85016031e-02 9.49079271e-02 3.24700401e-01
-1.00750007e-01 2.18841138e-02 -1.61316038e-01 1.52122726e-02]
[ 1.53474875e-01 -1.16740969e-01 -9.58872865e-02 8.90338491e-03
1.27350972e-01 9.07678551e-02 6.65252728e-02 2.07615713e-03
5.00825032e-02 1.23451987e-01 -1.99814751e-02 8.03500562e-02
1.33531444e-01 -1.10221475e-02 -9.17595887e-02 1.45523503e-01
1.14318569e-01 1.53562622e-01 -7.25326998e-02 -3.41828502e-02]]
[[ 2.24818909e-01 2.21235678e-02 1.01460432e-01 -1.41914365e-01
-1.21404939e-01 7.36078879e-02 1.38471242e-01 -1.17533437e-01
-5.21530141e-03 1.67706170e-01 -7.17727515e-02 -1.06750419e-01
7.13189845e-02 -9.07184818e-02 2.11111214e-02 1.81368716e-01
-1.46839530e-01 2.14554598e-02 -7.90004557e-02 8.87259097e-02]
[ 2.11738134e-01 2.06868083e-02 -1.42999066e-01 -5.44685789e-02
-6.31460261e-02 1.86872216e-01 1.22599483e-01 -1.82293974e-01
8.76017957e-02 7.64068221e-02 -5.74839315e-02 8.27909362e-02
5.49907143e-02 -8.06081683e-02 -4.65603130e-02 1.07644840e-01
-8.45501653e-02 -4.02021538e-02 -8.38841808e-02 3.63420987e-02]
[ 2.65380968e-01 -5.11942699e-03 -1.10961564e-02 -1.96348422e-01
-1.11433399e-01 1.53275799e-02 6.00570999e-02 -2.05297778e-01
8.52545915e-02 1.97091206e-01 -7.42037228e-02 1.25797496e-01
1.08283714e-01 -1.29675158e-01 -1.32684022e-01 1.19353210e-01
-1.22913400e-01 -1.09450277e-01 -1.97762286e-02 2.60753532e-02]
[ 2.29800970e-01 6.64311636e-02 -3.45172340e-02 -1.56474836e-01
-4.81899131e-02 9.00044045e-02 8.26513916e-02 -1.33626283e-01
1.37496640e-01 1.72760619e-01 9.74954132e-03 2.40818003e-02
1.28755599e-02 -2.39148477e-01 -2.11945339e-01 1.92631382e-01
-1.23300797e-01 -1.74345945e-02 -7.96618285e-02 -6.94683079e-03]
[ 1.95140846e-01 1.06431901e-01 -9.20244228e-02 -2.09311995e-01
-5.64830252e-03 9.53098517e-02 4.49136154e-02 -1.55642596e-01
-4.00256764e-02 1.03820451e-01 -9.09035922e-02 1.30894101e-01
-1.71891357e-02 -8.76164608e-02 -8.98778574e-02 7.59155122e-02
-7.54771617e-02 -1.34889843e-01 -4.58820217e-02 -7.81068266e-02]
[ 1.54092089e-01 9.24360827e-02 -5.50605385e-02 -7.78189060e-02
9.71408042e-02 6.45396884e-02 3.29061070e-02 -9.16390804e-02
1.23136831e-02 1.50451370e-01 6.51737096e-02 1.26724300e-01
2.08422868e-02 -1.62148824e-01 -2.01168589e-01 2.46698990e-01
4.28730731e-02 4.80121224e-02 -2.65825827e-02 -1.01322148e-01]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]]
output_fw的size是(2, 10, 20)。所以对于一个batch中的每一个input vector (本例中每一个batch是10 × 8,所以有10个vector),lstm都输出一个20的vector(20 is unit size)。
然后我们可以发现,之前state中的h,就是output中的最后一组vector!第二组h不是0, 是因为sequance_length参数 (也就是走了6个lstm就结束了)!
所以如果只需要take lstm中的final output,而不在乎中间过程(比如建一个classifier,而不是seq2seq)。直接take state output就可以了。