main.py
Fig.1 main.py 的流程
Fig.1 為 main.py 的流程。第1~8行 import module;
第1行 from model import MODEL
是從 model 資料夾中,將 MODEL.py 檔案匯入程式中;第2行 from utility import utility
是從 utility 資料夾中,將 utility.py 檔案匯入程式中;第3~8行匯入 Python 的 modules。
第11~34行 initialize parameters (BATCHSIZE, X_DIMENSION...等等);第36~49行 print parameters;
第51, 52行開啟 "output_data/question/qa.train.json" 檔案,以 data_file
變數存放檔案內容,將data_file
以 json_load()
[1] 的方式讀出,存在 train_q
變數中,train_q
的 type 是 dictionary
第53, 54行開啟 "output_data/question/qa.val.json" 檔案,以 data_file
變數存放檔案內容,將data_file
以 json_load()
[1] 的方式讀出,存在 val_q
變數中,val_q
的 type 是 dictionary。
註[1] : json 是javascript Object Notation的縮寫,是一種文字格式。裡面儲存是以dictionary type方式進行,所以json_load()會output dictionary。
第58, 59行傳入 BATCHSIZE, X_DIMENSION... 等13個參數,產生一個 Model() 類別的物件,存在變數 acm_net 中,接著執行 acm_net.initializa()
。
第60行傳入 train_q, val_q... 等9個參數給 utility/utility.py 檔案中的 train( ) 執行,將會訓練 model 並產生一個 train 與 valid 時的正確率,存在變數 accuracy 中 。accuracy 的 type 是 dictionary。
第61~65行分別取出 accuracy 變數中,train 與 valid 時的值,分別存在變數 accuracy_train 與 accuracy_val 中,並 print out。
第56, 67, 68行,計時程式執行時間後並 print out。
1 from model import MODEL
2 from utility import utility
3 import time
4 import json
5 import numpy as np
6 import random
7 import os
8 import sys
9
10
11 EPOCH = 50
12 X_DIMENSION = 300
13 LEARNING_RATE = 0.001
14 BATCHSIZE = 20
15 DROPOUT = 0.8
16 choice = 5
17 max_plot_num = 101
18 max_len = [100,50,50]
19 parameter_size = {
20 'cnn_filterSize':{'filter1':[1,3,5],'filter2':[1,3,5]},
21 'cnn_filterNum':128,
22 'cnn_filterNum2':128,
23 'dnn_hiddenUnits':128
24 }
25
26 CNN_FILTER_SIZE = parameter_size['cnn_filterSize']['filter1']
27
28 CNN_FILTER_SIZE2 = parameter_size['cnn_filterSize']['filter2']
29 CNN_FILTER_NUM = int(parameter_size['cnn_filterNum'])
30 CNN_FILTER_NUM2 = int(parameter_size['cnn_filterNum2'])
31 DNN_WIDTH = int(parameter_size['dnn_hiddenUnits'])
32
33 plotFilePath = "output_data/plot/"
34 parameterPath = 'parameter/'
35
36 print('###############################################################')
37 print('Epoch :',EPOCH)
38 print('X Dimension :', X_DIMENSION)
39 print('Learning Rate :', LEARNING_RATE)
40 print('Drop Out :', DROPOUT)
41 print('Plot Sentence Lum :', max_plot_num)
42 print('Plot Sentence Len :', max_len[0])
43 print('Q Sentence Len :', max_len[1])
44 print('Ans Sentence Len :', max_len[2])
45 print('CNN Filter Size :', CNN_FILTER_SIZE)
46 print('CNN Filter Size2 :', CNN_FILTER_SIZE2)
47 print('CNN Filter Num :', CNN_FILTER_NUM)
48 print('DNN Output Size :', DNN_WIDTH,'->',1)
49 print('###############################################################','\n')
50
51 with open('output_data/question/qa.train.json') as data_file:
52 train_q = json.load(data_file)
53 with open('output_data/question/qa.val.json') as data_file:
54 val_q = json.load(data_file)
55
56 start = time.time()
57
58 acm_net = MODEL.MODEL(BATCHSIZE,X_DIMENSION,DNN_WIDTH,CNN_FILTER_SIZE,CNN_FILTER_SIZE2,CNN_FILTER_NUM,CNN_FILTER_NUM2,LEARNING_RATE,DROPOUT,choice,max_plot_num,max_len,parameterPath)
59 acm_net.initialize()
60 accuracy = utility.train(train_q,val_q,plotFilePath,acm_net,EPOCH,LEARNING_RATE,BATCHSIZE,DROPOUT,choice,max_len)
61 accuracy_train = np.array(accuracy['train'])
62 accuracy_val = np.array(accuracy['val'])
63
64 print('val: ', accuracy_val, np.amax(accuracy_val))
65 print('train: ',accuracy_train, np.amax(accuracy_train))
66
67 eval_time = time.time()-start
68 print('use time: ',eval_time)