86 def ptb_producer(raw_data, batch_size, num_steps, name=None):
87   """Iterate on the raw PTB data.
88
89   This chunks up raw_data into batches of examples and returns Tensors that
90   are drawn from these batches.
91 
92   Args:
93     raw_data: one of the raw data outputs from ptb_raw_data.
94     batch_size: int, the batch size.
95     num_steps: int, the number of unrolls.
96     name: the name of this operation (optional).
97
98     Returns:
99     A pair of Tensors, each shaped [batch_size, num_steps]. The second element
100    of the tuple is the same data time-shifted to the right by one.
101
102    Raises:
103     tf.errors.InvalidArgumentError: if batch_size or num_steps are too high.
104   """
105   with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):
106     raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)
107 
108     data_len = tf.size(raw_data)
109     batch_len = data_len // batch_size
110     data = tf.reshape(raw_data[0 : batch_size * batch_len],
111                       [batch_size, batch_len])
112 
113     epoch_size = (batch_len - 1) // num_steps
114     assertion = tf.assert_positive(
115         epoch_size,
116         message="epoch_size == 0, decrease batch_size or num_steps")
117     with tf.control_dependencies([assertion]):
118       epoch_size = tf.identity(epoch_size, name="epoch_size")
119 
120     i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
121     x = tf.strided_slice(data, [0, i * num_steps],
122                          [batch_size, (i + 1) * num_steps])
123     x.set_shape([batch_size, num_steps])
124     y = tf.strided_slice(data, [0, i * num_steps + 1],
125                          [batch_size, (i + 1) * num_steps + 1])
126     y.set_shape([batch_size, num_steps])
127     return x, y

Figure 1: Flow chart of ptb_producer.py file

results matching ""

    No results matching ""