这个函数我已经查了三遍不止了,每次看网上说的那些都没看明白。也怪我自己不动脑子,现成的代码我复制粘贴了,连改都不改这怎么能懂其中的意思呢。我一开始看了这篇博客,觉得写得还真好,详细解读了内部,什么两次reverse后再拼接啥的(如下图)
理解bilstm的关键在于反向lstm并不是像看起来那样从右往左传递信息,而是先将原来的输入逆序排列输入到正向lstm中,再将得到的输出结果逆序排列,便得到了所谓的“反向lstm”的输出
例一
import tensorflow as tf
import numpy as np
# 创建一个batch为2的三维数组,值全为1
X = np.ones((2, 10, 8))
# 指定每个batch的真实长度,这是bidirectional_dynamic_rnn中的一个参数,如果不指定,默认为batch的最大长度
X_lengths = [10, 10]
cell = tf.nn.rnn_cell.LSTMCell(num_units=20, state_is_tuple=True)
outputs, states = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell, cell_bw=cell, dtype=tf.float64,
sequence_length=X_lengths, inputs=X)
# bidirectional_dynamic_rnn输出两个元组,第一个元组为输出值元组,第二个为状态元组
output_fw, output_bw = outputs
states_fw, states_bw = states
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 状态元组中的每个元素也还是一个元组,包含两个元素c和h,c是hidden state,h是output state
c_fw, h_fw = states_fw
c_bw, h_bw = states_bw
print('***********print h_fw')
print(sess.run(h_fw))
print('***********print h_bw')
print(sess.run(h_bw))
print('***********print c_fw')
print(sess.run(c_fw))
print('***********print c_bw')
print(sess.run(c_bw))
print('***********print output_fw')
print(sess.run(output_fw))
print('***********print output_bw')
print(sess.run(output_bw))
输出结果如下
***********print h_fw
[[ 0.30761863 -0.08809378 0.05676232 -0.12512271 -0.27447248 -0.37795411
-0.38890486 0.24909768 0.24172305 -0.27529398 -0.18520042 -0.39698435
0.0636031 -0.26719873 0.18002771 0.09966553 0.02702067 -0.29513662
-0.09883436 -0.41679873]
[ 0.30761863 -0.08809378 0.05676232 -0.12512271 -0.27447248 -0.37795411
-0.38890486 0.24909768 0.24172305 -0.27529398 -0.18520042 -0.39698435
0.0636031 -0.26719873 0.18002771 0.09966553 0.02702067 -0.29513662
-0.09883436 -0.41679873]]
***********print h_bw
[[ 0.30761863 -0.08809378 0.05676232 -0.12512271 -0.27447248 -0.37795411
-0.38890486 0.24909768 0.24172305 -0.27529398 -0.18520042 -0.39698435
0.0636031 -0.26719873 0.18002771 0.09966553 0.02702067 -0.29513662
-0.09883436 -0.41679873]
[ 0.30761863 -0.08809378 0.05676232 -0.12512271 -0.27447248 -0.37795411
-0.38890486 0.24909768 0.24172305 -0.27529398 -0.18520042 -0.39698435
0.0636031 -0.26719873 0.18002771 0.09966553 0.02702067 -0.29513662
-0.09883436 -0.41679873]]
***********print c_fw
[[ 0.63050086 -0.14962187 0.17986396 -0.27891692 -0.83289912 -1.38633692
-0.92382506 0.47125799 0.71944087 -0.55796188 -0.4651182 -1.34933772
0.1275259 -0.79503812 0.44752065 0.23413849 0.04473469 -0.82337067
-0.20446274 -0.80014184]
[ 0.63050086 -0.14962187 0.17986396 -0.27891692 -0.83289912 -1.38633692
-0.92382506 0.47125799 0.71944087 -0.55796188 -0.4651182 -1.34933772
0.1275259 -0.79503812 0.44752065 0.23413849 0.04473469 -0.82337067
-0.20446274 -0.80014184]]
***********print c_bw
[[ 0.63050086 -0.14962187 0.17986396 -0.27891692 -0.83289912 -1.38633692
-0.92382506 0.47125799 0.71944087 -0.55796188 -0.4651182 -1.34933772
0.1275259 -0.79503812 0.44752065 0.23413849 0.04473469 -0.82337067
-0.20446274 -0.80014184]
[ 0.63050086 -0.14962187 0.17986396 -0.27891692 -0.83289912 -1.38633692
-0.92382506 0.47125799 0.71944087 -0.55796188 -0.4651182 -1.34933772
0.1275259 -0.79503812 0.44752065 0.23413849 0.04473469 -0.82337067
-0.20446274 -0.80014184]]
***********print output_fw
[[[ 0.09213119 -0.00920132 -0.00480822 -0.06614824 -0.10940994
-0.13237866 -0.06281013 0.08518363 0.07377088 -0.00439786
-0.03767817 -0.1139195 0.01487621 -0.05797102 0.09557128
0.00161788 -0.06577087 -0.08184791 -0.05674727 -0.22975524]
[ 0.15855342 -0.02402754 -0.00301006 -0.09785019 -0.17732144
-0.22491602 -0.1349754 0.13963072 0.12761351 -0.02347015
-0.07433172 -0.19753812 0.02575886 -0.11248398 0.15490154
0.0149613 -0.10369941 -0.14653386 -0.08158803 -0.33888571]
[ 0.20518182 -0.03977035 0.00220437 -0.11144359 -0.21780033
-0.28412027 -0.20131872 0.17511601 0.16677614 -0.05516816
-0.10460625 -0.25808741 0.03480547 -0.15743644 0.18638093
0.03129929 -0.11566092 -0.19303162 -0.09320168 -0.38787705]
[ 0.23793991 -0.0540012 0.00920335 -0.11647142 -0.24172754
-0.32073244 -0.25595129 0.19880824 0.19462017 -0.09352138
-0.127896 -0.30146283 0.04269355 -0.1917475 0.20078535
0.04700247 -0.10889583 -0.22526698 -0.09881287 -0.40967251]
[ 0.26108442 -0.06564608 0.0171266 -0.1179926 -0.25581146
-0.34330216 -0.2981817 0.21500322 0.21370547 -0.13296083
-0.14530196 -0.33239821 0.04941772 -0.21687859 0.20531662
0.06073976 -0.09068629 -0.24765488 -0.10130358 -0.41883299]
[ 0.27751227 -0.07446752 0.02544216 -0.11856258 -0.26406581
-0.35734935 -0.32964097 0.22637017 0.22625771 -0.16993574
-0.15821977 -0.35452075 0.05483244 -0.23490698 0.20414891
0.07220202 -0.06670979 -0.26350145 -0.10203061 -0.42187412]
[ 0.28920624 -0.08068884 0.03378522 -0.1193405 -0.26889405
-0.36621415 -0.35256554 0.23458705 0.2341066 -0.20278181
-0.1678622 -0.37046562 0.05886875 -0.24768901 0.19974503
0.08152418 -0.04094146 -0.27499945 -0.10175494 -0.4219268 ]
[ 0.2975349 -0.08473076 0.04189705 -0.12071914 -0.27173138
-0.37188732 -0.3690315 0.24071022 0.2386569 -0.23108288
-0.17515872 -0.3820775 0.061575 -0.25666779 0.19364629
0.08900016 -0.01591296 -0.28355518 -0.10094973 -0.42059815]
[ 0.30344973 -0.08705629 0.04959632 -0.12270135 -0.27342634
-0.37555987 -0.38071883 0.24540596 0.24094529 -0.25508262
-0.18078124 -0.39062611 0.06309015 -0.26291155 0.18685659
0.09494983 0.0069545 -0.29007129 -0.09991513 -0.41874011]
[ 0.30761863 -0.08809378 0.05676232 -0.12512271 -0.27447248
-0.37795411 -0.38890486 0.24909768 0.24172305 -0.27529398
-0.18520042 -0.39698435 0.0636031 -0.26719873 0.18002771
0.09966553 0.02702067 -0.29513662 -0.09883436 -0.41679873]]
[[ 0.09213119 -0.00920132 -0.00480822 -0.06614824 -0.10940994
-0.13237866 -0.06281013 0.08518363 0.07377088 -0.00439786
-0.03767817 -0.1139195 0.01487621 -0.05797102 0.09557128
0.00161788 -0.06577087 -0.08184791 -0.05674727 -0.22975524]
[ 0.15855342 -0.02402754 -0.00301006 -0.09785019 -0.17732144
-0.22491602 -0.1349754 0.13963072 0.12761351 -0.02347015
-0.07433172 -0.19753812 0.02575886 -0.11248398 0.15490154
0.0149613 -0.10369941 -0.14653386 -0.08158803 -0.33888571]
[ 0.20518182 -0.03977035 0.00220437 -0.11144359 -0.21780033
-0.28412027 -0.20131872 0.17511601 0.16677614 -0.05516816
-0.10460625 -0.25808741 0.03480547 -0.15743644 0.18638093
0.03129929 -0.11566092 -0.19303162 -0.09320168 -0.38787705]
[ 0.23793991 -0.0540012 0.00920335 -0.11647142 -0.24172754
-0.32073244 -0.25595129 0.19880824 0.19462017 -0.09352138
-0.127896 -0.30146283 0.04269355 -0.1917475 0.20078535
0.04700247 -0.10889583 -0.22526698 -0.09881287 -0.40967251]
[ 0.26108442 -0.06564608 0.0171266 -0.1179926 -0.25581146
-0.34330216 -0.2981817 0.21500322 0.21370547 -0.13296083
-0.14530196 -0.33239821 0.04941772 -0.21687859 0.20531662
0.06073976 -0.09068629 -0.24765488 -0.10130358 -0.41883299]
[ 0.27751227 -0.07446752 0.02544216 -0.11856258 -0.26406581
-0.35734935 -0.32964097 0.22637017 0.22625771 -0.16993574
-0.15821977 -0.35452075 0.05483244 -0.23490698 0.20414891
0.07220202 -0.06670979 -0.26350145 -0.10203061 -0.42187412]
[ 0.28920624 -0.08068884 0.03378522 -0.1193405 -0.26889405
-0.36621415 -0.35256554 0.23458705 0.2341066 -0.20278181
-0.1678622 -0.37046562 0.05886875 -0.24768901 0.19974503
0.08152418 -0.04094146 -0.27499945 -0.10175494 -0.4219268 ]
[ 0.2975349 -0.08473076 0.04189705 -0.12071914 -0.27173138
-0.37188732 -0.3690315 0.24071022 0.2386569 -0.23108288
-0.17515872 -0.3820775 0.061575 -0.25666779 0.19364629
0.08900016 -0.01591296 -0.28355518 -0.10094973 -0.42059815]
[ 0.30344973 -0.08705629 0.04959632 -0.12270135 -0.27342634
-0.37555987 -0.38071883 0.24540596 0.24094529 -0.25508262
-0.18078124 -0.39062611 0.06309015 -0.26291155 0.18685659
0.09494983 0.0069545 -0.29007129 -0.09991513 -0.41874011]
[ 0.30761863 -0.08809378 0.05676232 -0.12512271 -0.27447248
-0.37795411 -0.38890486 0.24909768 0.24172305 -0.27529398
-0.18520042 -0.39698435 0.0636031 -0.26719873 0.18002771
0.09966553 0.02702067 -0.29513662 -0.09883436 -0.41679873]]]
***********print output_bw
[[[ 0.30761863 -0.08809378 0.05676232 -0.12512271 -0.27447248
-0.37795411 -0.38890486 0.24909768 0.24172305 -0.27529398
-0.18520042 -0.39698435 0.0636031 -0.26719873 0.18002771
0.09966553 0.02702067 -0.29513662 -0.09883436 -0.41679873]
[ 0.30344973 -0.08705629 0.04959632 -0.12270135 -0.27342634
-0.37555987 -0.38071883 0.24540596 0.24094529 -0.25508262
-0.18078124 -0.39062611 0.06309015 -0.26291155 0.18685659
0.09494983 0.0069545 -0.29007129 -0.09991513 -0.41874011]
[ 0.2975349 -0.08473076 0.04189705 -0.12071914 -0.27173138
-0.37188732 -0.3690315 0.24071022 0.2386569 -0.23108288
-0.17515872 -0.3820775 0.061575 -0.25666779 0.19364629
0.08900016 -0.01591296 -0.28355518 -0.10094973 -0.42059815]
[ 0.28920624 -0.08068884 0.03378522 -0.1193405 -0.26889405
-0.36621415 -0.35256554 0.23458705 0.2341066 -0.20278181
-0.1678622 -0.37046562 0.05886875 -0.24768901 0.19974503
0.08152418 -0.04094146 -0.27499945 -0.10175494 -0.4219268 ]
[ 0.27751227 -0.07446752 0.02544216 -0.11856258 -0.26406581
-0.35734935 -0.32964097 0.22637017 0.22625771 -0.16993574
-0.15821977 -0.35452075 0.05483244 -0.23490698 0.20414891
0.07220202 -0.06670979 -0.26350145 -0.10203061 -0.42187412]
[ 0.26108442 -0.06564608 0.0171266 -0.1179926 -0.25581146
-0.34330216 -0.2981817 0.21500322 0.21370547 -0.13296083
-0.14530196 -0.33239821 0.04941772 -0.21687859 0.20531662
0.06073976 -0.09068629 -0.24765488 -0.10130358 -0.41883299]
[ 0.23793991 -0.0540012 0.00920335 -0.11647142 -0.24172754
-0.32073244 -0.25595129 0.19880824 0.19462017 -0.09352138
-0.127896 -0.30146283 0.04269355 -0.1917475 0.20078535
0.04700247 -0.10889583 -0.22526698 -0.09881287 -0.40967251]
[ 0.20518182 -0.03977035 0.00220437 -0.11144359 -0.21780033
-0.28412027 -0.20131872 0.17511601 0.16677614 -0.05516816
-0.10460625 -0.25808741 0.03480547 -0.15743644 0.18638093
0.03129929 -0.11566092 -0.19303162 -0.09320168 -0.38787705]
[ 0.15855342 -0.02402754 -0.00301006 -0.09785019 -0.17732144
-0.22491602 -0.1349754 0.13963072 0.12761351 -0.02347015
-0.07433172 -0.19753812 0.02575886 -0.11248398 0.15490154
0.0149613 -0.10369941 -0.14653386 -0.08158803 -0.33888571]
[ 0.09213119 -0.00920132 -0.00480822 -0.06614824 -0.10940994
-0.13237866 -0.06281013 0.08518363 0.07377088 -0.00439786
-0.03767817 -0.1139195 0.01487621 -0.05797102 0.09557128
0.00161788 -0.06577087 -0.08184791 -0.05674727 -0.22975524]]
[[ 0.30761863 -0.08809378 0.05676232 -0.12512271 -0.27447248
-0.37795411 -0.38890486 0.24909768 0.24172305 -0.27529398
-0.18520042 -0.39698435 0.0636031 -0.26719873 0.18002771
0.09966553 0.02702067 -0.29513662 -0.09883436 -0.41679873]
[ 0.30344973 -0.08705629 0.04959632 -0.12270135 -0.27342634
-0.37555987 -0.38071883 0.24540596 0.24094529 -0.25508262
-0.18078124 -0.39062611 0.06309015 -0.26291155 0.18685659
0.09494983 0.0069545 -0.29007129 -0.09991513 -0.41874011]
[ 0.2975349 -0.08473076 0.04189705 -0.12071914 -0.27173138
-0.37188732 -0.3690315 0.24071022 0.2386569 -0.23108288
-0.17515872 -0.3820775 0.061575 -0.25666779 0.19364629
0.08900016 -0.01591296 -0.28355518 -0.10094973 -0.42059815]
[ 0.28920624 -0.08068884 0.03378522 -0.1193405 -0.26889405
-0.36621415 -0.35256554 0.23458705 0.2341066 -0.20278181
-0.1678622 -0.37046562 0.05886875 -0.24768901 0.19974503
0.08152418 -0.04094146 -0.27499945 -0.10175494 -0.4219268 ]
[ 0.27751227 -0.07446752 0.02544216 -0.11856258 -0.26406581
-0.35734935 -0.32964097 0.22637017 0.22625771 -0.16993574
-0.15821977 -0.35452075 0.05483244 -0.23490698 0.20414891
0.07220202 -0.06670979 -0.26350145 -0.10203061 -0.42187412]
[ 0.26108442 -0.06564608 0.0171266 -0.1179926 -0.25581146
-0.34330216 -0.2981817 0.21500322 0.21370547 -0.13296083
-0.14530196 -0.33239821 0.04941772 -0.21687859 0.20531662
0.06073976 -0.09068629 -0.24765488 -0.10130358 -0.41883299]
[ 0.23793991 -0.0540012 0.00920335 -0.11647142 -0.24172754
-0.32073244 -0.25595129 0.19880824 0.19462017 -0.09352138
-0.127896 -0.30146283 0.04269355 -0.1917475 0.20078535
0.04700247 -0.10889583 -0.22526698 -0.09881287 -0.40967251]
[ 0.20518182 -0.03977035 0.00220437 -0.11144359 -0.21780033
-0.28412027 -0.20131872 0.17511601 0.16677614 -0.05516816
-0.10460625 -0.25808741 0.03480547 -0.15743644 0.18638093
0.03129929 -0.11566092 -0.19303162 -0.09320168 -0.38787705]
[ 0.15855342 -0.02402754 -0.00301006 -0.09785019 -0.17732144
-0.22491602 -0.1349754 0.13963072 0.12761351 -0.02347015
-0.07433172 -0.19753812 0.02575886 -0.11248398 0.15490154
0.0149613 -0.10369941 -0.14653386 -0.08158803 -0.33888571]
[ 0.09213119 -0.00920132 -0.00480822 -0.06614824 -0.10940994
-0.13237866 -0.06281013 0.08518363 0.07377088 -0.00439786
-0.03767817 -0.1139195 0.01487621 -0.05797102 0.09557128
0.00161788 -0.06577087 -0.08184791 -0.05674727 -0.22975524]]]
从输出中可以看出以下几点
-
每个batch的hidden state和output_state都只包含一个tensor,在正向lstm中,h_fw就是output的最后一个tensor,在反向lstm中,h_bw就是output的第一个tensor。hidden_state和output_state是同样的机制,它们都表示最后的状态。
再看结果,可以看到output_bw和output_fw是完全相反的,如果只做一次reverse,应该是和原来相同(因为值全为1,所以reverse后还是和原来的输入相等)。
例二
直接上代码
import tensorflow as tf
import numpy as np
# 创建一个batch为2的三维数组
X = np.array([[[1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]], [[3, 4, 5, 6, 7], [2, 3, 4, 5, 6], [1, 2, 3, 4, 5]]],
dtype=np.float64)
# 指定每个batch的真实长度,这是bidirectional_dynamic_rnn中的一个参数,如果不指定,默认为batch的最大长度
X_lengths = [3, 3]
cell = tf.nn.rnn_cell.LSTMCell(num_units=10, state_is_tuple=True)
outputs, states = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell, cell_bw=cell, dtype=tf.float64,
sequence_length=X_lengths, inputs=X)
# bidirectional_dynamic_rnn输出两个元组,第一个元组为输出值元组,第二个为状态元组
output_fw, output_bw = outputs
states_fw, states_bw = states
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 状态元组中的每个元素也还是一个元组,包含两个元素c和h,c是hidden state,h是output state
c_fw, h_fw = states_fw
c_bw, h_bw = states_bw
print('***********print h_fw')
print(sess.run(h_fw))
print('***********print h_bw')
print(sess.run(h_bw))
print('***********print c_fw')
print(sess.run(c_fw))
print('***********print c_bw')
print(sess.run(c_bw))
print('***********print output_fw')
print(sess.run(output_fw))
print('***********print output_bw')
print(sess.run(output_bw))
输出结果
***********print h_fw
[[-0.90904258 -0.72417344 -0.00340704 0.47999107 -0.6172267 0.16119028
0.96314568 0.59865986 -0.23445085 0.14730728]
[-0.80015177 -0.65880448 -0.01965226 0.42119995 -0.43309781 0.18892933
0.90951544 0.43914896 -0.26462378 0.1727273 ]]
***********print h_bw
[[-0.80015177 -0.65880448 -0.01965226 0.42119995 -0.43309781 0.18892933
0.90951544 0.43914896 -0.26462378 0.1727273 ]
[-0.90904258 -0.72417344 -0.00340704 0.47999107 -0.6172267 0.16119028
0.96314568 0.59865986 -0.23445085 0.14730728]]
***********print c_fw
[[-2.00750908 -1.19615351 -0.28910374 0.99731872 -1.66288234 1.35553407
2.10854366 0.72069906 -0.27357135 0.9675592 ]
[-1.97285668 -1.16508798 -0.28735088 1.00169352 -1.40325954 1.3543562
1.94142731 0.52746365 -0.34427745 0.93367482]]
***********print c_bw
[[-1.97285668 -1.16508798 -0.28735088 1.00169352 -1.40325954 1.3543562
1.94142731 0.52746365 -0.34427745 0.93367482]
[-2.00750908 -1.19615351 -0.28910374 0.99731872 -1.66288234 1.35553407
2.10854366 0.72069906 -0.27357135 0.9675592 ]]
***********print output_fw
[[[-0.52647366 -0.41398232 -0.01570638 0.16715023 -0.25770299
0.11789769 0.60452984 0.19651944 -0.07048413 0.18267517]
[-0.8034695 -0.62766681 -0.00827474 0.34252566 -0.47038781
0.15523849 0.87603592 0.43703746 -0.16224657 0.18179719]
[-0.90904258 -0.72417344 -0.00340704 0.47999107 -0.6172267
0.16119028 0.96314568 0.59865986 -0.23445085 0.14730728]]
[[-0.64887064 -0.49435156 -0.00098317 0.22632825 -0.45637361
0.10613126 0.6985296 0.4481981 -0.04567536 0.16216273]
[-0.8229839 -0.64132454 -0.00468272 0.3630846 -0.49435403
0.15683851 0.89261432 0.51742868 -0.15192277 0.16823444]
[-0.80015177 -0.65880448 -0.01965226 0.42119995 -0.43309781
0.18892933 0.90951544 0.43914896 -0.26462378 0.1727273 ]]]
***********print output_bw
[[[-0.80015177 -0.65880448 -0.01965226 0.42119995 -0.43309781
0.18892933 0.90951544 0.43914896 -0.26462378 0.1727273 ]
[-0.8229839 -0.64132454 -0.00468272 0.3630846 -0.49435403
0.15683851 0.89261432 0.51742868 -0.15192277 0.16823444]
[-0.64887064 -0.49435156 -0.00098317 0.22632825 -0.45637361
0.10613126 0.6985296 0.4481981 -0.04567536 0.16216273]]
[[-0.90904258 -0.72417344 -0.00340704 0.47999107 -0.6172267
0.16119028 0.96314568 0.59865986 -0.23445085 0.14730728]
[-0.8034695 -0.62766681 -0.00827474 0.34252566 -0.47038781
0.15523849 0.87603592 0.43703746 -0.16224657 0.18179719]
[-0.52647366 -0.41398232 -0.01570638 0.16715023 -0.25770299
0.11789769 0.60452984 0.19651944 -0.07048413 0.18267517]]]
注意看正向和反向的输出
图解(字难看,图也乱,但是还能将就看。)
参考资料
[1] https://blog.csdn.net/u012436149/article/details/71080601
[2] https://blog.csdn.net/qq_41424519/article/details/82112904