def _get_lstm_cell(self, config, is_training)
1 def _get_lstm_cell(self, config, is_training):
2 if config.rnn_mode == BASIC:
3 return tf.contrib.rnn.BasicLSTMCell(
4 config.hidden_size, forget_bias=0.0, state_is_tuple=True,
5 reuse=not is_training
6 )
7 if config.rnn_mode == BLOCK:
8 return tf.contrib.rnn.LSTMBlockCell(
9 config.hidden_size, forget_bias=0.0
10 )
11 raise ValueError("rnn_mode %s not supported" % config.rnn_mode)