def main(_)
Fig.1 main( ) 函數的前半段
Fig.1 是 main( ) 函數的前半段流程。
如果沒有 FLAGS.data_path 變數,則程式出現 ValueError( ) 後結束。否則繼續執行。
"Get gpus list" 將取得系統所有的 GPU,存在 gpus 變數中
如果參數 FLAGS.num_gpus 大於系統的 GPU 個數 gpus,則程式結束,(換句話說,ValueError( ) ) 。否則,繼續執行。
再來呼叫 Class reader 中的 ptb_raw_data( ) 函式,輸出 raw_data。
raw_data 再轉成 train_data、valid_data、test_data。
呼叫 get_config( ) 函式,將產生 config 與 eval_config 變數。
config 未來會在 Training 與 Validate 時使用
eval_config 未來會在 Testing 時使用
Fig.2 main( ) 函數的中段 (在程式第21~55行),執行 with tf.Graph().as_defalut()
Fig.2 main( ) 函數的中段流程
輸入 -config.init_scale 與 config.init_scale 進 tf.random_uniform_initializer(),產生 initializer。initializer 在接下來各個 tf.name_scope() 中會使用到。
分別在 "tf.name_scope() = Train", "tf.name_scope() = Valid", "tf.name_scope() = Test" 中產生 m, mvalid, mtest。
m, mvalid 與 mtest 都是 class PTBModel object。
Models 變數將 m, mvalid 與 mtest 包成一個 dictionary
code block is related to "utility.py"
Fig.3 main( ) 函數的後段(在程式第57~80行),執行 with tf.Graph().as_defalut()
Fig.3 main( ) 函數的段流程。
輸入 FLAGS.savepath 進 tf.train.Supervisor( ),得到 sv
輸入 soft_placement 進 tf.ConfigProto( ),得到 config_proto
for loop 迴圈 train and valid。
i = 0 時,輸入 session, m, m.train_op, True 進 run_epoch,得到 train_perplexity,且輸入 session, mvalid 進 ,得到 valid_perplexity,接著判斷 i 是否等於 max_epoch。i <= max_epoch 則 i = i + 1,繼續執行迴圈,直到 i = max_epoch 後停止
接著輸入session, mtest 進 run_epoch( ),得到 test_perplexity
1 def main(_):
2 if not FLAGS.data_path:
3 raise ValueError("Must set --data_path to PTB data directory")
4 gpus = [
5 x.name for x in device_lib.list_local_devices() if x.device_type == "GPU"
6 ]
7 if FLAGS.num_gpus > len(gpus):
8 raise ValueError(
9 "Your machine has only %d gpus "
10 "which is less than the requested --num_gpus=%d."
11 % (len(gpus), FLAGS.num_gpus))
12
13 raw_data = reader.ptb_raw_data(FLAGS.data_path)
14 train_data, valid_data, test_data, _ = raw_data
15
16 config = get_config()
17 eval_config = get_config()
18 eval_config.batch_size = 1
19 eval_config.num_steps = 1
20
21 with tf.Graph().as_default():
22 initializer = tf.random_uniform_initializer(-config.init_scale,
23 config.init_scale)
24
25 with tf.name_scope("Train"):
26 train_input = PTBInput(config=config, data=train_data, name="TrainInput")
27 with tf.variable_scope("Model", reuse=None, initializer=initializer):
28 m = PTBModel(is_training=True, config=config, input_=train_input)
29 tf.summary.scalar("Training Loss", m.cost)
30 tf.summary.scalar("Learning Rate", m.lr)
31
32 with tf.name_scope("Valid"):
33 valid_input = PTBInput(config=config, data=valid_data, name="ValidInput")
34 with tf.variable_scope("Model", reuse=True, initializer=initializer):
35 mvalid = PTBModel(is_training=False, config=config, input_=valid_input)
36 tf.summary.scalar("Validation Loss", mvalid.cost)
37
38 with tf.name_scope("Test"):
39 test_input = PTBInput(
40 config=eval_config, data=test_data, name="TestInput")
41 with tf.variable_scope("Model", reuse=True, initializer=initializer):
42 mtest = PTBModel(is_training=False, config=eval_config,
43 input_=test_input)
44
45 models = {"Train": m, "Valid": mvalid, "Test": mtest}
46 for name, model in models.items():
47 model.export_ops(name)
48 metagraph = tf.train.export_meta_graph()
49 if tf.__version__ < "1.1.0" and FLAGS.num_gpus > 1:
50 raise ValueError("num_gpus > 1 is not supported for TensorFlow versions "
51 "below 1.1.0")
52 soft_placement = False
53 if FLAGS.num_gpus > 1:
54 soft_placement = True
55 util.auto_parallel(metagraph, m)
56
57 with tf.Graph().as_default():
58 tf.train.import_meta_graph(metagraph)
59 for model in models.values():
30 model.import_ops()
61 sv = tf.train.Supervisor(logdir=FLAGS.save_path)
62 config_proto = tf.ConfigProto(allow_soft_placement=soft_placement)
63 with sv.managed_session(config=config_proto) as session:
64 for i in range(config.max_max_epoch):
65 lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0)
66 m.assign_lr(session, config.learning_rate * lr_decay)
67
68 print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
69 train_perplexity = run_epoch(session, m, eval_op=m.train_op,
70 verbose=True)
71 print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
72 valid_perplexity = run_epoch(session, mvalid)
73 print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
74
75 test_perplexity = run_epoch(session, mtest)
76 print("Test Perplexity: %.3f" % test_perplexity)
77
78 if FLAGS.save_path:
79 print("Saving model to %s." % FLAGS.save_path)
80 sv.saver.save(session, FLAGS.save_path, global_step=sv.global_step)
81