7
7
import yaml
8
8
9
9
from .logger import logger
10
+ import collections .abc
10
11
11
12
13
+ def update_dict (d : dict , u : dict ):
14
+ for k , v in u .items ():
15
+ if isinstance (v , collections .abc .Mapping ):
16
+ d [k ] = update_dict (d .get (k , {}), v ) # type: ignore
17
+ else :
18
+ d [k ] = v
19
+ return d
20
+
12
21
def getabsolutepath (path : Path ) -> Path :
13
22
"""
14
23
Get absolute path from relative path
@@ -67,29 +76,6 @@ def expand_index(ind_info: Union[str, int]):
67
76
return ret
68
77
69
78
70
- def set_default (setting : dict ) -> None :
71
- """
72
- Set default values (in place) for some fields in setting
73
- """
74
- if "multiprocessing" not in setting ["general" ]:
75
- setting ["general" ]["multiprocessing" ] = - 1
76
- if "valid_dist" not in setting ["map" ]:
77
- setting ["map" ]["valid_dist" ] = 5
78
- if "threshold" not in setting ["probe_profile" ]:
79
- setting ["probe_profile" ]["threshold" ] = 0.001
80
- if "env_dist" not in setting ["probe_profile" ]:
81
- setting ["probe_profile" ]["resenv" ]["env_dist" ] = 4
82
-
83
- if "dt" not in setting ["exprorer_msmd" ]["general" ]:
84
- setting ["exprorer_msmd" ]["general" ]["dt" ] = 0.002
85
- if "temperature" not in setting ["exprorer_msmd" ]["general" ]:
86
- setting ["exprorer_msmd" ]["general" ]["temperature" ] = 300
87
- if "pressure" not in setting ["exprorer_msmd" ]["general" ]:
88
- setting ["exprorer_msmd" ]["general" ]["pressure" ] = 1.0
89
- if "num_process_per_gpu" not in setting ["general" ]:
90
- setting ["general" ]["num_process_per_gpu" ] = 1
91
-
92
-
93
79
def ensure_compatibility_v1_1 (setting : dict ):
94
80
"""
95
81
Ensure compatibility with exprorer_msmd v1.1
@@ -110,6 +96,38 @@ def ensure_compatibility_v1_1(setting: dict):
110
96
111
97
112
98
def parse_yaml (yamlpath : Path ) -> dict :
99
+ setting : dict = {
100
+ "general" : {
101
+ "workdir" : Path ("" ),
102
+ "multiprocessing" : - 1 ,
103
+ "num_process_per_gpu" : 1 ,
104
+ },
105
+ "input" : {
106
+ "protein" : {
107
+ "pdb" : Path ("" ),
108
+ },
109
+ "probe" : {
110
+ "cid" : "" ,
111
+ },
112
+ },
113
+ "exprorer_msmd" : {
114
+ "general" : {
115
+ "dt" : 0.002 ,
116
+ "temperature" : 300 ,
117
+ "pressure" : 1.0 ,
118
+ },
119
+ },
120
+ "map" : {
121
+ "snapshot" : "" ,
122
+ "valid_dist" : 5.0 ,
123
+ },
124
+ "probe_profile" : {
125
+ "threshold" : 0.001 ,
126
+ "resenv" : {
127
+ "env_dist" : 4.0 ,
128
+ },
129
+ },
130
+ }
113
131
YAML_PATH = getabsolutepath (yamlpath )
114
132
YAML_DIR_PATH = YAML_PATH .parent
115
133
if not YAML_PATH .exists ():
@@ -119,10 +137,11 @@ def parse_yaml(yamlpath: Path) -> dict:
119
137
if not os .path .splitext (YAML_PATH )[1 ][1 :] == "yaml" :
120
138
raise ValueError ("YAML file must have .yaml extension: %s" % YAML_PATH )
121
139
with YAML_PATH .open () as fin :
122
- setting : dict = yaml .safe_load (fin ) # type: ignore
140
+ yaml_dict : dict = yaml .safe_load (fin ) # type: ignore
141
+ if yaml_dict is not None :
142
+ update_dict (setting , yaml_dict )
123
143
124
144
ensure_compatibility_v1_1 (setting )
125
- set_default (setting )
126
145
127
146
if "mol2" not in setting ["input" ]["probe" ] or setting ["input" ]["probe" ]["mol2" ] is None :
128
147
setting ["input" ]["probe" ]["mol2" ] = setting ["input" ]["probe" ]["cid" ] + ".mol2"
@@ -169,9 +188,8 @@ def parse_yaml(yamlpath: Path) -> dict:
169
188
)
170
189
setting ["input" ]["probe" ]["mol2" ] = Path (setting ["input" ]["probe" ]["mol2" ])
171
190
172
- if setting ["input" ]["protein" ]["ssbond" ] is None :
191
+ if "ssbond" not in setting [ "input" ][ "protein" ] or setting ["input" ]["protein" ]["ssbond" ] is None :
173
192
setting ["input" ]["protein" ]["ssbond" ] = []
174
193
175
194
setting ["general" ]["yaml" ] = YAML_PATH
176
-
177
195
return setting
0 commit comments