def main(_)

Fig.1 main( ) 函數的前半段

Fig.1 是 main( ) 函數的前半段流程。

如果沒有 FLAGS.data_path 變數,則程式出現 ValueError( ) 後結束。否則繼續執行。

"Get gpus list" 將取得系統所有的 GPU,存在 gpus 變數中

如果參數 FLAGS.num_gpus 大於系統的 GPU 個數 gpus,則程式結束,(換句話說,ValueError( ) ) 。否則,繼續執行。

再來呼叫 Class reader 中的 ptb_raw_data( ) 函式,輸出 raw_data。

raw_data 再轉成 train_data、valid_data、test_data。

呼叫 get_config( ) 函式,將產生 config 與 eval_config 變數。

config 未來會在 Training 與 Validate 時使用

eval_config 未來會在 Testing 時使用

Fig.2 main( ) 函數的中段 (在程式第21~55行),執行 with tf.Graph().as_defalut()

Fig.2 main( ) 函數的中段流程

輸入 -config.init_scale 與 config.init_scale 進 tf.random_uniform_initializer(),產生 initializer。initializer 在接下來各個 tf.name_scope() 中會使用到。

分別在 "tf.name_scope() = Train", "tf.name_scope() = Valid", "tf.name_scope() = Test" 中產生 m, mvalid, mtest。

m, mvalid 與 mtest 都是 class PTBModel object。

Models 變數將 m, mvalid 與 mtest 包成一個 dictionary

code block is related to "utility.py"

Fig.3 main( ) 函數的後段(在程式第57~80行),執行 with tf.Graph().as_defalut()

Fig.3 main( ) 函數的段流程。

輸入 FLAGS.savepath 進 tf.train.Supervisor( ),得到 sv

輸入 soft_placement 進 tf.ConfigProto( ),得到 config_proto

for loop 迴圈 train and valid。

i = 0 時,輸入 session, m, m.train_op, True 進 run_epoch,得到 train_perplexity,且輸入 session, mvalid 進 ,得到 valid_perplexity,接著判斷 i 是否等於 max_epoch。i <= max_epoch 則 i = i + 1,繼續執行迴圈,直到 i = max_epoch 後停止

接著輸入session, mtest 進 run_epoch( ),得到 test_perplexity

1   def main(_):
2     if not FLAGS.data_path:
3         raise ValueError("Must set --data_path to PTB data directory")
4     gpus = [
5         x.name for x in device_lib.list_local_devices() if x.device_type == "GPU"
6     ]
7     if FLAGS.num_gpus > len(gpus):
8       raise ValueError(
9           "Your machine has only %d gpus "
10          "which is less than the requested --num_gpus=%d."
11          % (len(gpus), FLAGS.num_gpus))
12
13    raw_data = reader.ptb_raw_data(FLAGS.data_path)
14    train_data, valid_data, test_data, _ = raw_data
15
16    config = get_config()
17    eval_config = get_config()
18    eval_config.batch_size = 1
19    eval_config.num_steps = 1
20
21    with tf.Graph().as_default():
22      initializer = tf.random_uniform_initializer(-config.init_scale,
23                                                config.init_scale)
24
25      with tf.name_scope("Train"):
26        train_input = PTBInput(config=config, data=train_data, name="TrainInput")
27        with tf.variable_scope("Model", reuse=None, initializer=initializer):
28          m = PTBModel(is_training=True, config=config, input_=train_input)
29        tf.summary.scalar("Training Loss", m.cost)
30        tf.summary.scalar("Learning Rate", m.lr)
31
32      with tf.name_scope("Valid"):
33        valid_input = PTBInput(config=config, data=valid_data, name="ValidInput")
34        with tf.variable_scope("Model", reuse=True, initializer=initializer):
35          mvalid = PTBModel(is_training=False, config=config, input_=valid_input)
36        tf.summary.scalar("Validation Loss", mvalid.cost)
37
38      with tf.name_scope("Test"):
39        test_input = PTBInput(
40            config=eval_config, data=test_data, name="TestInput")
41        with tf.variable_scope("Model", reuse=True, initializer=initializer):
42          mtest = PTBModel(is_training=False, config=eval_config,
43                         input_=test_input)
44
45      models = {"Train": m, "Valid": mvalid, "Test": mtest}
46      for name, model in models.items():
47        model.export_ops(name)
48      metagraph = tf.train.export_meta_graph()
49      if tf.__version__ < "1.1.0" and FLAGS.num_gpus > 1:
50        raise ValueError("num_gpus > 1 is not supported for TensorFlow versions "
51                       "below 1.1.0")
52      soft_placement = False
53      if FLAGS.num_gpus > 1:
54        soft_placement = True
55        util.auto_parallel(metagraph, m)
56
57    with tf.Graph().as_default():
58      tf.train.import_meta_graph(metagraph)
59      for model in models.values():
30        model.import_ops()
61      sv = tf.train.Supervisor(logdir=FLAGS.save_path)
62      config_proto = tf.ConfigProto(allow_soft_placement=soft_placement)
63      with sv.managed_session(config=config_proto) as session:
64        for i in range(config.max_max_epoch):
65          lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0)
66          m.assign_lr(session, config.learning_rate * lr_decay)
67
68          print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
69          train_perplexity = run_epoch(session, m, eval_op=m.train_op,
70                                     verbose=True)
71          print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
72          valid_perplexity = run_epoch(session, mvalid)
73          print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
74
75        test_perplexity = run_epoch(session, mtest)
76        print("Test Perplexity: %.3f" % test_perplexity)
77
78        if FLAGS.save_path:
79          print("Saving model to %s." % FLAGS.save_path)
80          sv.saver.save(session, FLAGS.save_path, global_step=sv.global_step)
81

results matching ""

    No results matching ""