def _build_rnn_graph_cudnn(self, inputs, config, is_training)

1   def _build_rnn_graph_cudnn(self, inputs, config, is_training):
2       """Build the inference graph using CUDNN cell."""
3       inputs = tf.transpose(inputs, [1, 0, 2])
4       self._cell = tf.contrib.cudnn_rnn.CudnnLSTM(
5           num_layers=config.num_layers,
6           num_units=config.hidden_size,
7           input_size=config.hidden_size,
8           dropout=1 - config.keep_prob if is_training else 0
9       )
10      params_size_t = self._cell.params_size()
11      self._rnn_params = tf.get_variable(
12          "lstm_params",
13          initializer=tf.random_uniform([params_size_t], -config.init_scale, config.init_scale),
14          validate_shape=False
15      )
16      c = tf.zeros([config.num_layers, self.batch_size, config.hidden_size], tf.float32)
17      h = tf.zeros([config.num_layers, self.batch_size, config.hidden_size], tf.float32)
18      self._initial_state = (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)
19      outputs, h, c = self._cell(inputs, h, c, self._rnn_params, is_training)
20      outputs = tf.transpose(outputs, [1, 0, 2])
21      outputs = tf.reshape(outputs, [-1, config.hidden_size])
22      return outputs, (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)

results matching ""

    No results matching ""