def _build_rnn_graph_lstm
def _build_rnn_graph_lstm(self, inputs, config, is_training):
"""Build the inference graph using canonical LSTM cells."""
def make_cell():
cell = self._get_lstm_cell(config, is_training)
if is_training and config.keep_prob < 1:
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=config.keep_prob)
return cell
cell = tf.contrib.rnn.MultiRNNCell([make_cell() for _ in range(config.num_layers)], state_is_tuple=True)
self._initial_state = cell.zero_state(config.batch_size, data_type())
state = self._initial_state
outputs = []
with tf.variable_scope("RNN"):
for time_step in range(self.num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)
output = tf.reshape(tf.concat(outputs, 1), [-1, config.hidden_size])
return output, state