-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
133 lines (104 loc) · 4.48 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import argparse
import config
from common.utils import str2bool
#from ReinforcementLearning.dqnAgent import DQNAgent
from ImitationLearning.ImitationModel import ImitationModel
_imitationLearningList = ['Basic','Multimodal','Codevilla18','Codevilla19','Kim2017','Experimental','ExpBranch','Approach']
class Main():
""" Constructor """
def __init__(self,init,setting):
self.model = None
# Define seed
init.set_seed()
if setting.model in _imitationLearningList:
self.model = ImitationModel(init,setting)
else:
raise NameError('ERROR 404: Model no found')
def train(self):
self.model.build()
self.model.create_model()
self.model.execute()
def to_continue(self,name,epoch=None):
self.model.build()
self.model.to_continue(name,epoch)
self.model.execute()
def load(self,path):
self.model.build()
self.model.create_model()
self.model.load(path)
def study(self,name,epoch):
self.model.build( study=True)
self.model.to_continue(name,epoch,study=True)
self.model.execute( study=True)
def plot(self,name):
self.model.build()
self.model.plot(name)
def play(self):
pass
if __name__ == "__main__":
# Parser define
parser = argparse.ArgumentParser(description="SelfDriving")
# Path
parser.add_argument("--trainpath" ,type=str,help="Data for train")
parser.add_argument("--validpath" ,type=str,help="Data for validation")
parser.add_argument("--savedpath" ,type=str,help="Path for saved data")
parser.add_argument("--modelpath" ,type=str,help="Model file path")
parser.add_argument("--init" ,type=str,help="Init json path")
parser.add_argument("--setting" ,type=str,help="Setting json path")
parser.add_argument("--epoch" ,type=int,help="Number of epoch for train")
parser.add_argument("--batch_size",type=int,help="Batch size for train")
parser.add_argument("--model" ,type=str,help="End-to-End model: Basic, Multimodal, Codevilla18, Codevilla19, Kim2017")
parser.add_argument("--workers" ,type=int,help="Number of CPU workers")
parser.add_argument("--optimizer",type=str ,help="Optimizer method: Adam, RAdam, Ranger, DiffGrad, DiffRGrad, DeepMemory.")
parser.add_argument("--scheduler",type=str2bool,help="Use scheduler (boolean)")
parser.add_argument("--name",type=str,help="Code model.")
parser.add_argument("--mode",default="train",type=str,help="Select execution mode: train,continue,play,plot")
args = parser.parse_args()
# Setting
init = config. Init()
setting = config.Setting()
# Load setting
if args.init is not None: init .load(args. init)
if args.setting is not None: setting.load(args.setting)
if args.workers is not None: init.num_workers = args.workers
# Model
if args.model is not None: setting.model_( args.model )
# Path
if args.trainpath is not None: setting.general.trainPath = args.trainpath
if args.validpath is not None: setting.general.validPath = args.validpath
if args.savedpath is not None: setting.general.savedPath = args.savedpath
# Train
if args.epoch is not None: setting.train.n_epoch = args.epoch
if args.batch_size is not None: setting.train.batch_size = args.batch_size
if args.optimizer is not None: setting.train.optimizer.type = args.optimizer
if args.scheduler is not None: setting.train.scheduler.available = args.scheduler
# Print settings
setting.print()
# Loaded modes
if args.mode in ['continue','plot']:
init.is_loadedModel = True
# Main program
main = Main(init,setting)
# Load model
if args.modelpath is not None:
main.load(args.modelpath)
# Execute mode
if args.mode == "train":
main.train()
elif args.mode == "study":
main.study(args.name,args.epoch)
elif args.mode == "play":
main.play()
elif args.mode == "continue":
if args.name is not None:
main.to_continue(args.name,args.epoch)
else:
NameError('Undefined model. Please define with --name"')
elif args.mode == "plot":
if args.name is not None:
main.plot(args.name)
else:
NameError('Undefined model. Please define with --name"')
else:
print("Valid execution modes: train,play")