-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathupdate_config_parameters.py
64 lines (52 loc) · 2.84 KB
/
update_config_parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
Updates model paramters from global config file to model-specific config file.
Usage: python update_config_parameters.py GLOBAL_CONFIG_FILE MODEL_CONFIG_FILE MODEL_NAME
GLOBAL_CONFIG_FILE: Path to global config file (e.g., epiearth_darwin.yaml)
MODEL_CONFIG_FILE: Path to model-specific config file (e.g., mosq_config.yaml)
MODEL_NAME: Name of the model (e.g., MOSQUITO_POP_MODEL or EPI_MODEL)
"""
import sys
import yaml
from pathlib import Path
def update_model_config_params_based_on_global_config_params(param_key, global_params, model_params):
#print('\n')
#print(global_params)
#print(model_params)
#print('-{}, {}, {}'.format(param_key, global_params, model_params))
if not isinstance(global_params[param_key], dict):
assert type(global_params[param_key]) == type(model_params[param_key])
model_params[param_key] = global_params[param_key]
#print('----m2{}, {}, {}'.format(param_key, global_params, model_params))
return model_params
for gb_key, gb_value in global_params[param_key].items():
#print('--{}, {}, {}'.format(gb_key, global_params[param_key], model_params[param_key]))
model_params[param_key] = update_model_config_params_based_on_global_config_params(gb_key, global_params[param_key], model_params[param_key])
#print(model_params)
return model_params
def update_config_file(global_config_file, model_config_file, model_name):
"""
Updates model paramters from global config file to model-specific config file.
Arguments:
1. GLOBAL_CONFIG_FILE: Path to global config file (e.g., epiearth_darwin.yaml)
2. MODEL_CONFIG_FILE: Path to model-specific config file (e.g., mosq_config.yaml)
3. MODEL_NAME: Name of the model (e.g., MOSQUITO_POP_MODEL or EPI_MODEL)
"""
# Read global and model-specific config files
global_params = yaml.safe_load(Path(global_config_file).read_text())
model_params = yaml.safe_load(Path(model_config_file).read_text())
# Raise an error if unknown model name found.
if model_name not in ['MOSQUITO_POP_MODEL', 'EPI_MODEL']:
raise ValueError('ERROR!! Found unknow model name: {}. Following model names are allowed: MOSQUITO_POP_MODEL and EPI_MODEL.'.format(model_name))
# Read model paramters from global config file
global_params = global_params[model_name]['CONFIG_FILE_PARAMETERS']
#print(model_params)
#print('\n')
for param_key in global_params.keys():
assert param_key in model_params.keys()
model_params = update_model_config_params_based_on_global_config_params(param_key, global_params, model_params)
#print(model_params)
# Write updated model parameters
with open(model_config_file, 'w') as file:
yaml.dump(model_params, file)
if __name__ == '__main__':
update_config_file(sys.argv[1], sys.argv[2], sys.argv[3])