From 566efdad57fd2d4d9e5780e4bd1c0398d316ef0b Mon Sep 17 00:00:00 2001 From: Keisuke Yanagisawa Date: Thu, 15 Aug 2024 23:53:01 +0000 Subject: [PATCH 1/2] [WIP] default parameter --- script/utilities/util.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/script/utilities/util.py b/script/utilities/util.py index 66be7f6..d97a9da 100644 --- a/script/utilities/util.py +++ b/script/utilities/util.py @@ -110,6 +110,28 @@ def ensure_compatibility_v1_1(setting: dict): def parse_yaml(yamlpath: Path) -> dict: + setting: dict = { + "general": { + "workdir": Path(""), + }, + "input": { + "protein": { + "pdb": Path(""), + }, + "probe": { + "cid": "", + }, + }, + "exprorer_msmd": { + "general": {}, + }, + "map": { + "snapshot": "", + }, + "probe_profile": { + "resenv": {}, + }, + } YAML_PATH = getabsolutepath(yamlpath) YAML_DIR_PATH = YAML_PATH.parent if not YAML_PATH.exists(): @@ -119,7 +141,9 @@ def parse_yaml(yamlpath: Path) -> dict: if not os.path.splitext(YAML_PATH)[1][1:] == "yaml": raise ValueError("YAML file must have .yaml extension: %s" % YAML_PATH) with YAML_PATH.open() as fin: - setting: dict = yaml.safe_load(fin) # type: ignore + yaml_dict: dict = yaml.safe_load(fin) # type: ignore + if yaml_dict is not None: + setting.update(yaml_dict) ensure_compatibility_v1_1(setting) set_default(setting) @@ -169,7 +193,7 @@ def parse_yaml(yamlpath: Path) -> dict: ) setting["input"]["probe"]["mol2"] = Path(setting["input"]["probe"]["mol2"]) - if setting["input"]["protein"]["ssbond"] is None: + if "ssbond" not in setting["input"]["protein"] or setting["input"]["protein"]["ssbond"] is None: setting["input"]["protein"]["ssbond"] = [] setting["general"]["yaml"] = YAML_PATH From 7fb388de439bd5c59ed92dae1582ea72ae515fe8 Mon Sep 17 00:00:00 2001 From: Keisuke Yanagisawa Date: Fri, 16 Aug 2024 00:41:14 +0000 Subject: [PATCH 2/2] default parameter --- script/alignresenv.py | 2 +- script/test_generate_msmd_system.py | 4 +-- script/utilities/Bio/PDB.py | 4 +-- script/utilities/Bio/test_PDB.py | 2 +- script/utilities/util.py | 50 +++++++++++++---------------- 5 files changed, 28 insertions(+), 34 deletions(-) diff --git a/script/alignresenv.py b/script/alignresenv.py index 59cf2a6..198aa00 100644 --- a/script/alignresenv.py +++ b/script/alignresenv.py @@ -56,7 +56,7 @@ def selector(a: Atom): sup = SuperImposer() tmppdb = Path(tempfile.mkstemp(suffix=".pdb")[1]) - with uPDB.PDBIOhelper(str(tmppdb)) as pdbio: + with uPDB.PDBIOhelper(tmppdb) as pdbio: for model in tqdm(struct, desc="[align res. env.]", disable=not verbose): # print(struct, i) diff --git a/script/test_generate_msmd_system.py b/script/test_generate_msmd_system.py index 50b3f26..6cdc3b5 100644 --- a/script/test_generate_msmd_system.py +++ b/script/test_generate_msmd_system.py @@ -63,11 +63,11 @@ def test_create_frcmod(self): def test_invalid_atomtype(self): with self.assertRaises(ValueError): - _create_frcmod(self.mol2file, "INVALID") # type: ignore + _create_frcmod(self.mol2file, Path("INVALID")) # type: ignore def test_mol2file_does_not_exist(self): with self.assertRaises(FileNotFoundError): - _create_frcmod("INVALID", self.atomtype) # type: ignore + _create_frcmod(Path("INVALID"), self.atomtype) # type: ignore def test_invalid_mol2file(self): with self.assertRaises(ValueError): diff --git a/script/utilities/Bio/PDB.py b/script/utilities/Bio/PDB.py index 28dedad..300a379 100644 --- a/script/utilities/Bio/PDB.py +++ b/script/utilities/Bio/PDB.py @@ -137,7 +137,7 @@ class PDBIOhelper: これは1つずつモデルを保存していきます。 """ - def __init__(self, path: str): + def __init__(self, path: Path): self.path = expandpath(path) self.open() @@ -449,7 +449,7 @@ def concatenate_structures(structs: List[Structure]) -> Structure: with tempfile.NamedTemporaryFile("w") as f: - out_helper = PDBIOhelper(f.name) + out_helper = PDBIOhelper(Path(f.name)) for struct in structs: out_helper.save(struct) out_helper.close() diff --git a/script/utilities/Bio/test_PDB.py b/script/utilities/Bio/test_PDB.py index cce913d..1cee8dc 100644 --- a/script/utilities/Bio/test_PDB.py +++ b/script/utilities/Bio/test_PDB.py @@ -30,7 +30,7 @@ def test_pdb_io_helper(self): reader = PDB.MultiModelPDBReader(str(self.pdbfile)) models = [model for model in reader] tmp_output_pdb = Path(tempfile.mkstemp(suffix=".pdb")[1]) - writer = PDB.PDBIOhelper(str(tmp_output_pdb)) + writer = PDB.PDBIOhelper(tmp_output_pdb) for model in models: writer.save(model) writer.close() # This is important, otherwise the unittest will hang up diff --git a/script/utilities/util.py b/script/utilities/util.py index d97a9da..20b2b67 100644 --- a/script/utilities/util.py +++ b/script/utilities/util.py @@ -7,8 +7,17 @@ import yaml from .logger import logger +import collections.abc +def update_dict(d: dict, u: dict): + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = update_dict(d.get(k, {}), v) # type: ignore + else: + d[k] = v + return d + def getabsolutepath(path: Path) -> Path: """ Get absolute path from relative path @@ -67,29 +76,6 @@ def expand_index(ind_info: Union[str, int]): return ret -def set_default(setting: dict) -> None: - """ - Set default values (in place) for some fields in setting - """ - if "multiprocessing" not in setting["general"]: - setting["general"]["multiprocessing"] = -1 - if "valid_dist" not in setting["map"]: - setting["map"]["valid_dist"] = 5 - if "threshold" not in setting["probe_profile"]: - setting["probe_profile"]["threshold"] = 0.001 - if "env_dist" not in setting["probe_profile"]: - setting["probe_profile"]["resenv"]["env_dist"] = 4 - - if "dt" not in setting["exprorer_msmd"]["general"]: - setting["exprorer_msmd"]["general"]["dt"] = 0.002 - if "temperature" not in setting["exprorer_msmd"]["general"]: - setting["exprorer_msmd"]["general"]["temperature"] = 300 - if "pressure" not in setting["exprorer_msmd"]["general"]: - setting["exprorer_msmd"]["general"]["pressure"] = 1.0 - if "num_process_per_gpu" not in setting["general"]: - setting["general"]["num_process_per_gpu"] = 1 - - def ensure_compatibility_v1_1(setting: dict): """ Ensure compatibility with exprorer_msmd v1.1 @@ -113,6 +99,8 @@ def parse_yaml(yamlpath: Path) -> dict: setting: dict = { "general": { "workdir": Path(""), + "multiprocessing": -1, + "num_process_per_gpu": 1, }, "input": { "protein": { @@ -123,13 +111,21 @@ def parse_yaml(yamlpath: Path) -> dict: }, }, "exprorer_msmd": { - "general": {}, + "general": { + "dt": 0.002, + "temperature": 300, + "pressure": 1.0, + }, }, "map": { "snapshot": "", + "valid_dist": 5.0, }, "probe_profile": { - "resenv": {}, + "threshold": 0.001, + "resenv": { + "env_dist": 4.0, + }, }, } YAML_PATH = getabsolutepath(yamlpath) @@ -143,10 +139,9 @@ def parse_yaml(yamlpath: Path) -> dict: with YAML_PATH.open() as fin: yaml_dict: dict = yaml.safe_load(fin) # type: ignore if yaml_dict is not None: - setting.update(yaml_dict) + update_dict(setting, yaml_dict) ensure_compatibility_v1_1(setting) - set_default(setting) if "mol2" not in setting["input"]["probe"] or setting["input"]["probe"]["mol2"] is None: setting["input"]["probe"]["mol2"] = setting["input"]["probe"]["cid"] + ".mol2" @@ -197,5 +192,4 @@ def parse_yaml(yamlpath: Path) -> dict: setting["input"]["protein"]["ssbond"] = [] setting["general"]["yaml"] = YAML_PATH - return setting