-
Notifications
You must be signed in to change notification settings - Fork 38
Add support for fixed total magnetization and optional features #344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
t-reents
merged 6 commits into
aiidateam:master
from
t-reents:feat/fixed-total-magnetization
Nov 4, 2025
Merged
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
302804e
Add support for fixed total magnetization and optional features
t-reents 9d9e167
Fix logig for optional features
t-reents 41d27ff
Fix AiiDA-QE updated protocols
t-reents 3ee6052
Adressing review comments: Improved error handling and fixing of smal…
t-reents 66c3ed2
Add fermi energy to outputs. This is necessary in order to calculate …
t-reents 234333d
Fix ReadtheDocs
t-reents File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,13 @@ | ||
| """Module with resources for input generators for workflows.""" | ||
| from .generator import InputGenerator | ||
| from .ports import ChoiceType, CodeType, InputGeneratorPort | ||
| from .ports import ChoiceType, CodeType, InputGeneratorPort, OptionalFeatureType | ||
| from .spec import InputGeneratorSpec | ||
|
|
||
| __all__ = ('InputGenerator', 'InputGeneratorPort', 'ChoiceType', 'CodeType', 'InputGeneratorSpec') | ||
| __all__ = ( | ||
| 'InputGenerator', | ||
| 'InputGeneratorPort', | ||
| 'ChoiceType', | ||
| 'CodeType', | ||
| 'OptionalFeatureType', | ||
| 'InputGeneratorSpec', | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 47 additions & 0 deletions
47
src/aiida_common_workflows/generators/optional_features.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| from enum import Enum | ||
| from typing import FrozenSet, Iterable | ||
|
|
||
|
|
||
| class OptionalFeature(str, Enum): | ||
| """Enumeration of optional features that an input generator can support.""" | ||
|
|
||
|
|
||
| class OptionalFeatureMixin: | ||
| """Mixin class for input generators that support optional features.""" | ||
|
|
||
| _optional_features: FrozenSet[OptionalFeature] = frozenset() | ||
| _supported_optional_features: FrozenSet[OptionalFeature] = frozenset() | ||
|
|
||
| @classmethod | ||
| def get_optional_features(cls) -> set[OptionalFeature]: | ||
| """Return the set of optional features for this common workflow.""" | ||
| return set(cls._optional_features) | ||
|
|
||
| @classmethod | ||
| def get_supported_optional_features(cls) -> set[OptionalFeature]: | ||
| """Return the set of optional features supported by this implementation.""" | ||
| return set(cls._supported_optional_features) | ||
|
|
||
| @classmethod | ||
| def supports_feature(cls, feature: OptionalFeature) -> bool: | ||
| """Return whether the given feature is supported by this implementation.""" | ||
| return feature in cls._supported_optional_features | ||
|
|
||
| @classmethod | ||
| def validate_optional_features( | ||
| cls, | ||
| requested_features: Iterable[str], | ||
| ) -> None: | ||
| """Validate that all requested features are supported by this implementation. | ||
|
|
||
| :param requested_features: an iterable of requested features. | ||
| :raises InputValidationError: if any of the requested features is not supported. | ||
| """ | ||
| unsupported_features = set(requested_features) - { | ||
| feature.value for feature in cls.get_supported_optional_features() | ||
| } | ||
| if unsupported_features: | ||
| return ( | ||
| f'the following optional features are not supported by `{cls.__name__}`: ' | ||
| f'{", ".join(unsupported_features)}' | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| """Equation of state workflow that can use any code plugin implementing the common relax workflow.""" | ||
| import inspect | ||
|
|
||
| from aiida import orm | ||
| from aiida.common import exceptions | ||
| from aiida.engine import WorkChain, append_ | ||
| from aiida.plugins import WorkflowFactory | ||
|
|
||
| from aiida_common_workflows.workflows.relax.generator import ElectronicType, RelaxType, SpinType | ||
| from aiida_common_workflows.workflows.relax.workchain import CommonRelaxWorkChain | ||
|
|
||
|
|
||
| def validate_inputs(value, _): | ||
t-reents marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Validate the entire input namespace.""" | ||
|
|
||
| # Validate that the provided ``generator_inputs`` are valid for the associated input generator. | ||
| process_class = WorkflowFactory(value['sub_process_class']) | ||
| generator = process_class.get_input_generator() | ||
|
|
||
| try: | ||
| generator.get_builder(structure=value['structure'], **value['generator_inputs']) | ||
| except Exception as exc: | ||
| return f'`{generator.__class__.__name__}.get_builder()` fails for the provided `generator_inputs`: {exc}' | ||
|
|
||
|
|
||
| def validate_sub_process_class(value, _): | ||
| """Validate the sub process class.""" | ||
| try: | ||
| process_class = WorkflowFactory(value) | ||
| except exceptions.EntryPointError: | ||
| return f'`{value}` is not a valid or registered workflow entry point.' | ||
|
|
||
| if not inspect.isclass(process_class) or not issubclass(process_class, CommonRelaxWorkChain): | ||
| return f'`{value}` is not a subclass of the `CommonRelaxWorkChain` common workflow.' | ||
|
|
||
|
|
||
| def validate_total_magnetizations(value, _): | ||
| """Validate the `fixed_total_magnetizations` input.""" | ||
| if value and len(value) < 3: | ||
| return 'need at least 3 total magnetizations.' | ||
| if not all(isinstance(m, (int, float)) for m in value): | ||
| return 'all total magnetizations must be numbers (int or float).' | ||
|
|
||
|
|
||
| def validate_relax_type(value, _): | ||
| """Validate the `generator_inputs.relax_type` input.""" | ||
| if value is not None and isinstance(value, str): | ||
| value = RelaxType(value) | ||
|
|
||
| if value not in [RelaxType.NONE, RelaxType.POSITIONS, RelaxType.SHAPE, RelaxType.POSITIONS_SHAPE]: | ||
| return '`generator_inputs.relax_type`. Equation of state and relaxation with variable volume not compatible.' | ||
|
|
||
|
|
||
| class EnergyMagnetizationWorkChain(WorkChain): | ||
| """Workflow to compute the energy vs magnetization curve for a given crystal structure.""" | ||
|
|
||
| @classmethod | ||
| def define(cls, spec): | ||
| # yapf: disable | ||
| super().define(spec) | ||
| spec.input('structure', valid_type=orm.StructureData, help='The structure at equilibrium volume.') | ||
| spec.input('fixed_total_magnetizations', valid_type=orm.List, required=True, | ||
| validator=validate_total_magnetizations, serializer=orm.to_aiida_type, | ||
| help='The list of fixed total magnetizations to be calculated for the structure.') | ||
| spec.input_namespace('generator_inputs', | ||
| help='The inputs that will be passed to the input generator of the specified `sub_process`.') | ||
| spec.input('generator_inputs.engines', valid_type=dict, non_db=True) | ||
| spec.input('generator_inputs.protocol', valid_type=str, non_db=True, | ||
| help='The protocol to use when determining the workchain inputs.') | ||
| spec.input('generator_inputs.relax_type', | ||
| valid_type=(RelaxType, str), non_db=True, validator=validate_relax_type, | ||
| help='The type of relaxation to perform.') | ||
| spec.input('generator_inputs.spin_type', valid_type=(SpinType, str), required=False, non_db=True, | ||
| help='The type of spin for the calculation.') | ||
| spec.input('generator_inputs.electronic_type', valid_type=(ElectronicType, str), required=False, non_db=True, | ||
| help='The type of electronics (insulator/metal) for the calculation.') | ||
| spec.input( | ||
| 'generator_inputs.fixed_total_cell_magnetization', valid_type=(list, tuple), | ||
t-reents marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| required=False, non_db=True, | ||
| help='List containing the total magnetizations per cell to be calculated.' | ||
| ) | ||
| spec.input('generator_inputs.threshold_forces', valid_type=float, required=False, non_db=True, | ||
| help='Target threshold for the forces in eV/Å.') | ||
| spec.input('generator_inputs.threshold_stress', valid_type=float, required=False, non_db=True, | ||
| help='Target threshold for the stress in eV/Å^3.') | ||
| spec.input_namespace('sub_process', dynamic=True, populate_defaults=False) | ||
| spec.input('sub_process_class', non_db=True, validator=validate_sub_process_class) | ||
| spec.inputs.validator = validate_inputs | ||
| spec.outline( | ||
| cls.run_em, | ||
| cls.inspect_em, | ||
| ) | ||
|
|
||
| spec.output_namespace('total_energies', valid_type=orm.Float, | ||
| help='The computed total energy of the relaxed structures at each scaling factor.') | ||
| spec.output_namespace('total_magnetizations', valid_type=orm.Float, | ||
| help='The fixed total magnetizations that were evaluated.') | ||
| spec.exit_code(400, 'ERROR_SUB_PROCESS_FAILED', | ||
| message='At least one of the `{cls}` sub processes did not finish successfully.') | ||
|
|
||
| def get_sub_workchain_builder(self, total_magnetization): | ||
| """Return the builder for the relax workchain.""" | ||
| structure = self.inputs.structure | ||
| process_class = WorkflowFactory(self.inputs.sub_process_class) | ||
|
|
||
| base_inputs = {'structure': structure, 'fixed_total_cell_magnetization': total_magnetization} | ||
|
|
||
| builder = process_class.get_input_generator().get_builder(**base_inputs, **self.inputs.generator_inputs) | ||
| builder._merge(**self.inputs.get('sub_process', {})) | ||
|
|
||
| return builder | ||
|
|
||
| def run_em(self): | ||
| """Run the sub process at each scale factor to compute the structure volume and total energy.""" | ||
| for total_magnetization in self.inputs.fixed_total_magnetizations: | ||
| builder = self.get_sub_workchain_builder(total_magnetization) | ||
| self.report( | ||
| f'submitting `{builder.process_class.__name__}` for total_magnetization `{total_magnetization}`' | ||
| ) | ||
| self.to_context(children=append_(self.submit(builder))) | ||
|
|
||
| def inspect_em(self): | ||
| """Inspect all children workflows to make sure they finished successfully.""" | ||
| if any(not child.is_finished_ok for child in self.ctx.children): | ||
| return self.exit_codes.ERROR_SUB_PROCESS_FAILED.format(cls=self.inputs.sub_process_class) | ||
|
|
||
| for index, child in enumerate(self.ctx.children): | ||
| energy = child.outputs.total_energy | ||
| total_magnetization = child.outputs.total_magnetization | ||
|
|
||
| self.report(f'Image {index}: total_magnetization={total_magnetization}, total energy={energy.value}') | ||
|
|
||
| self.out(f'total_energies.{index}', energy) | ||
| self.out(f'total_magnetizations.{index}', total_magnetization) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.