Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ python:

sphinx:
builder: html
configuration: docs/source/conf.py
fail_on_warning: true
5 changes: 5 additions & 0 deletions docs/source/workflows/base/relax/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ Only ``structure`` and ``engines`` can be specified as a positional argument, al
The default for this input is the Python value None and, in case of calculations with spin, the None value signals that the implementation should automatically decide an appropriate default initial magnetization.
The implementation of such choice is code-dependent and described in the supplementary material of the `S. P. Huber et al., npj Comput. Mater. 7, 136 (2021)`_.


* ``fixed_total_cell_magnetization``. (Type: Python None or a Python float).
The total magnetization of the system for fixed spin moment calculations.
Should be a float representing the total magnetization in Bohr magnetons (μB).

.. _relax-ref-wc:

* ``reference_workchain.`` (Type: a previously completed ``RelaxWorkChain``, performed with the same code as the ``RelaxWorkChain`` created by ``get_builder``).
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ requires-python = '>=3.9'
[project.entry-points.'aiida.workflows']
'common_workflows.bands.siesta' = 'aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain'
'common_workflows.dissociation_curve' = 'aiida_common_workflows.workflows.dissociation:DissociationCurveWorkChain'
'common_workflows.em' = 'aiida_common_workflows.workflows.em:EnergyMagnetizationWorkChain'
'common_workflows.eos' = 'aiida_common_workflows.workflows.eos:EquationOfStateWorkChain'
'common_workflows.relax.abinit' = 'aiida_common_workflows.workflows.relax.abinit.workchain:AbinitCommonRelaxWorkChain'
'common_workflows.relax.bigdft' = 'aiida_common_workflows.workflows.relax.bigdft.workchain:BigDftCommonRelaxWorkChain'
Expand Down
11 changes: 9 additions & 2 deletions src/aiida_common_workflows/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Module with resources for input generators for workflows."""
from .generator import InputGenerator
from .ports import ChoiceType, CodeType, InputGeneratorPort
from .ports import ChoiceType, CodeType, InputGeneratorPort, OptionalFeatureType
from .spec import InputGeneratorSpec

__all__ = ('InputGenerator', 'InputGeneratorPort', 'ChoiceType', 'CodeType', 'InputGeneratorSpec')
__all__ = (
'InputGenerator',
'InputGeneratorPort',
'ChoiceType',
'CodeType',
'OptionalFeatureType',
'InputGeneratorSpec',
)
9 changes: 8 additions & 1 deletion src/aiida_common_workflows/generators/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from aiida import engine, orm

from .optional_features import OptionalFeatureMixin
from .spec import InputGeneratorSpec

__all__ = ('InputGenerator',)
Expand All @@ -22,7 +23,7 @@ def recursively_check_stored_nodes(obj):
return copy.deepcopy(obj)


class InputGenerator(metaclass=abc.ABCMeta):
class InputGenerator(OptionalFeatureMixin, metaclass=abc.ABCMeta):
"""Base class for an input generator for a common workflow."""

_spec_cls: InputGeneratorSpec = InputGeneratorSpec
Expand Down Expand Up @@ -79,8 +80,14 @@ def get_builder(self, **kwargs) -> engine.ProcessBuilder:

processed_kwargs = self.spec().inputs.pre_process(copied_kwargs)
serialized_kwargs = self.spec().inputs.serialize(processed_kwargs)

optional_features = {k for k, v in self.spec().inputs.ports.items() if getattr(v, 'optional', None)}
optional_features_requested = {k for k, v in processed_kwargs.items() if k in optional_features}

validate_optional_features_error = self.validate_optional_features(optional_features_requested)
validation_error = self.spec().inputs.validate(serialized_kwargs)

validation_error = validate_optional_features_error or validation_error
if validation_error is not None:
raise ValueError(validation_error)

Expand Down
47 changes: 47 additions & 0 deletions src/aiida_common_workflows/generators/optional_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from enum import Enum
from typing import FrozenSet, Iterable


class OptionalFeature(str, Enum):
"""Enumeration of optional features that an input generator can support."""


class OptionalFeatureMixin:
"""Mixin class for input generators that support optional features."""

_optional_features: FrozenSet[OptionalFeature] = frozenset()
_supported_optional_features: FrozenSet[OptionalFeature] = frozenset()

@classmethod
def get_optional_features(cls) -> set[OptionalFeature]:
"""Return the set of optional features for this common workflow."""
return set(cls._optional_features)

@classmethod
def get_supported_optional_features(cls) -> set[OptionalFeature]:
"""Return the set of optional features supported by this implementation."""
return set(cls._supported_optional_features)

@classmethod
def supports_feature(cls, feature: OptionalFeature) -> bool:
"""Return whether the given feature is supported by this implementation."""
return feature in cls._supported_optional_features

@classmethod
def validate_optional_features(
cls,
requested_features: Iterable[str],
) -> None:
"""Validate that all requested features are supported by this implementation.

:param requested_features: an iterable of requested features.
:raises InputValidationError: if any of the requested features is not supported.
"""
unsupported_features = set(requested_features) - {
feature.value for feature in cls.get_supported_optional_features()
}
if unsupported_features:
return (
f'the following optional features are not supported by `{cls.__name__}`: '
f'{", ".join(unsupported_features)}'
)
18 changes: 17 additions & 1 deletion src/aiida_common_workflows/generators/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from aiida.orm import Code
from plumpy.ports import UNSPECIFIED, Port, PortValidationError, breadcrumbs_to_port

__all__ = ('ChoiceType', 'CodeType', 'InputGeneratorPort')
__all__ = ('ChoiceType', 'CodeType', 'OptionalFeatureType', 'InputGeneratorPort')


class CodeType:
Expand All @@ -34,11 +34,23 @@ def __init__(self, choices: t.Sequence[t.Any]):
self.valid_type: tuple[t.Any] = valid_types if len(valid_types) > 1 else valid_types[0]


class OptionalFeatureType:
"""Class that can be used for the ``valid_type`` of a ``InputPort`` that can define optional features."""

def __init__(self, valid_type: t.Any):
"""Construct a new instance.

:param valid_type: the valid type for the port.
"""
self.valid_type = valid_type


class InputGeneratorPort(InputPort):
"""Subclass of :class:`aiida.engine.InputPort` with support for choice types and value serializers."""

code_entry_point = None
choices = None
optional = None

def __init__(self, *args, valid_type=None, **kwargs) -> None:
"""Construct a new instance and process the ``valid_type`` keyword if it is an instance of ``ChoiceType``."""
Expand All @@ -59,6 +71,10 @@ def valid_type(self, valid_type: t.Any | None) -> None:
self.code_entry_point = valid_type.entry_point
valid_type = valid_type.valid_type

if isinstance(valid_type, OptionalFeatureType):
self.optional = True
valid_type = valid_type.valid_type

self._valid_type = valid_type

def validate(self, value: t.Any, breadcrumbs: t.Sequence[str] = ()) -> PortValidationError | None:
Expand Down
156 changes: 156 additions & 0 deletions src/aiida_common_workflows/workflows/em.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Workflow to calculate the energy as a function of fixed total magnetization that can use any code
plugin implementing the common relax workflow and supports fixed spin-moment calulations."""
import inspect

from aiida import orm
from aiida.common import exceptions
from aiida.engine import WorkChain, append_
from aiida.plugins import WorkflowFactory

from aiida_common_workflows.workflows.relax.generator import ElectronicType, OptionalRelaxFeatures, RelaxType, SpinType
from aiida_common_workflows.workflows.relax.workchain import CommonRelaxWorkChain


def validate_inputs(value, _):
"""Validate the entire input namespace."""

# Validate that the provided ``generator_inputs`` are valid for the associated input generator.
process_class = WorkflowFactory(value['sub_process_class'])
generator = process_class.get_input_generator()

if not generator.supports_feature(OptionalRelaxFeatures.FIXED_MAGNETIZATION):
return (
f'The `{value["sub_process_class"]}` plugin does not support the '
f'`{OptionalRelaxFeatures.FIXED_MAGNETIZATION}` optional feature required for this workflow.'
)

try:
generator.get_builder(structure=value['structure'], **value['generator_inputs'])
except Exception as exc:
return f'`{generator.__class__.__name__}.get_builder()` fails for the provided `generator_inputs`: {exc}'


def validate_sub_process_class(value, _):
"""Validate the sub process class."""
try:
process_class = WorkflowFactory(value)
except exceptions.EntryPointError:
return f'`{value}` is not a valid or registered workflow entry point.'

if not inspect.isclass(process_class) or not issubclass(process_class, CommonRelaxWorkChain):
return f'`{value}` is not a subclass of the `CommonRelaxWorkChain` common workflow.'


def validate_total_magnetizations(value, _):
"""Validate the `fixed_total_magnetizations` input."""
if value and len(value) < 3:
return 'need at least 3 total magnetizations.'
if not all(isinstance(m, (int, float)) for m in value):
return 'all total magnetizations must be numbers (int or float).'


def validate_relax_type(value, _):
"""Validate the `generator_inputs.relax_type` input."""
if value is not None and isinstance(value, str):
value = RelaxType(value)

if value not in [RelaxType.NONE, RelaxType.POSITIONS, RelaxType.SHAPE, RelaxType.POSITIONS_SHAPE]:
return '`generator_inputs.relax_type`. Equation of state and relaxation with variable volume not compatible.'


class EnergyMagnetizationWorkChain(WorkChain):
"""Workflow to compute the energy vs magnetization curve for a given crystal structure."""

@classmethod
def define(cls, spec):
# yapf: disable
super().define(spec)
spec.input('structure', valid_type=orm.StructureData, help='The structure at equilibrium volume.')
spec.input('fixed_total_magnetizations', valid_type=orm.List, required=True,
validator=validate_total_magnetizations, serializer=orm.to_aiida_type,
help='The list of fixed total magnetizations to be calculated for the structure.')
spec.input_namespace('generator_inputs',
help='The inputs that will be passed to the input generator of the specified `sub_process`.')
spec.input('generator_inputs.engines', valid_type=dict, non_db=True)
spec.input('generator_inputs.protocol', valid_type=str, non_db=True,
help='The protocol to use when determining the workchain inputs.')
spec.input('generator_inputs.relax_type',
valid_type=(RelaxType, str), non_db=True, validator=validate_relax_type,
help='The type of relaxation to perform.')
spec.input('generator_inputs.spin_type', valid_type=(SpinType, str), required=False, non_db=True,
help='The type of spin for the calculation.')
spec.input('generator_inputs.electronic_type', valid_type=(ElectronicType, str), required=False, non_db=True,
help='The type of electronics (insulator/metal) for the calculation.')
spec.input('generator_inputs.threshold_forces', valid_type=float, required=False, non_db=True,
help='Target threshold for the forces in eV/Å.')
spec.input('generator_inputs.threshold_stress', valid_type=float, required=False, non_db=True,
help='Target threshold for the stress in eV/Å^3.')
spec.input_namespace('sub_process', dynamic=True, populate_defaults=False)
spec.input('sub_process_class', non_db=True, validator=validate_sub_process_class)
spec.inputs.validator = validate_inputs

spec.outline(
cls.run_em,
cls.inspect_em,
)

spec.output_namespace('total_energies', valid_type=orm.Float,
help='The computed total energy of the relaxed structures at each scaling factor.')
spec.output_namespace('total_magnetizations', valid_type=orm.Float,
help='The fixed total magnetizations that were evaluated in mu_B.')
spec.output_namespace('fermi_energies_up', valid_type=orm.Float,
help=(
'The fermi energies of the spin-up channel (at each fixed total magnetization). '
'Can be used to compute an effective magnetic field. Otherwise, only meaningful '
'in combination with the BandsData that will be added in the future.'
)
)
spec.output_namespace('fermi_energies_down', valid_type=orm.Float,
help=(
'The fermi energies of the spin-down channel (at each fixed total magnetization). '
'Can be used to compute an effective magnetic field. Otherwise, only meaningful '
'in combination with the BandsData that will be added in the future.'
)
)
spec.exit_code(400, 'ERROR_SUB_PROCESS_FAILED',
message='At least one of the `{cls}` sub processes did not finish successfully.')

def get_sub_workchain_builder(self, total_magnetization):
"""Return the builder for the relax workchain."""
structure = self.inputs.structure
process_class = WorkflowFactory(self.inputs.sub_process_class)

base_inputs = {'structure': structure, 'fixed_total_cell_magnetization': total_magnetization}

builder = process_class.get_input_generator().get_builder(**base_inputs, **self.inputs.generator_inputs)
builder._merge(**self.inputs.get('sub_process', {}))

return builder

def run_em(self):
"""Run the sub process at each scale factor to compute the structure volume and total energy."""
for total_magnetization in self.inputs.fixed_total_magnetizations:
builder = self.get_sub_workchain_builder(total_magnetization)
self.report(
f'submitting `{builder.process_class.__name__}` for total_magnetization `{total_magnetization}`'
)
self.to_context(children=append_(self.submit(builder)))

def inspect_em(self):
"""Inspect all children workflows to make sure they finished successfully."""
if any(not child.is_finished_ok for child in self.ctx.children):
return self.exit_codes.ERROR_SUB_PROCESS_FAILED.format(cls=self.inputs.sub_process_class)

for index, child in enumerate(self.ctx.children):
energy = child.outputs.total_energy
total_magnetization = child.outputs.total_magnetization

fermi_energy_up = child.outputs.fermi_energy_up
fermi_energy_down = child.outputs.fermi_energy_down

self.report(f'Image {index}: total_magnetization={total_magnetization}, total energy={energy.value}')

self.out(f'total_energies.{index}', energy)
self.out(f'total_magnetizations.{index}', total_magnetization)
self.out(f'fermi_energies_up.{index}', fermi_energy_up)
self.out(f'fermi_energies_down.{index}', fermi_energy_down)
2 changes: 2 additions & 0 deletions src/aiida_common_workflows/workflows/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def define(cls, spec):
help='The type of electronics (insulator/metal) for the calculation.')
spec.input('generator_inputs.magnetization_per_site', valid_type=(list, tuple), required=False, non_db=True,
help='List containing the initial magnetization per atomic site.')
spec.input('generator_inputs.fixed_total_cell_magnetization', valid_type=float, required=False, non_db=True,
help='The fixed total magnetization of the cell in Bohr magnetons (μB).')
spec.input('generator_inputs.threshold_forces', valid_type=float, required=False, non_db=True,
help='Target threshold for the forces in eV/Å.')
spec.input('generator_inputs.threshold_stress', valid_type=float, required=False, non_db=True,
Expand Down
26 changes: 25 additions & 1 deletion src/aiida_common_workflows/workflows/relax/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,33 @@
from aiida import orm, plugins

from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
from aiida_common_workflows.generators import ChoiceType, InputGenerator
from aiida_common_workflows.generators import ChoiceType, InputGenerator, OptionalFeatureType
from aiida_common_workflows.generators.optional_features import OptionalFeature
from aiida_common_workflows.protocol import ProtocolRegistry

__all__ = ('CommonRelaxInputGenerator',)


def validate_inputs(value, _):
"""Validate the entire input namespace."""
# Validate mutual exclusivity of magnetization inputs.
if value.get('magnetization_per_site') is not None and value.get('fixed_total_cell_magnetization') is not None:
return 'the inputs `magnetization_per_site` and ' '`fixed_total_cell_magnetization` are mutually exclusive.'


class OptionalRelaxFeatures(OptionalFeature):
FIXED_MAGNETIZATION = 'fixed_total_cell_magnetization'


class CommonRelaxInputGenerator(InputGenerator, ProtocolRegistry, metaclass=abc.ABCMeta):
"""Input generator for the common relax workflow.

This class should be subclassed by implementations for specific quantum engines. After calling the super, they can
modify the ports defined here in the base class as well as add additional custom ports.
"""

_optional_features = frozenset(OptionalRelaxFeatures)

@classmethod
def define(cls, spec):
"""Define the specification of the input generator.
Expand Down Expand Up @@ -68,6 +82,14 @@ def define(cls, spec):
'electrons, for the site. This also corresponds to the magnetization of the site in Bohr magnetons '
'(μB).',
)
spec.input(
'fixed_total_cell_magnetization',
valid_type=OptionalFeatureType(float),
required=False,
non_db=True,
help='The total magnetization of the system for fixed spin moment calculations. Should be '
'a float representing the total magnetization in Bohr magnetons (μB).',
)
spec.input(
'threshold_forces',
valid_type=float,
Expand Down Expand Up @@ -114,3 +136,5 @@ def define(cls, spec):
non_db=True,
help='Options for the geometry optimization calculation jobs.',
)

spec.inputs.validator = validate_inputs
Loading