Skip to content

Commit 166b8bc

Browse files
authored
Merge branch 'master' into feature/abacus
2 parents ecd5a58 + 8a70baf commit 166b8bc

File tree

13 files changed

+345
-19
lines changed

13 files changed

+345
-19
lines changed

.readthedocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ python:
1717

1818
sphinx:
1919
builder: html
20+
configuration: docs/source/conf.py
2021
fail_on_warning: true

docs/source/workflows/base/relax/index.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ Only ``structure`` and ``engines`` can be specified as a positional argument, al
171171
The default for this input is the Python value None and, in case of calculations with spin, the None value signals that the implementation should automatically decide an appropriate default initial magnetization.
172172
The implementation of such choice is code-dependent and described in the supplementary material of the `S. P. Huber et al., npj Comput. Mater. 7, 136 (2021)`_.
173173

174+
175+
* ``fixed_total_cell_magnetization``. (Type: Python None or a Python float).
176+
The total magnetization of the system for fixed spin moment calculations.
177+
Should be a float representing the total magnetization in Bohr magnetons (μB).
178+
174179
.. _relax-ref-wc:
175180

176181
* ``reference_workchain.`` (Type: a previously completed ``RelaxWorkChain``, performed with the same code as the ``RelaxWorkChain`` created by ``get_builder``).

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ requires-python = '>=3.9'
3333
[project.entry-points.'aiida.workflows']
3434
'common_workflows.bands.siesta' = 'aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain'
3535
'common_workflows.dissociation_curve' = 'aiida_common_workflows.workflows.dissociation:DissociationCurveWorkChain'
36+
'common_workflows.em' = 'aiida_common_workflows.workflows.em:EnergyMagnetizationWorkChain'
3637
'common_workflows.eos' = 'aiida_common_workflows.workflows.eos:EquationOfStateWorkChain'
3738
'common_workflows.relax.abacus' = 'aiida_common_workflows.workflows.relax.abacus.workchain:AbacusCommonRelaxWorkChain'
3839
'common_workflows.relax.abinit' = 'aiida_common_workflows.workflows.relax.abinit.workchain:AbinitCommonRelaxWorkChain'
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
"""Module with resources for input generators for workflows."""
22
from .generator import InputGenerator
3-
from .ports import ChoiceType, CodeType, InputGeneratorPort
3+
from .ports import ChoiceType, CodeType, InputGeneratorPort, OptionalFeatureType
44
from .spec import InputGeneratorSpec
55

6-
__all__ = ('InputGenerator', 'InputGeneratorPort', 'ChoiceType', 'CodeType', 'InputGeneratorSpec')
6+
__all__ = (
7+
'InputGenerator',
8+
'InputGeneratorPort',
9+
'ChoiceType',
10+
'CodeType',
11+
'OptionalFeatureType',
12+
'InputGeneratorSpec',
13+
)

src/aiida_common_workflows/generators/generator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from aiida import engine, orm
66

7+
from .optional_features import OptionalFeatureMixin
78
from .spec import InputGeneratorSpec
89

910
__all__ = ('InputGenerator',)
@@ -22,7 +23,7 @@ def recursively_check_stored_nodes(obj):
2223
return copy.deepcopy(obj)
2324

2425

25-
class InputGenerator(metaclass=abc.ABCMeta):
26+
class InputGenerator(OptionalFeatureMixin, metaclass=abc.ABCMeta):
2627
"""Base class for an input generator for a common workflow."""
2728

2829
_spec_cls: InputGeneratorSpec = InputGeneratorSpec
@@ -79,8 +80,14 @@ def get_builder(self, **kwargs) -> engine.ProcessBuilder:
7980

8081
processed_kwargs = self.spec().inputs.pre_process(copied_kwargs)
8182
serialized_kwargs = self.spec().inputs.serialize(processed_kwargs)
83+
84+
optional_features = {k for k, v in self.spec().inputs.ports.items() if getattr(v, 'optional', None)}
85+
optional_features_requested = {k for k, v in processed_kwargs.items() if k in optional_features}
86+
87+
validate_optional_features_error = self.validate_optional_features(optional_features_requested)
8288
validation_error = self.spec().inputs.validate(serialized_kwargs)
8389

90+
validation_error = validate_optional_features_error or validation_error
8491
if validation_error is not None:
8592
raise ValueError(validation_error)
8693

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from enum import Enum
2+
from typing import FrozenSet, Iterable
3+
4+
5+
class OptionalFeature(str, Enum):
6+
"""Enumeration of optional features that an input generator can support."""
7+
8+
9+
class OptionalFeatureMixin:
10+
"""Mixin class for input generators that support optional features."""
11+
12+
_optional_features: FrozenSet[OptionalFeature] = frozenset()
13+
_supported_optional_features: FrozenSet[OptionalFeature] = frozenset()
14+
15+
@classmethod
16+
def get_optional_features(cls) -> set[OptionalFeature]:
17+
"""Return the set of optional features for this common workflow."""
18+
return set(cls._optional_features)
19+
20+
@classmethod
21+
def get_supported_optional_features(cls) -> set[OptionalFeature]:
22+
"""Return the set of optional features supported by this implementation."""
23+
return set(cls._supported_optional_features)
24+
25+
@classmethod
26+
def supports_feature(cls, feature: OptionalFeature) -> bool:
27+
"""Return whether the given feature is supported by this implementation."""
28+
return feature in cls._supported_optional_features
29+
30+
@classmethod
31+
def validate_optional_features(
32+
cls,
33+
requested_features: Iterable[str],
34+
) -> None:
35+
"""Validate that all requested features are supported by this implementation.
36+
37+
:param requested_features: an iterable of requested features.
38+
:raises InputValidationError: if any of the requested features is not supported.
39+
"""
40+
unsupported_features = set(requested_features) - {
41+
feature.value for feature in cls.get_supported_optional_features()
42+
}
43+
if unsupported_features:
44+
return (
45+
f'the following optional features are not supported by `{cls.__name__}`: '
46+
f'{", ".join(unsupported_features)}'
47+
)

src/aiida_common_workflows/generators/ports.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from aiida.orm import Code
88
from plumpy.ports import UNSPECIFIED, Port, PortValidationError, breadcrumbs_to_port
99

10-
__all__ = ('ChoiceType', 'CodeType', 'InputGeneratorPort')
10+
__all__ = ('ChoiceType', 'CodeType', 'OptionalFeatureType', 'InputGeneratorPort')
1111

1212

1313
class CodeType:
@@ -34,11 +34,23 @@ def __init__(self, choices: t.Sequence[t.Any]):
3434
self.valid_type: tuple[t.Any] = valid_types if len(valid_types) > 1 else valid_types[0]
3535

3636

37+
class OptionalFeatureType:
38+
"""Class that can be used for the ``valid_type`` of a ``InputPort`` that can define optional features."""
39+
40+
def __init__(self, valid_type: t.Any):
41+
"""Construct a new instance.
42+
43+
:param valid_type: the valid type for the port.
44+
"""
45+
self.valid_type = valid_type
46+
47+
3748
class InputGeneratorPort(InputPort):
3849
"""Subclass of :class:`aiida.engine.InputPort` with support for choice types and value serializers."""
3950

4051
code_entry_point = None
4152
choices = None
53+
optional = None
4254

4355
def __init__(self, *args, valid_type=None, **kwargs) -> None:
4456
"""Construct a new instance and process the ``valid_type`` keyword if it is an instance of ``ChoiceType``."""
@@ -59,6 +71,10 @@ def valid_type(self, valid_type: t.Any | None) -> None:
5971
self.code_entry_point = valid_type.entry_point
6072
valid_type = valid_type.valid_type
6173

74+
if isinstance(valid_type, OptionalFeatureType):
75+
self.optional = True
76+
valid_type = valid_type.valid_type
77+
6278
self._valid_type = valid_type
6379

6480
def validate(self, value: t.Any, breadcrumbs: t.Sequence[str] = ()) -> PortValidationError | None:
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""Workflow to calculate the energy as a function of fixed total magnetization that can use any code
2+
plugin implementing the common relax workflow and supports fixed spin-moment calulations."""
3+
import inspect
4+
5+
from aiida import orm
6+
from aiida.common import exceptions
7+
from aiida.engine import WorkChain, append_
8+
from aiida.plugins import WorkflowFactory
9+
10+
from aiida_common_workflows.workflows.relax.generator import ElectronicType, OptionalRelaxFeatures, RelaxType, SpinType
11+
from aiida_common_workflows.workflows.relax.workchain import CommonRelaxWorkChain
12+
13+
14+
def validate_inputs(value, _):
15+
"""Validate the entire input namespace."""
16+
17+
# Validate that the provided ``generator_inputs`` are valid for the associated input generator.
18+
process_class = WorkflowFactory(value['sub_process_class'])
19+
generator = process_class.get_input_generator()
20+
21+
if not generator.supports_feature(OptionalRelaxFeatures.FIXED_MAGNETIZATION):
22+
return (
23+
f'The `{value["sub_process_class"]}` plugin does not support the '
24+
f'`{OptionalRelaxFeatures.FIXED_MAGNETIZATION}` optional feature required for this workflow.'
25+
)
26+
27+
try:
28+
generator.get_builder(structure=value['structure'], **value['generator_inputs'])
29+
except Exception as exc:
30+
return f'`{generator.__class__.__name__}.get_builder()` fails for the provided `generator_inputs`: {exc}'
31+
32+
33+
def validate_sub_process_class(value, _):
34+
"""Validate the sub process class."""
35+
try:
36+
process_class = WorkflowFactory(value)
37+
except exceptions.EntryPointError:
38+
return f'`{value}` is not a valid or registered workflow entry point.'
39+
40+
if not inspect.isclass(process_class) or not issubclass(process_class, CommonRelaxWorkChain):
41+
return f'`{value}` is not a subclass of the `CommonRelaxWorkChain` common workflow.'
42+
43+
44+
def validate_total_magnetizations(value, _):
45+
"""Validate the `fixed_total_magnetizations` input."""
46+
if value and len(value) < 3:
47+
return 'need at least 3 total magnetizations.'
48+
if not all(isinstance(m, (int, float)) for m in value):
49+
return 'all total magnetizations must be numbers (int or float).'
50+
51+
52+
def validate_relax_type(value, _):
53+
"""Validate the `generator_inputs.relax_type` input."""
54+
if value is not None and isinstance(value, str):
55+
value = RelaxType(value)
56+
57+
if value not in [RelaxType.NONE, RelaxType.POSITIONS, RelaxType.SHAPE, RelaxType.POSITIONS_SHAPE]:
58+
return '`generator_inputs.relax_type`. Equation of state and relaxation with variable volume not compatible.'
59+
60+
61+
class EnergyMagnetizationWorkChain(WorkChain):
62+
"""Workflow to compute the energy vs magnetization curve for a given crystal structure."""
63+
64+
@classmethod
65+
def define(cls, spec):
66+
# yapf: disable
67+
super().define(spec)
68+
spec.input('structure', valid_type=orm.StructureData, help='The structure at equilibrium volume.')
69+
spec.input('fixed_total_magnetizations', valid_type=orm.List, required=True,
70+
validator=validate_total_magnetizations, serializer=orm.to_aiida_type,
71+
help='The list of fixed total magnetizations to be calculated for the structure.')
72+
spec.input_namespace('generator_inputs',
73+
help='The inputs that will be passed to the input generator of the specified `sub_process`.')
74+
spec.input('generator_inputs.engines', valid_type=dict, non_db=True)
75+
spec.input('generator_inputs.protocol', valid_type=str, non_db=True,
76+
help='The protocol to use when determining the workchain inputs.')
77+
spec.input('generator_inputs.relax_type',
78+
valid_type=(RelaxType, str), non_db=True, validator=validate_relax_type,
79+
help='The type of relaxation to perform.')
80+
spec.input('generator_inputs.spin_type', valid_type=(SpinType, str), required=False, non_db=True,
81+
help='The type of spin for the calculation.')
82+
spec.input('generator_inputs.electronic_type', valid_type=(ElectronicType, str), required=False, non_db=True,
83+
help='The type of electronics (insulator/metal) for the calculation.')
84+
spec.input('generator_inputs.threshold_forces', valid_type=float, required=False, non_db=True,
85+
help='Target threshold for the forces in eV/Å.')
86+
spec.input('generator_inputs.threshold_stress', valid_type=float, required=False, non_db=True,
87+
help='Target threshold for the stress in eV/Å^3.')
88+
spec.input_namespace('sub_process', dynamic=True, populate_defaults=False)
89+
spec.input('sub_process_class', non_db=True, validator=validate_sub_process_class)
90+
spec.inputs.validator = validate_inputs
91+
92+
spec.outline(
93+
cls.run_em,
94+
cls.inspect_em,
95+
)
96+
97+
spec.output_namespace('total_energies', valid_type=orm.Float,
98+
help='The computed total energy of the relaxed structures at each scaling factor.')
99+
spec.output_namespace('total_magnetizations', valid_type=orm.Float,
100+
help='The fixed total magnetizations that were evaluated in mu_B.')
101+
spec.output_namespace('fermi_energies_up', valid_type=orm.Float,
102+
help=(
103+
'The fermi energies of the spin-up channel (at each fixed total magnetization). '
104+
'Can be used to compute an effective magnetic field. Otherwise, only meaningful '
105+
'in combination with the BandsData that will be added in the future.'
106+
)
107+
)
108+
spec.output_namespace('fermi_energies_down', valid_type=orm.Float,
109+
help=(
110+
'The fermi energies of the spin-down channel (at each fixed total magnetization). '
111+
'Can be used to compute an effective magnetic field. Otherwise, only meaningful '
112+
'in combination with the BandsData that will be added in the future.'
113+
)
114+
)
115+
spec.exit_code(400, 'ERROR_SUB_PROCESS_FAILED',
116+
message='At least one of the `{cls}` sub processes did not finish successfully.')
117+
118+
def get_sub_workchain_builder(self, total_magnetization):
119+
"""Return the builder for the relax workchain."""
120+
structure = self.inputs.structure
121+
process_class = WorkflowFactory(self.inputs.sub_process_class)
122+
123+
base_inputs = {'structure': structure, 'fixed_total_cell_magnetization': total_magnetization}
124+
125+
builder = process_class.get_input_generator().get_builder(**base_inputs, **self.inputs.generator_inputs)
126+
builder._merge(**self.inputs.get('sub_process', {}))
127+
128+
return builder
129+
130+
def run_em(self):
131+
"""Run the sub process at each scale factor to compute the structure volume and total energy."""
132+
for total_magnetization in self.inputs.fixed_total_magnetizations:
133+
builder = self.get_sub_workchain_builder(total_magnetization)
134+
self.report(
135+
f'submitting `{builder.process_class.__name__}` for total_magnetization `{total_magnetization}`'
136+
)
137+
self.to_context(children=append_(self.submit(builder)))
138+
139+
def inspect_em(self):
140+
"""Inspect all children workflows to make sure they finished successfully."""
141+
if any(not child.is_finished_ok for child in self.ctx.children):
142+
return self.exit_codes.ERROR_SUB_PROCESS_FAILED.format(cls=self.inputs.sub_process_class)
143+
144+
for index, child in enumerate(self.ctx.children):
145+
energy = child.outputs.total_energy
146+
total_magnetization = child.outputs.total_magnetization
147+
148+
fermi_energy_up = child.outputs.fermi_energy_up
149+
fermi_energy_down = child.outputs.fermi_energy_down
150+
151+
self.report(f'Image {index}: total_magnetization={total_magnetization}, total energy={energy.value}')
152+
153+
self.out(f'total_energies.{index}', energy)
154+
self.out(f'total_magnetizations.{index}', total_magnetization)
155+
self.out(f'fermi_energies_up.{index}', fermi_energy_up)
156+
self.out(f'fermi_energies_down.{index}', fermi_energy_down)

src/aiida_common_workflows/workflows/eos.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def define(cls, spec):
103103
help='The type of electronics (insulator/metal) for the calculation.')
104104
spec.input('generator_inputs.magnetization_per_site', valid_type=(list, tuple), required=False, non_db=True,
105105
help='List containing the initial magnetization per atomic site.')
106+
spec.input('generator_inputs.fixed_total_cell_magnetization', valid_type=float, required=False, non_db=True,
107+
help='The fixed total magnetization of the cell in Bohr magnetons (μB).')
106108
spec.input('generator_inputs.threshold_forces', valid_type=float, required=False, non_db=True,
107109
help='Target threshold for the forces in eV/Å.')
108110
spec.input('generator_inputs.threshold_stress', valid_type=float, required=False, non_db=True,

src/aiida_common_workflows/workflows/relax/generator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,33 @@
44
from aiida import orm, plugins
55

66
from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
7-
from aiida_common_workflows.generators import ChoiceType, InputGenerator
7+
from aiida_common_workflows.generators import ChoiceType, InputGenerator, OptionalFeatureType
8+
from aiida_common_workflows.generators.optional_features import OptionalFeature
89
from aiida_common_workflows.protocol import ProtocolRegistry
910

1011
__all__ = ('CommonRelaxInputGenerator',)
1112

1213

14+
def validate_inputs(value, _):
15+
"""Validate the entire input namespace."""
16+
# Validate mutual exclusivity of magnetization inputs.
17+
if value.get('magnetization_per_site') is not None and value.get('fixed_total_cell_magnetization') is not None:
18+
return 'the inputs `magnetization_per_site` and ' '`fixed_total_cell_magnetization` are mutually exclusive.'
19+
20+
21+
class OptionalRelaxFeatures(OptionalFeature):
22+
FIXED_MAGNETIZATION = 'fixed_total_cell_magnetization'
23+
24+
1325
class CommonRelaxInputGenerator(InputGenerator, ProtocolRegistry, metaclass=abc.ABCMeta):
1426
"""Input generator for the common relax workflow.
1527
1628
This class should be subclassed by implementations for specific quantum engines. After calling the super, they can
1729
modify the ports defined here in the base class as well as add additional custom ports.
1830
"""
1931

32+
_optional_features = frozenset(OptionalRelaxFeatures)
33+
2034
@classmethod
2135
def define(cls, spec):
2236
"""Define the specification of the input generator.
@@ -68,6 +82,14 @@ def define(cls, spec):
6882
'electrons, for the site. This also corresponds to the magnetization of the site in Bohr magnetons '
6983
'(μB).',
7084
)
85+
spec.input(
86+
'fixed_total_cell_magnetization',
87+
valid_type=OptionalFeatureType(float),
88+
required=False,
89+
non_db=True,
90+
help='The total magnetization of the system for fixed spin moment calculations. Should be '
91+
'a float representing the total magnetization in Bohr magnetons (μB).',
92+
)
7193
spec.input(
7294
'threshold_forces',
7395
valid_type=float,
@@ -114,3 +136,5 @@ def define(cls, spec):
114136
non_db=True,
115137
help='Options for the geometry optimization calculation jobs.',
116138
)
139+
140+
spec.inputs.validator = validate_inputs

0 commit comments

Comments
 (0)