class LSTM(object):
"""LSTM layer using dynamic_rnn.
Exposes variables in `trainable_weights` property.
"""
def __init__(self, cell_size, num_layers=1, keep_prob=1., name='LSTM'):
self.cell_size = cell_size
self.num_layers = num_layers
self.keep_prob = keep_prob
self.reuse = None
self.trainable_weights = None
self.name = name
print("You got one LSTM")
def __call__(self, x, initial_state, seq_length):
with tf.variable_scope(self.name, reuse=self.reuse) as vs:
cell = tf.contrib.rnn.MultiRNNCell([
tf.contrib.rnn.BasicLSTMCell(
self.cell_size,
forget_bias=0.0,
reuse=tf.get_variable_scope().reuse)
for _ in xrange(self.num_layers)
])
# shape(x) = (batch_size, num_timesteps, embedding_dim)
lstm_out, next_state = tf.nn.dynamic_rnn(
cell, x, initial_state=initial_state, sequence_length=seq_length)
# shape(lstm_out) = (batch_size, timesteps, cell_size)
if self.keep_prob < 1.:
lstm_out = tf.nn.dropout(lstm_out, self.keep_prob)
if self.reuse is None:
self.trainable_weights = vs.global_variables()
self.reuse = True
return lstm_out, next_state
[tf]进行变量和层的重用
最后编辑于 :
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- 一、图形变换 保存一个状态、恢复到上一个保存状态 因为Canvas的图形变换 '不是'基于状态的,下一个变换会在上...