这一部分,先留着,等有时间再好好整理,因为tensorflow更新了,所以跟着 GRU 内部实现也改变了,有时间再看定制的cell。。。
- 我们很容易去使用标准 LSTM 和 GRU cell,那么我们能不能定制自己的 RNN cell 呢? 可以的
-
先上图,有时间再翻译。。
class GRUCell(tf.nn.rnn_cell.RNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
def __init__(self, num_units):
self._num_units = num_units
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
with tf.variable_scope(scope or type(self).__name__): # "GRUCell"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
ru = tf.nn.rnn_cell._linear([inputs, state],
2 * self._num_units, True, 1.0)
ru = tf.nn.sigmoid(ru)
r, u = tf.split(1, 2, ru)
with tf.variable_scope("Candidate"):
c = tf.nn.tanh(tf.nn.rnn_cell._linear([inputs, r * state],
self._num_units, True))
new_h = u * state + (1 - u) * c
return new_h, new_h
def __init__(self, num_units, num_weights):
self._num_units = num_units
self._num_weights = num_weights
class CustomCell(tf.nn.rnn_cell.RNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
def __init__(self, num_units, num_weights):
self._num_units = num_units
self._num_weights = num_weights
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
with tf.variable_scope(scope or type(self).__name__): # "GRUCell"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
ru = tf.nn.rnn_cell._linear([inputs, state],
2 * self._num_units, True, 1.0)
ru = tf.nn.sigmoid(ru)
r, u = tf.split(1, 2, ru)
with tf.variable_scope("Candidate"):
lambdas = tf.nn.rnn_cell._linear([inputs, state], self._num_weights, True)
lambdas = tf.split(1, self._num_weights, tf.nn.softmax(lambdas))
Ws = tf.get_variable("Ws",
shape = [self._num_weights, inputs.get_shape()[1], self._num_units])
Ws = [tf.squeeze(i) for i in tf.split(0, self._num_weights, Ws)]
candidate_inputs = []
for idx, W in enumerate(Ws):
candidate_inputs.append(tf.matmul(inputs, W) * lambdas[idx])
Wx = tf.add_n(candidate_inputs)
c = tf.nn.tanh(Wx + tf.nn.rnn_cell._linear([r * state],
self._num_units, True, scope="second"))
new_h = u * state + (1 - u) * c
return new_h, new_h
- 现在我们的定制 RNN cell 已编写完成,现在让我们来和 GRU cell 比较一下(使用 num_step = 30,因为这比 num_steps = 200 更加优越):
def build_multilayer_graph_with_custom_cell(
cell_type = None,
num_weights_for_custom_cell = 5,
state_size = 100,
num_classes = vocab_size,
batch_size = 32,
num_steps = 200,
num_layers = 3,
learning_rate = 1e-4):
reset_graph()
x = tf.placeholder(tf.int32, [batch_size, num_steps], name='input_placeholder')
y = tf.placeholder(tf.int32, [batch_size, num_steps], name='labels_placeholder')
embeddings = tf.get_variable('embedding_matrix', [num_classes, state_size])
rnn_inputs = tf.nn.embedding_lookup(embeddings, x)
if cell_type == 'Custom':
cell = CustomCell(state_size, num_weights_for_custom_cell)
elif cell_type == 'GRU':
cell = tf.nn.rnn_cell.GRUCell(state_size)
elif cell_type == 'LSTM':
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
else:
cell = tf.nn.rnn_cell.BasicRNNCell(state_size)
if cell_type == 'LSTM':
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
else:
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)
init_state = cell.zero_state(batch_size, tf.float32)
rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, initial_state=init_state)
with tf.variable_scope('softmax'):
W = tf.get_variable('W', [state_size, num_classes])
b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))
#reshape rnn_outputs and y
rnn_outputs = tf.reshape(rnn_outputs, [-1, state_size])
y_reshaped = tf.reshape(y, [-1])
logits = tf.matmul(rnn_outputs, W) + b
total_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits = logits, labels = y_reshaped))
train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
return dict(
x = x,
y = y,
init_state = init_state,
final_state = final_state,
total_loss = total_loss,
train_step = train_step
)