-
Notifications
You must be signed in to change notification settings - Fork 0
/
settings.py
63 lines (55 loc) · 2.59 KB
/
settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
seed = 97
settings = dict()
settings['bohb_results_dir'] = "./BOHB_results/"
settings['es_results_dir'] = './ES_results'
settings['data_dir'] = './data'
settings['loss_dict'] = {'cross_entropy': torch.nn.CrossEntropyLoss
}#'mse': torch.nn.MSELoss}
settings['opti_dict'] = {'adam': torch.optim.Adam,
'adad': torch.optim.Adadelta,
'sgd': torch.optim.SGD}
settings['run_dir'] = './stored_models/'
settings['init_channels'] = 8
settings['batch_size'] = 64
settings['seed'] = seed
def get(variable):
if variable in settings.keys():
return settings[variable]
else:
raise KeyError
darts_args=dict()
darts_args['data'] ='./data' #'location of the data corpus'
darts_args['set'] ='cifar10' #'location of the data corpus'
darts_args['batch_size'] =32 #'batch size'
darts_args['learning_rate']=0.1 #'init learning rate'
darts_args['learning_rate_min']=0.001 #'min learning rate'
darts_args['momentum']=0.9 #'momentum'
darts_args['weight_decay']=3e-4 #'weight decay'
darts_args['report_freq']=50 #'report frequency'
darts_args['gpu']='cuda:0' #'gpu device id'
darts_args['epochs']= 20 #'num of training epochs'
darts_args['init_channels']=8 #'num of init channels'
darts_args['layers']= 4 #'total number of layers'
darts_args['model_path']= 'saved_models' #'path to save the model'
darts_args['cutout']= False #'use cutout'
darts_args['cutout_length']= 16 #'cutout length'
darts_args['drop_path_prob']= 0.3 #'drop path probability'
darts_args['save']= 'EXP' #'experiment name'
darts_args['seed']=seed #'random seed'
darts_args['grad_clip']=5 #'gradient clipping'
darts_args['train_portion']=0.5 #'portion of training data'
darts_args['unrolled']=False #'use one-step unrolled validation loss'
darts_args['arch_learning_rate']=6e-4 #'learning rate for arch encoding'
darts_args['arch_weight_decay']=1e-3 #'weight decay for arch encoding'
main_args = dict()
main_args['seed'] =seed # help='random seed',type=int)
main_args['batch_size']=64 # help='Batch size',type=int)
main_args['data_dir'] ='./data'# help='Directory in which the data is stored (can be downloaded)')
main_args['optimizer'] ='adam' # help='Which optimizer to use during training',choices=list(get('opti_dict').keys()),type=str)
main_args['exp_no'] =5 # help='Experiment number',type=str)
main_args['verbose'] ='INFO' # choices=['INFO', 'DEBUG'],help='verbosity'
def get_main_args():
return main_args
def get_darts_args():
return darts_args