Skip to content

Commit 77841bf

Browse files
Add calc kwargs input (#220)
1 parent e7c25f2 commit 77841bf

File tree

20 files changed

+316
-181
lines changed

20 files changed

+316
-181
lines changed

aiida_mlip/calculations/base.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import aiida.common.folders
99
from aiida.engine import CalcJob, CalcJobProcessSpec
1010
import aiida.engine.processes
11-
from aiida.orm import SinglefileData, Str, StructureData
11+
from aiida.orm import Dict, SinglefileData, Str, StructureData
1212
from ase.io import read, write
1313

1414
from aiida_mlip.data.config import JanusConfigfile
@@ -134,12 +134,6 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
134134
required=False,
135135
help="The input structure.",
136136
)
137-
spec.input(
138-
"precision",
139-
valid_type=Str,
140-
required=False,
141-
help="Precision level for calculation",
142-
)
143137
spec.input(
144138
"device",
145139
valid_type=Str,
@@ -171,6 +165,13 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
171165
help="Filename to which the content of stdout of the scheduler is written.",
172166
)
173167

168+
spec.input(
169+
"calc_kwargs",
170+
valid_type=Dict,
171+
required=False,
172+
help="Keyword arguments to pass to selected calculator.",
173+
)
174+
174175
spec.input(
175176
"config",
176177
valid_type=JanusConfigfile,
@@ -223,13 +224,14 @@ def prepare_for_submission(
223224
}
224225

225226
# The inputs are saved in the node, but we want their value as a string
226-
if "precision" in self.inputs:
227-
precision = (self.inputs.precision).value
228-
cmd_line["calc-kwargs"] = {"default_dtype": precision}
229227
if "device" in self.inputs:
230-
device = (self.inputs.device).value
228+
device = self.inputs.device.value
231229
cmd_line["device"] = device
232230

231+
# Set calc_kwargs from dict and specific stored inputs
232+
if "calc_kwargs" in self.inputs:
233+
cmd_line["calc-kwargs"] = self.inputs.calc_kwargs.get_dict()
234+
233235
# Define architecture from model if model is given,
234236
# otherwise get architecture from inputs and download default model
235237
self._add_arch_to_cmdline(cmd_line)
@@ -310,7 +312,6 @@ def _add_model_to_cmdline(
310312
dict
311313
Dictionary containing the cmd line keys updated with the model.
312314
"""
313-
model_path = None
314315
if "model" in self.inputs:
315316
# Raise error if model is None (different than model not given as input)
316317
if self.inputs.model is None:
@@ -322,5 +323,4 @@ def _add_model_to_cmdline(
322323
):
323324
shutil.copyfileobj(source, target)
324325

325-
model_path = "mlff.model"
326-
cmd_line.setdefault("calc-kwargs", {})["model"] = model_path
326+
cmd_line["model"] = "mlff.model"

docs/source/user_guide/calculations.rst

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ The inputs can be grouped into a dictionary:
4040
"architecture": Str,
4141
"structure": StructureData,
4242
"model": ModelData,
43-
"precision": Str,
4443
"device": Str,
44+
"calc_kwargs": Dict,
4545
}
4646
SinglePointCalculation = CalculationFactory("mlip.sp")
4747
submit(SinglePointCalculation, **inputs)
@@ -58,6 +58,9 @@ The config file contains the parameters in yaml format:
5858
device: "cpu"
5959
struct: "path/to/structure.cif"
6060
model: "path/to/model.model"
61+
calc_kwargs:
62+
dispersion: True
63+
6164
6265
And it is used as shown below. Note that some parameters, which are specific to AiiDA, need to be given individually.
6366

@@ -101,8 +104,6 @@ In this case the structure used is going to be "path/to/structure2.xyz" rather
101104
Refer to the API documentation for additional parameters that can be passed.
102105
Some parameters are not required and don't have a default value set in aiida-mlip. In that case the default values will be the same as `janus <https://stfc.github.io/janus-core/>`_
103106
The only default parameters defined in aiida-mlip are the names of the input and output files, as they do not affect the results of the calculation itself, and are needed in AiiDA to parse the results.
104-
For example in the code above the parameter "precision" is never defined, neither in the config nor in the run_get_node function.
105-
The parameter will default to the janus default, which is "float64"
106107

107108

108109
Submission
@@ -120,7 +121,7 @@ They will be converted to AiiDA data types by the script itself.
120121

121122
.. code-block:: python
122123
123-
verdi run submit_singlepoint.py "janus@localhost" --structure "path/to/structure" --model "path/to/model" --precision "float64" --device "cpu"
124+
verdi run submit_singlepoint.py "janus@localhost" --structure "path/to/structure" --model "path/to/model" --device "cpu"
124125
125126
The submit_using_config.py script can be used to facilitate submission using a config file.
126127

@@ -157,7 +158,7 @@ They will be converted to AiiDA data types by the script itself.
157158

158159
.. code-block:: python
159160
160-
verdi run submit_geomopt.py "janus@localhost" --structure "path/to/structure" --model "path/to/model" --precision "float64" --device "cpu"
161+
verdi run submit_geomopt.py "janus@localhost" --structure "path/to/structure" --model "path/to/model" --device "cpu"
161162
162163
163164

docs/source/user_guide/tutorial.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,15 @@ The other inputs can be set up as AiiDA Str. There is a default for every input
5353

5454
.. code-block:: python
5555
56-
from aiida.orm import Bool, Float, Str
56+
from aiida.orm import Bool, Dict, Float, Str
57+
5758
inputs = {
5859
"code": code,
5960
"model": model,
6061
"structure": structure,
6162
"architecture": Str(model.architecture),
62-
"precision": Str("float64"),
6363
"device": Str("cpu"),
64+
"calc_kwargs": Dict({}),
6465
"max_force": Float(0.1), # Specific to geometry optimisation: convergence criteria
6566
"opt_cell_lengths": Bool(False), # Specific to geometry optimisation
6667
"opt_cell_fully": Bool(True), # Specific to geometry optimisation: to optimise the cell
@@ -159,13 +160,13 @@ The calculation can also be interacted with through verdi cli. Use `verdi proces
159160
Inputs PK Type
160161
--------------- ---- -------------
161162
architecture 1121 Str
163+
calc_kwargs 1122 Dict
162164
code 2 InstalledCode
163165
device 1123 Str
164166
opt_cell_fully 1126 Bool
165167
log_filename 1128 Str
166168
max_force 1124 Float
167169
model 1119 ModelData
168-
precision 1122 Str
169170
structure 1120 StructureData
170171
traj 1129 Str
171172
opt_cell_lengths 1125 Bool

examples/calculations/submit_descriptors.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
from __future__ import annotations
44

5+
import ast
6+
57
from aiida.common import NotExistent
68
from aiida.engine import run_get_node
7-
from aiida.orm import Bool, Str, load_code
9+
from aiida.orm import Bool, Dict, Str, load_code
810
from aiida.plugins import CalculationFactory
911
import click
1012

@@ -39,13 +41,16 @@ def descriptors(params: dict) -> None:
3941
"arch": Str(params["arch"]),
4042
"struct": structure,
4143
"model": model,
42-
"precision": Str(params["precision"]),
4344
"device": Str(params["device"]),
4445
"invariants_only": Bool(params["invariants_only"]),
4546
"calc_per_element": Bool(params["calc_per_element"]),
4647
"calc_per_atom": Bool(params["calc_per_atom"]),
4748
}
4849

50+
# Only calc_kwargs add if set
51+
if params["calc_kwargs"]:
52+
inputs["calc_kwargs"] = Dict(params["calc_kwargs"])
53+
4954
# Run calculation
5055
result, node = run_get_node(DescriptorsCalc, **inputs)
5156
print(f"Printing results from calculation: {result}")
@@ -77,7 +82,10 @@ def descriptors(params: dict) -> None:
7782
"--device", default="cpu", type=str, help="Device to run calculations on."
7883
)
7984
@click.option(
80-
"--precision", default="float64", type=str, help="Chosen level of precision."
85+
"--calc-kwargs",
86+
default="{}",
87+
type=str,
88+
help="Keyword arguments to pass to calculator.",
8189
)
8290
@click.option(
8391
"--invariants-only",
@@ -103,12 +111,14 @@ def cli(
103111
model,
104112
arch,
105113
device,
106-
precision,
114+
calc_kwargs,
107115
invariants_only,
108116
calc_per_element,
109117
calc_per_atom,
110118
) -> None:
111119
"""Click interface."""
120+
calc_kwargs = ast.literal_eval(calc_kwargs)
121+
112122
try:
113123
code = load_code(codelabel)
114124
except NotExistent as exc:
@@ -121,7 +131,7 @@ def cli(
121131
"model": model,
122132
"arch": arch,
123133
"device": device,
124-
"precision": precision,
134+
"calc_kwargs": calc_kwargs,
125135
"invariants_only": invariants_only,
126136
"calc_per_element": calc_per_element,
127137
"calc_per_atom": calc_per_atom,

examples/calculations/submit_geomopt.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
from __future__ import annotations
44

5+
import ast
6+
57
from aiida.common import NotExistent
68
from aiida.engine import run_get_node
7-
from aiida.orm import Bool, Float, Int, Str, load_code
9+
from aiida.orm import Bool, Dict, Float, Int, Str, load_code
810
from aiida.plugins import CalculationFactory
911
import click
1012

@@ -39,7 +41,6 @@ def geomopt(params: dict) -> None:
3941
"arch": Str(params["arch"]),
4042
"struct": structure,
4143
"model": model,
42-
"precision": Str(params["precision"]),
4344
"device": Str(params["device"]),
4445
"fmax": Float(params["fmax"]),
4546
"opt_cell_lengths": Bool(params["opt_cell_lengths"]),
@@ -48,6 +49,10 @@ def geomopt(params: dict) -> None:
4849
"steps": Int(params["steps"]),
4950
}
5051

52+
# Only calc_kwargs add if set
53+
if params["calc_kwargs"]:
54+
inputs["calc_kwargs"] = Dict(params["calc_kwargs"])
55+
5156
# Run calculation
5257
result, node = run_get_node(GeomoptCalc, **inputs)
5358
print(f"Printing results from calculation: {result}")
@@ -79,7 +84,10 @@ def geomopt(params: dict) -> None:
7984
"--device", default="cpu", type=str, help="Device to run calculations on."
8085
)
8186
@click.option(
82-
"--precision", default="float64", type=str, help="Chosen level of precision."
87+
"--calc-kwargs",
88+
default="{}",
89+
type=str,
90+
help="Keyword arguments to pass to calculator.",
8391
)
8492
@click.option("--fmax", default=0.1, type=float, help="Maximum force for convergence.")
8593
@click.option(
@@ -103,13 +111,15 @@ def cli(
103111
model,
104112
arch,
105113
device,
106-
precision,
114+
calc_kwargs,
107115
fmax,
108116
opt_cell_lengths,
109117
opt_cell_fully,
110118
steps,
111119
) -> None:
112120
"""Click interface."""
121+
calc_kwargs = ast.literal_eval(calc_kwargs)
122+
113123
try:
114124
code = load_code(codelabel)
115125
except NotExistent as exc:
@@ -122,7 +132,7 @@ def cli(
122132
"model": model,
123133
"arch": arch,
124134
"device": device,
125-
"precision": precision,
135+
"calc_kwargs": calc_kwargs,
126136
"fmax": fmax,
127137
"opt_cell_lengths": opt_cell_lengths,
128138
"opt_cell_fully": opt_cell_fully,

examples/calculations/submit_md.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,15 @@ def md(params: dict) -> None:
4141
"arch": Str(params["arch"]),
4242
"struct": structure,
4343
"model": model,
44-
"precision": Str(params["precision"]),
4544
"device": Str(params["device"]),
4645
"ensemble": Str(params["ensemble"]),
4746
"md_kwargs": Dict(params["md_dict"]),
4847
}
4948

49+
# Only calc_kwargs add if set
50+
if params["calc_kwargs"]:
51+
inputs["calc_kwargs"] = Dict(params["calc_kwargs"])
52+
5053
# Run calculation
5154
result, node = run_get_node(MDCalc, **inputs)
5255
print(f"Printing results from calculation: {result}")
@@ -78,7 +81,10 @@ def md(params: dict) -> None:
7881
"--device", default="cpu", type=str, help="Device to run calculations on."
7982
)
8083
@click.option(
81-
"--precision", default="float64", type=str, help="Chosen level of precision."
84+
"--calc-kwargs",
85+
default="{}",
86+
type=str,
87+
help="Keyword arguments to pass to calculator.",
8288
)
8389
@click.option(
8490
"--ensemble", default="nve", type=str, help="Name of thermodynamic ensemble."
@@ -90,10 +96,12 @@ def md(params: dict) -> None:
9096
help="String containing a dictionary with other md parameters",
9197
)
9298
def cli(
93-
codelabel, struct, model, arch, device, precision, ensemble, md_dict_str
99+
codelabel, struct, model, arch, device, calc_kwargs, ensemble, md_dict_str
94100
) -> None:
95101
"""Click interface."""
96102
md_dict = ast.literal_eval(md_dict_str)
103+
calc_kwargs = ast.literal_eval(calc_kwargs)
104+
97105
try:
98106
code = load_code(codelabel)
99107
except NotExistent as exc:
@@ -106,7 +114,7 @@ def cli(
106114
"model": model,
107115
"arch": arch,
108116
"device": device,
109-
"precision": precision,
117+
"calc_kwargs": calc_kwargs,
110118
"ensemble": ensemble,
111119
"md_dict": md_dict,
112120
}

examples/calculations/submit_singlepoint.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
from __future__ import annotations
44

5+
import ast
6+
57
from aiida.common import NotExistent
68
from aiida.engine import run_get_node
7-
from aiida.orm import Str, load_code
9+
from aiida.orm import Dict, Str, load_code
810
from aiida.plugins import CalculationFactory
911
import click
1012

@@ -39,10 +41,13 @@ def singlepoint(params: dict) -> None:
3941
"arch": Str(params["arch"]),
4042
"struct": structure,
4143
"model": model,
42-
"precision": Str(params["precision"]),
4344
"device": Str(params["device"]),
4445
}
4546

47+
# Only calc_kwargs add if set
48+
if params["calc_kwargs"]:
49+
inputs["calc_kwargs"] = Dict(params["calc_kwargs"])
50+
4651
# Run calculation
4752
result, node = run_get_node(SinglepointCalc, **inputs)
4853
print(f"Printing results from calculation: {result}")
@@ -74,10 +79,15 @@ def singlepoint(params: dict) -> None:
7479
"--device", default="cpu", type=str, help="Device to run calculations on."
7580
)
7681
@click.option(
77-
"--precision", default="float64", type=str, help="Chosen level of precision."
82+
"--calc-kwargs",
83+
default="{}",
84+
type=str,
85+
help="Keyword arguments to pass to calculator.",
7886
)
79-
def cli(codelabel, struct, model, arch, device, precision) -> None:
87+
def cli(codelabel, struct, model, arch, device, calc_kwargs) -> None:
8088
"""Click interface."""
89+
calc_kwargs = ast.literal_eval(calc_kwargs)
90+
8191
try:
8292
code = load_code(codelabel)
8393
except NotExistent as exc:
@@ -90,7 +100,7 @@ def cli(codelabel, struct, model, arch, device, precision) -> None:
90100
"model": model,
91101
"arch": arch,
92102
"device": device,
93-
"precision": precision,
103+
"calc_kwargs": calc_kwargs,
94104
}
95105

96106
# Submit single point

0 commit comments

Comments
 (0)