Skip to content

Commit 6f370bb

Browse files
Make command line args conversion a function (#225)
Co-authored-by: Elliott Kasoar <[email protected]>
1 parent 50173d8 commit 6f370bb

File tree

7 files changed

+77
-44
lines changed

7 files changed

+77
-44
lines changed

aiida_mlip/calculations/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from aiida_mlip.data.config import JanusConfigfile
1515
from aiida_mlip.data.model import ModelData
16+
from aiida_mlip.helpers.converters import kwarg_to_param
1617

1718

1819
def validate_inputs(
@@ -261,10 +262,7 @@ def prepare_for_submission(
261262
codeinfo = datastructures.CodeInfo()
262263

263264
# Initialize cmdline_params with a placeholder "calculation" command
264-
codeinfo.cmdline_params = ["calculation"]
265-
266-
for flag, value in cmd_line.items():
267-
codeinfo.cmdline_params += [f"--{flag}", str(value)]
265+
codeinfo.cmdline_params = ["calculation", *kwarg_to_param(cmd_line)]
268266

269267
# Node where the code is saved
270268
codeinfo.code_uuid = self.inputs.code.uuid

aiida_mlip/calculations/descriptors.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from aiida.orm import Bool
1010

1111
from aiida_mlip.calculations.singlepoint import Singlepoint
12+
from aiida_mlip.helpers.converters import kwarg_to_param
1213

1314

1415
class Descriptors(Singlepoint): # numpydoc ignore=PR01
@@ -89,20 +90,17 @@ def prepare_for_submission(
8990

9091
# Adding command line params for when we run janus
9192
# descriptors is overwriting the placeholder "calculation" from the base.py file
92-
codeinfo.cmdline_params[0] = "descriptors"
9393

9494
cmdline_options = {
9595
key.replace("_", "-"): getattr(self.inputs, key).value
9696
for key in ("invariants_only", "calc_per_element", "calc_per_atom")
9797
if key in self.inputs
9898
}
9999

100-
for flag, value in cmdline_options.items():
101-
if isinstance(value, bool):
102-
# Add boolean flags without value if True
103-
if value:
104-
codeinfo.cmdline_params.append(f"--{flag}")
105-
else:
106-
codeinfo.cmdline_params += [f"--{flag}", value]
100+
codeinfo.cmdline_params = [
101+
"descriptors",
102+
*codeinfo.cmdline_params[1:],
103+
*kwarg_to_param(cmdline_options),
104+
]
107105

108106
return calcinfo

aiida_mlip/calculations/geomopt.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from plumpy.utils import AttributesFrozendict
1919

2020
from aiida_mlip.calculations.singlepoint import Singlepoint
21+
from aiida_mlip.helpers.converters import kwarg_to_param
2122

2223

2324
class GeomOpt(Singlepoint): # numpydoc ignore=PR01
@@ -124,15 +125,11 @@ def prepare_for_submission(
124125

125126
# Adding command line params for when we run janus
126127
# 'geomopt' is overwriting the placeholder "calculation" from the base.py file
127-
codeinfo.cmdline_params[0] = "geomopt"
128-
129-
for flag, value in geom_opt_cmdline.items():
130-
if isinstance(value, bool):
131-
# Add boolean flags without value if True
132-
if value:
133-
codeinfo.cmdline_params.append(f"--{flag}")
134-
else:
135-
codeinfo.cmdline_params += [f"--{flag}", value]
128+
codeinfo.cmdline_params = [
129+
"geomopt",
130+
*codeinfo.cmdline_params[1:],
131+
*kwarg_to_param(geom_opt_cmdline),
132+
]
136133

137134
calcinfo.retrieve_list.append(minimize_kwargs["traj_kwargs"]["filename"])
138135

aiida_mlip/calculations/md.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from aiida.orm import Dict, SinglefileData, Str, StructureData, TrajectoryData
1010

1111
from aiida_mlip.calculations.base import BaseJanus
12+
from aiida_mlip.helpers.converters import kwarg_to_param
1213

1314

1415
class MD(BaseJanus): # numpydoc ignore=PR01
@@ -108,16 +109,13 @@ def prepare_for_submission(
108109
raise ValueError("'ensemble' not provided.")
109110

110111
# md is overwriting the placeholder "calculation" from the base.py file
111-
codeinfo.cmdline_params[0] = "md"
112-
113-
codeinfo.cmdline_params += ["--ensemble", ensemble]
114-
115-
for flag, value in md_dictionary.items():
116-
# Add boolean flags without value if True
117-
if isinstance(value, bool) and value:
118-
codeinfo.cmdline_params.append(f"--{flag}")
119-
else:
120-
codeinfo.cmdline_params += [f"--{flag}", value]
112+
codeinfo.cmdline_params = [
113+
"md",
114+
*codeinfo.cmdline_params[1:],
115+
"--ensemble",
116+
ensemble,
117+
*kwarg_to_param(md_dictionary),
118+
]
121119

122120
calcinfo.retrieve_list.append(md_dictionary["traj-file"])
123121
calcinfo.retrieve_list.append(md_dictionary["stats-file"])

aiida_mlip/calculations/singlepoint.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,15 @@ def prepare_for_submission(
9797

9898
# Adding command line params for when we run janus
9999
# singlepoint is overwriting the placeholder "calculation" from the base.py file
100-
codeinfo.cmdline_params[0] = "singlepoint"
101100

102101
# The inputs are saved in the node, but we want their value as a string
103102
xyz_filename = (self.inputs.out).value
104-
codeinfo.cmdline_params += ["--out", xyz_filename]
103+
codeinfo.cmdline_params = [
104+
"singlepoint",
105+
*codeinfo.cmdline_params[1:],
106+
"--out",
107+
xyz_filename,
108+
]
105109

106110
if "properties" in self.inputs:
107111
properties = self.inputs.properties.value

aiida_mlip/calculations/train.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from aiida_mlip.data.config import JanusConfigfile
1515
from aiida_mlip.data.model import ModelData
16+
from aiida_mlip.helpers.converters import kwarg_to_param
1617

1718

1819
def validate_inputs(
@@ -190,18 +191,16 @@ def prepare_for_submission(
190191
codeinfo = datastructures.CodeInfo()
191192

192193
# Initialize cmdline_params with train command
193-
codeinfo.cmdline_params = ["train"]
194194
# Create the rest of the command line
195-
cmd_line = {"mlip-config": config_copy}
196-
if self.inputs.fine_tune:
197-
cmd_line["fine-tune"] = None
198-
199-
# Add cmd line params to codeinfo
200-
for flag, value in cmd_line.items():
201-
if value is None:
202-
codeinfo.cmdline_params += [f"--{flag}"]
203-
else:
204-
codeinfo.cmdline_params += [f"--{flag}", str(value)]
195+
cmd_line = {
196+
"mlip-config": config_copy,
197+
"fine-tune": bool(self.inputs.fine_tune),
198+
}
199+
200+
codeinfo.cmdline_params = [
201+
"train",
202+
*kwarg_to_param(cmd_line),
203+
]
205204

206205
# Node where the code is saved
207206
codeinfo.code_uuid = self.inputs.code.uuid

aiida_mlip/helpers/converters.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from pathlib import Path
6+
from typing import Any
67

78
from aiida.orm import Bool, Dict, Str, StructureData, TrajectoryData, load_code
89
from ase.io import read
@@ -99,3 +100,41 @@ def convert_to_nodes(dictionary: dict, convert_all: bool = False) -> dict:
99100
continue
100101
new_dict[key] = value
101102
return new_dict
103+
104+
105+
def kwarg_to_param(params: dict[str, Any]) -> list[str]:
106+
"""
107+
Convert a dictionary of kwargs to a set of commandline flags.
108+
109+
Bools are converted as though ``store_true`` flag keys.
110+
111+
Parameters
112+
----------
113+
params : dict[str, Any]
114+
Dictionary of arguments to convert.
115+
116+
Returns
117+
-------
118+
list[str]
119+
Commandline arguments as flags.
120+
121+
Examples
122+
--------
123+
>>> kwarg_to_param({"name": "Geoff", "key": True})
124+
['--name', 'Geoff', '--key']
125+
>>> kwarg_to_param({"value": 6, "falsey": False})
126+
['--value', '6', '--no-falsey']
127+
"""
128+
cmdline_params = []
129+
130+
for key, val in params.items():
131+
key = key.replace("_", "-")
132+
match val:
133+
case bool() if val:
134+
cmdline_params.append(f"--{key}")
135+
case bool():
136+
cmdline_params.append(f"--no-{key}")
137+
case _:
138+
cmdline_params.extend((f"--{key}", str(val)))
139+
140+
return cmdline_params

0 commit comments

Comments
 (0)