1 def export_state_tuples(state_tuples, name):
2 for state_tuple in state_tuples:
3 tf.add_to_collection(name, state_tuple.c)
4 tf.add_to_collection(name, state_tuple.h)
5
6
7 def import_state_tuples(state_tuples, name, num_replicas):
8 restored = []
9 for i in range(len(state_tuples) * num_replicas):
10 c = tf.get_collection_ref(name)[2 * i + 0]
11 h = tf.get_collection_ref(name)[2 * i + 1]
12 restored.append(tf.contrib.rnn.LSTMStateTuple(c, h))
13 return tuple(restored)
Figure 1: Flow chart of export_state_tuples.py file
Figure 2: Flow chart of import_state_tuples.py file