class LSTM(object):
"""LSTM layer using dynamic_rnn.
Exposes variables in `trainable_weights` property.
"""
def __init__(self, cell_size, num_layers=1, keep_prob=1., name='LSTM'):
self.cell_size = cell_size
self.num_layers = num_layers
self.keep_prob = keep_prob
self.reuse = None
self.trainable_weights = None
self.name = name
def __call__(self, x, initial_state, seq_length):
with tf.variable_scope(self.name, reuse=self.reuse) as vs:
cell = tf.contrib.rnn.MultiRNNCell([
tf.contrib.rnn.BasicLSTMCell(
self.cell_size,
forget_bias=0.0,
reuse=tf.get_variable_scope().reuse)
for _ in xrange(self.num_layers)
])
lstm_out, next_state = tf.nn.dynamic_rnn(
cell, x, initial_state=initial_state, sequence_length=seq_length)
# shape(lstm_out) = (batch_size, timesteps, cell_size)
if self.keep_prob < 1.:
lstm_out = tf.nn.dropout(lstm_out, self.keep_prob)
if self.reuse is None:
self.trainable_weights = vs.global_variables()
self.reuse = True
return lstm_out, next_state
class Actionselect(object):
def __init__(self,
action_class,
**kwargs):
self.multiclass_dense_layer = K.layers.Dense(action_class)
def __call__(self,input_data):
return self.multiclass_dense_layer(input_data)
最后编辑于 :
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。