main.py

Fig.1 main.py 的流程

Fig.1 為 main.py 的流程。第1~8行 import module;

第1行 from model import MODEL 是從 model 資料夾中,將 MODEL.py 檔案匯入程式中;第2行 from utility import utility 是從 utility 資料夾中,將 utility.py 檔案匯入程式中;第3~8行匯入 Python 的 modules。

第11~34行 initialize parameters (BATCHSIZE, X_DIMENSION...等等);第36~49行 print parameters;

第51, 52行開啟 "output_data/question/qa.train.json" 檔案,以 data_file 變數存放檔案內容,將data_filejson_load()[1] 的方式讀出,存在 train_q 變數中,train_q的 type 是 dictionary

第53, 54行開啟 "output_data/question/qa.val.json" 檔案,以 data_file 變數存放檔案內容,將data_filejson_load()[1] 的方式讀出,存在 val_q 變數中,val_q的 type 是 dictionary。

註[1] : json 是javascript Object Notation的縮寫,是一種文字格式。裡面儲存是以dictionary type方式進行,所以json_load()會output dictionary。

第58, 59行傳入 BATCHSIZE, X_DIMENSION... 等13個參數,產生一個 Model() 類別的物件,存在變數 acm_net 中,接著執行 acm_net.initializa()

第60行傳入 train_q, val_q... 等9個參數給 utility/utility.py 檔案中的 train( ) 執行,將會訓練 model 並產生一個 train 與 valid 時的正確率,存在變數 accuracy 中 。accuracy 的 type 是 dictionary。

第61~65行分別取出 accuracy 變數中,train 與 valid 時的值,分別存在變數 accuracy_train 與 accuracy_val 中,並 print out。

第56, 67, 68行,計時程式執行時間後並 print out。


1   from model import MODEL
2   from utility import utility
3   import time
4   import json
5   import numpy as np
6   import random
7   import os
8   import sys
9
10
11  EPOCH = 50
12  X_DIMENSION = 300
13  LEARNING_RATE = 0.001
14  BATCHSIZE = 20
15  DROPOUT = 0.8
16  choice = 5
17  max_plot_num = 101
18  max_len = [100,50,50]
19  parameter_size = {
20      'cnn_filterSize':{'filter1':[1,3,5],'filter2':[1,3,5]},    
21      'cnn_filterNum':128,
22    'cnn_filterNum2':128,
23    'dnn_hiddenUnits':128
24  }
25
26  CNN_FILTER_SIZE = parameter_size['cnn_filterSize']['filter1']
27
28  CNN_FILTER_SIZE2 = parameter_size['cnn_filterSize']['filter2']
29  CNN_FILTER_NUM = int(parameter_size['cnn_filterNum'])
30  CNN_FILTER_NUM2 = int(parameter_size['cnn_filterNum2'])
31  DNN_WIDTH = int(parameter_size['dnn_hiddenUnits'])
32
33  plotFilePath = "output_data/plot/"
34  parameterPath = 'parameter/'
35
36  print('###############################################################')
37  print('Epoch             :',EPOCH)
38  print('X Dimension       :', X_DIMENSION)
39  print('Learning Rate     :', LEARNING_RATE)
40  print('Drop Out          :', DROPOUT)
41  print('Plot Sentence Lum :', max_plot_num)
42  print('Plot Sentence Len :', max_len[0])
43  print('Q Sentence Len    :', max_len[1])
44  print('Ans Sentence Len  :', max_len[2])
45  print('CNN Filter Size   :', CNN_FILTER_SIZE)
46  print('CNN Filter Size2   :', CNN_FILTER_SIZE2)
47  print('CNN Filter Num    :', CNN_FILTER_NUM)
48  print('DNN Output Size   :', DNN_WIDTH,'->',1)
49  print('###############################################################','\n')
50
51  with open('output_data/question/qa.train.json') as data_file:    
52      train_q = json.load(data_file)
53  with open('output_data/question/qa.val.json') as data_file:
54    val_q = json.load(data_file)
55
56  start = time.time()
57
58  acm_net = MODEL.MODEL(BATCHSIZE,X_DIMENSION,DNN_WIDTH,CNN_FILTER_SIZE,CNN_FILTER_SIZE2,CNN_FILTER_NUM,CNN_FILTER_NUM2,LEARNING_RATE,DROPOUT,choice,max_plot_num,max_len,parameterPath)
59  acm_net.initialize()
60  accuracy = utility.train(train_q,val_q,plotFilePath,acm_net,EPOCH,LEARNING_RATE,BATCHSIZE,DROPOUT,choice,max_len)
61  accuracy_train = np.array(accuracy['train'])
62  accuracy_val = np.array(accuracy['val'])
63
64  print('val: ', accuracy_val, np.amax(accuracy_val))
65  print('train: ',accuracy_train, np.amax(accuracy_train))
66
67  eval_time = time.time()-start
68  print('use time: ',eval_time)

results matching ""

    No results matching ""