Skip to content

Commit 3120939

Browse files
Replace opt_kwargs with minimize_kwargs (#180)
Co-authored-by: ElliottKasoar <[email protected]>
1 parent 2633ef0 commit 3120939

File tree

9 files changed

+238
-25
lines changed

9 files changed

+238
-25
lines changed

aiida_mlip/calculations/geomopt.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
from aiida.common import datastructures
66
import aiida.common.folders
77
from aiida.engine import CalcJobProcessSpec
8-
import aiida.engine.processes
98
from aiida.orm import (
109
Bool,
1110
Dict,
1211
Float,
1312
Int,
1413
SinglefileData,
15-
Str,
1614
StructureData,
1715
TrajectoryData,
1816
)
17+
from aiida.orm.utils.managers import NodeLinksManager
18+
from plumpy.utils import AttributesFrozendict
1919

2020
from aiida_mlip.calculations.singlepoint import Singlepoint
2121

@@ -48,13 +48,6 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
4848
super().define(spec)
4949

5050
# Additional inputs for geometry optimisation
51-
spec.input(
52-
"traj",
53-
valid_type=Str,
54-
required=False,
55-
default=lambda: Str(cls.DEFAULT_TRAJ_FILE),
56-
help="Path to save optimisation frames to",
57-
)
5851
spec.input(
5952
"opt_cell_fully",
6053
valid_type=Bool,
@@ -82,10 +75,10 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
8275
)
8376

8477
spec.input(
85-
"opt_kwargs",
78+
"minimize_kwargs",
8679
valid_type=Dict,
8780
required=False,
88-
help="Other optimisation keywords",
81+
help="All other keyword arguments to pass to geometry optimizer",
8982
)
9083

9184
spec.inputs["metadata"]["options"]["parser_name"].default = "mlip.opt_parser"
@@ -114,17 +107,12 @@ def prepare_for_submission(
114107
calcinfo = super().prepare_for_submission(folder)
115108
codeinfo = calcinfo.codes_info[0]
116109

117-
minimize_kwargs = (
118-
f"{{'traj_kwargs': {{'filename': '{self.inputs.traj.value}'}}}}"
119-
)
110+
minimize_kwargs = self.set_minimize_kwargs(self.inputs)
120111

121112
geom_opt_cmdline = {
122113
"minimize-kwargs": minimize_kwargs,
123114
"write-traj": True,
124115
}
125-
if "opt_kwargs" in self.inputs:
126-
opt_kwargs = self.inputs.opt_kwargs.get_dict()
127-
geom_opt_cmdline["opt-kwargs"] = opt_kwargs
128116
if "opt_cell_fully" in self.inputs:
129117
geom_opt_cmdline["opt-cell-fully"] = self.inputs.opt_cell_fully.value
130118
if "opt_cell_lengths" in self.inputs:
@@ -146,6 +134,36 @@ def prepare_for_submission(
146134
else:
147135
codeinfo.cmdline_params += [f"--{flag}", value]
148136

149-
calcinfo.retrieve_list.append(self.inputs.traj.value)
137+
calcinfo.retrieve_list.append(minimize_kwargs["traj_kwargs"]["filename"])
150138

151139
return calcinfo
140+
141+
@classmethod
142+
def set_minimize_kwargs(
143+
cls, inputs: AttributesFrozendict | NodeLinksManager
144+
) -> dict[str, dict[str, str]]:
145+
"""
146+
Set minimize kwargs from CalcJob inputs.
147+
148+
Parameters
149+
----------
150+
inputs : x
151+
CalcJob inputs.
152+
153+
Returns
154+
-------
155+
dict[str, dict[str, str]]
156+
Set minimize_kwargs dict with trajectory filename extracted from `traj`,
157+
the config file, or set as the default.
158+
"""
159+
if "minimize_kwargs" in inputs:
160+
minimize_kwargs = inputs.minimize_kwargs.get_dict()
161+
elif "config" in inputs:
162+
minimize_kwargs = inputs.config.as_dictionary.get("minimize_kwargs", {})
163+
else:
164+
minimize_kwargs = {}
165+
166+
minimize_kwargs.setdefault("traj_kwargs", {})
167+
minimize_kwargs["traj_kwargs"].setdefault("filename", cls.DEFAULT_TRAJ_FILE)
168+
169+
return minimize_kwargs

aiida_mlip/parsers/opt_parser.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from aiida.orm.nodes.process.process import ProcessNode
1111
from aiida.plugins import CalculationFactory
1212

13+
from aiida_mlip.calculations.geomopt import GeomOpt
1314
from aiida_mlip.helpers.converters import xyz_to_aiida_traj
1415
from aiida_mlip.parsers.sp_parser import SPParser
1516

@@ -75,7 +76,9 @@ def parse(self, **kwargs) -> ExitCode:
7576
exit_code = super().parse(**kwargs)
7677

7778
if exit_code == ExitCode(0):
78-
traj_file = (self.node.inputs.traj).value
79+
traj_file = GeomOpt.set_minimize_kwargs(self.node.inputs)["traj_kwargs"][
80+
"filename"
81+
]
7982

8083
# Parse the trajectory file and save it as `SingleFileData`
8184
with self.retrieved.open(traj_file, "rb") as handle:

docs/source/user_guide/tutorial.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ The other inputs can be set up as AiiDA ``Str``. There is a default for every in
6565
"fmax": Float(0.1),
6666
"opt_cell_lengths": Bool(False),
6767
"opt_cell_fully": Bool(True),
68+
"minimize_kwargs": Dict({"filter_kwargs": {"constant_volume": True}}),
6869
"metadata": {"options": {"resources": {"num_machines": 1}}},
6970
}
7071
@@ -168,7 +169,7 @@ The calculation can also be interacted with through verdi cli. Use ``verdi proce
168169
max_force 1124 Float
169170
model 1119 ModelData
170171
structure 1120 StructureData
171-
traj 1129 Str
172+
minimize_kwargs 1129 Dict
172173
opt_cell_lengths 1125 Bool
173174
xyz_output_name 1127 Str
174175

examples/calculations/submit_geomopt.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,17 @@ def geomopt(params: dict) -> None:
4545
"fmax": Float(params["fmax"]),
4646
"opt_cell_lengths": Bool(params["opt_cell_lengths"]),
4747
"opt_cell_fully": Bool(params["opt_cell_fully"]),
48-
# "opt_kwargs": Dict({"restart": "rest.pkl"}),
4948
"steps": Int(params["steps"]),
5049
}
5150

5251
# Only calc_kwargs add if set
5352
if params["calc_kwargs"]:
5453
inputs["calc_kwargs"] = Dict(params["calc_kwargs"])
5554

55+
# Only minimize_kwargs add if set
56+
if params["minimize_kwargs"]:
57+
inputs["minimize_kwargs"] = Dict(params["minimize_kwargs"])
58+
5659
# Run calculation
5760
result, node = run_get_node(GeomoptCalc, **inputs)
5861
print(f"Printing results from calculation: {result}")
@@ -105,6 +108,15 @@ def geomopt(params: dict) -> None:
105108
@click.option(
106109
"--steps", default=1000, type=int, help="Maximum number of optimisation steps."
107110
)
111+
@click.option(
112+
"--minimize-kwargs",
113+
default="{}",
114+
type=str,
115+
help=(
116+
"Keyword arguments to pass to geometry optimizer, including 'opt_kwargs', "
117+
"'filter_kwargs', and 'traj_kwargs'."
118+
),
119+
)
108120
def cli(
109121
codelabel,
110122
struct,
@@ -116,9 +128,11 @@ def cli(
116128
opt_cell_lengths,
117129
opt_cell_fully,
118130
steps,
131+
minimize_kwargs,
119132
) -> None:
120133
"""Click interface."""
121134
calc_kwargs = ast.literal_eval(calc_kwargs)
135+
minimize_kwargs = ast.literal_eval(minimize_kwargs)
122136

123137
try:
124138
code = load_code(codelabel)
@@ -137,6 +151,7 @@ def cli(
137151
"opt_cell_lengths": opt_cell_lengths,
138152
"opt_cell_fully": opt_cell_fully,
139153
"steps": steps,
154+
"minimize_kwargs": minimize_kwargs,
140155
}
141156

142157
# Submit single point

examples/tutorials/calculations/geometry-optimisation.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
" \"fmax\": Float(0.1), \n",
154154
" \"opt_cell_lengths\": Bool(False), \n",
155155
" \"opt_cell_fully\": Bool(True),\n",
156+
" \"minimize_kwargs\": Dict({\"filter_kwargs\": {\"constant_volume\": True}}),\n",
156157
" \"metadata\": {\"options\": {\"resources\": {\"num_machines\": 1}}},\n",
157158
" }"
158159
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
arch: mace_mp
2+
struct: "NaCl.cif"
3+
calc_kwargs:
4+
dispersion: True
5+
minimize_kwargs:
6+
opt_kwargs:
7+
alpha: 100
8+
fmax: 0.1
9+
opt_cell_lengths: True
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
arch: mace_mp
2+
struct: "NaCl.cif"
3+
calc_kwargs:
4+
dispersion: True
5+
minimize_kwargs:
6+
filter_kwargs:
7+
constant_volume: True
8+
traj_kwargs:
9+
filename: "test-traj.xyz"
10+
fmax: 0.1
11+
opt_cell_lengths: True

0 commit comments

Comments
 (0)