1 def _build_rnn_graph_cudnn(self, inputs, config, is_training):
2 """Build the inference graph using CUDNN cell."""
3 inputs = tf.transpose(inputs, [1, 0, 2])
4 self._cell = tf.contrib.cudnn_rnn.CudnnLSTM(
5 num_layers=config.num_layers,
6 num_units=config.hidden_size,
7 input_size=config.hidden_size,
8 dropout=1 - config.keep_prob if is_training else 0
9 )
10 params_size_t = self._cell.params_size()
11 self._rnn_params = tf.get_variable(
12 "lstm_params",
13 initializer=tf.random_uniform([params_size_t], -config.init_scale, config.init_scale),
14 validate_shape=False
15 )
16 c = tf.zeros([config.num_layers, self.batch_size, config.hidden_size], tf.float32)
17 h = tf.zeros([config.num_layers, self.batch_size, config.hidden_size], tf.float32)
18 self._initial_state = (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)
19 outputs, h, c = self._cell(inputs, h, c, self._rnn_params, is_training)
20 outputs = tf.transpose(outputs, [1, 0, 2])
21 outputs = tf.reshape(outputs, [-1, config.hidden_size])
22 return outputs, (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)