Skip to content

Commit

Permalink
Merge pull request #39 from keisuke-yanagisawa/yaml_default
Browse files Browse the repository at this point in the history
Refactor Default Value Handling in Configuration Files
  • Loading branch information
keisuke-yanagisawa authored Aug 16, 2024
2 parents c4aaba5 + 7fb388d commit 33cb3c2
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 33 deletions.
2 changes: 1 addition & 1 deletion script/alignresenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def selector(a: Atom):

sup = SuperImposer()
tmppdb = Path(tempfile.mkstemp(suffix=".pdb")[1])
with uPDB.PDBIOhelper(str(tmppdb)) as pdbio:
with uPDB.PDBIOhelper(tmppdb) as pdbio:
for model in tqdm(struct, desc="[align res. env.]", disable=not verbose):

# print(struct, i)
Expand Down
4 changes: 2 additions & 2 deletions script/test_generate_msmd_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def test_create_frcmod(self):

def test_invalid_atomtype(self):
with self.assertRaises(ValueError):
_create_frcmod(self.mol2file, "INVALID") # type: ignore
_create_frcmod(self.mol2file, Path("INVALID")) # type: ignore

def test_mol2file_does_not_exist(self):
with self.assertRaises(FileNotFoundError):
_create_frcmod("INVALID", self.atomtype) # type: ignore
_create_frcmod(Path("INVALID"), self.atomtype) # type: ignore

def test_invalid_mol2file(self):
with self.assertRaises(ValueError):
Expand Down
4 changes: 2 additions & 2 deletions script/utilities/Bio/PDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class PDBIOhelper:
これは1つずつモデルを保存していきます。
"""

def __init__(self, path: str):
def __init__(self, path: Path):
self.path = expandpath(path)
self.open()

Expand Down Expand Up @@ -449,7 +449,7 @@ def concatenate_structures(structs: List[Structure]) -> Structure:

with tempfile.NamedTemporaryFile("w") as f:

out_helper = PDBIOhelper(f.name)
out_helper = PDBIOhelper(Path(f.name))
for struct in structs:
out_helper.save(struct)
out_helper.close()
Expand Down
2 changes: 1 addition & 1 deletion script/utilities/Bio/test_PDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_pdb_io_helper(self):
reader = PDB.MultiModelPDBReader(str(self.pdbfile))
models = [model for model in reader]
tmp_output_pdb = Path(tempfile.mkstemp(suffix=".pdb")[1])
writer = PDB.PDBIOhelper(str(tmp_output_pdb))
writer = PDB.PDBIOhelper(tmp_output_pdb)
for model in models:
writer.save(model)
writer.close() # This is important, otherwise the unittest will hang up
Expand Down
72 changes: 45 additions & 27 deletions script/utilities/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,17 @@
import yaml

from .logger import logger
import collections.abc


def update_dict(d: dict, u: dict):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = update_dict(d.get(k, {}), v) # type: ignore
else:
d[k] = v
return d

def getabsolutepath(path: Path) -> Path:
"""
Get absolute path from relative path
Expand Down Expand Up @@ -67,29 +76,6 @@ def expand_index(ind_info: Union[str, int]):
return ret


def set_default(setting: dict) -> None:
"""
Set default values (in place) for some fields in setting
"""
if "multiprocessing" not in setting["general"]:
setting["general"]["multiprocessing"] = -1
if "valid_dist" not in setting["map"]:
setting["map"]["valid_dist"] = 5
if "threshold" not in setting["probe_profile"]:
setting["probe_profile"]["threshold"] = 0.001
if "env_dist" not in setting["probe_profile"]:
setting["probe_profile"]["resenv"]["env_dist"] = 4

if "dt" not in setting["exprorer_msmd"]["general"]:
setting["exprorer_msmd"]["general"]["dt"] = 0.002
if "temperature" not in setting["exprorer_msmd"]["general"]:
setting["exprorer_msmd"]["general"]["temperature"] = 300
if "pressure" not in setting["exprorer_msmd"]["general"]:
setting["exprorer_msmd"]["general"]["pressure"] = 1.0
if "num_process_per_gpu" not in setting["general"]:
setting["general"]["num_process_per_gpu"] = 1


def ensure_compatibility_v1_1(setting: dict):
"""
Ensure compatibility with exprorer_msmd v1.1
Expand All @@ -110,6 +96,38 @@ def ensure_compatibility_v1_1(setting: dict):


def parse_yaml(yamlpath: Path) -> dict:
setting: dict = {
"general": {
"workdir": Path(""),
"multiprocessing": -1,
"num_process_per_gpu": 1,
},
"input": {
"protein": {
"pdb": Path(""),
},
"probe": {
"cid": "",
},
},
"exprorer_msmd": {
"general": {
"dt": 0.002,
"temperature": 300,
"pressure": 1.0,
},
},
"map": {
"snapshot": "",
"valid_dist": 5.0,
},
"probe_profile": {
"threshold": 0.001,
"resenv": {
"env_dist": 4.0,
},
},
}
YAML_PATH = getabsolutepath(yamlpath)
YAML_DIR_PATH = YAML_PATH.parent
if not YAML_PATH.exists():
Expand All @@ -119,10 +137,11 @@ def parse_yaml(yamlpath: Path) -> dict:
if not os.path.splitext(YAML_PATH)[1][1:] == "yaml":
raise ValueError("YAML file must have .yaml extension: %s" % YAML_PATH)
with YAML_PATH.open() as fin:
setting: dict = yaml.safe_load(fin) # type: ignore
yaml_dict: dict = yaml.safe_load(fin) # type: ignore
if yaml_dict is not None:
update_dict(setting, yaml_dict)

ensure_compatibility_v1_1(setting)
set_default(setting)

if "mol2" not in setting["input"]["probe"] or setting["input"]["probe"]["mol2"] is None:
setting["input"]["probe"]["mol2"] = setting["input"]["probe"]["cid"] + ".mol2"
Expand Down Expand Up @@ -169,9 +188,8 @@ def parse_yaml(yamlpath: Path) -> dict:
)
setting["input"]["probe"]["mol2"] = Path(setting["input"]["probe"]["mol2"])

if setting["input"]["protein"]["ssbond"] is None:
if "ssbond" not in setting["input"]["protein"] or setting["input"]["protein"]["ssbond"] is None:
setting["input"]["protein"]["ssbond"] = []

setting["general"]["yaml"] = YAML_PATH

return setting

0 comments on commit 33cb3c2

Please sign in to comment.