Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions aiida_common_workflows/generators/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from aiida import engine
from aiida import orm
from ..protocol import ProtocolRegistry
from .spec import InputGeneratorSpec

__all__ = ('InputGenerator',)
Expand All @@ -24,7 +23,7 @@ def recursively_check_stored_nodes(obj):
return copy.deepcopy(obj)


class InputGenerator(ProtocolRegistry, metaclass=abc.ABCMeta):
class InputGenerator(metaclass=abc.ABCMeta):
"""Base class for an input generator for a common workflow."""

_spec_cls: InputGeneratorSpec = InputGeneratorSpec
Expand All @@ -50,15 +49,14 @@ def define(cls, spec):
The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
"""

def __init__(self, *args, **kwargs):
def __init__(self, process_class, **kwargs): #pylint: disable=unused-argument
"""Construct an instance of the input generator, validating the class attributes."""
super().__init__(*args, **kwargs)

def raise_invalid(message):
raise RuntimeError('invalid input generator `{}`: {}'.format(self.__class__.__name__, message))

try:
self.process_class = kwargs.pop('process_class')
self.process_class = process_class
except KeyError:
raise_invalid('required keyword argument `process_class` was not defined.')

Expand Down
20 changes: 17 additions & 3 deletions aiida_common_workflows/generators/ports.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# -*- coding: utf-8 -*-
"""Modules with resources to define specific port types for input generator specifications."""
import typing as t

from plumpy.ports import Port, PortValidationError, UNSPECIFIED, breadcrumbs_to_port
from aiida.engine import InputPort
from aiida.orm import Code
from plumpy.ports import Port, PortValidationError, UNSPECIFIED, breadcrumbs_to_port

__all__ = ('ChoiceType', 'CodeType', 'InputGeneratorPort')

Expand Down Expand Up @@ -39,10 +38,11 @@ class InputGeneratorPort(InputPort):
code_entry_point = None
choices = None

def __init__(self, *args, valid_type=None, **kwargs) -> None:
def __init__(self, *args, valid_type_in_wc=None, valid_type=None, **kwargs) -> None:
"""Construct a new instance and process the ``valid_type`` keyword if it is an instance of ``ChoiceType``."""
super().__init__(*args, **kwargs)
self.valid_type = valid_type
self.valid_type_in_wc = valid_type_in_wc

@Port.valid_type.setter
def valid_type(self, valid_type: t.Optional[t.Any]) -> None:
Expand Down Expand Up @@ -75,3 +75,17 @@ def validate(self, value: t.Any, breadcrumbs: t.Sequence[str] = ()) -> t.Optiona
message = f'`{value}` is not a valid choice. Valid choices are: {", ".join(choices)}'
breadcrumbs = (breadcrumb for breadcrumb in (*breadcrumbs, self.name) if breadcrumb)
return PortValidationError(message, breadcrumbs_to_port(breadcrumbs))

@property
def valid_type_in_wc(self) -> t.Optional[t.Type[t.Any]]:
"""Get the valid value type for this port if one is specified
:return: the value value type
"""
return self._valid_type_in_wc

@valid_type_in_wc.setter
def valid_type_in_wc(self, valid_type_in_wc: t.Optional[t.Type[t.Any]]) -> None:
"""Set the valid value type for this port
:param valid_type: the value valid type
"""
self._valid_type_in_wc = valid_type_in_wc
2 changes: 1 addition & 1 deletion aiida_common_workflows/protocol/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ProtocolRegistry:
_protocols = None
_default_protocol = None

def __init__(self, *_, **__):
def __init__(self):
"""Construct an instance of the protocol registry, validating the class attributes set by the sub class."""

def raise_invalid(message):
Expand Down
6 changes: 6 additions & 0 deletions aiida_common_workflows/workflows/bands/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
# pylint: disable=undefined-variable
"""Module with the base classes for the common bands workchains."""
from .generator import *

__all__ = (generator.__all__,)
62 changes: 62 additions & 0 deletions aiida_common_workflows/workflows/bands/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
"""Module with base input generator for the common bands workchains."""
import abc

from aiida import orm
from aiida import plugins

from aiida_common_workflows.generators import InputGenerator

__all__ = ('CommonBandsInputGenerator',)


class CommonBandsInputGenerator(InputGenerator, metaclass=abc.ABCMeta):
"""Input generator for the common bands workflow.

This class should be subclassed by implementations for specific quantum engines. After calling the super, they can
modify the ports defined here in the base class as well as add additional custom ports.
"""

@classmethod
def define(cls, spec):
"""Define the specification of the input generator.

The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
"""
super().define(spec)
spec.input(
'bands_kpoints',
valid_type=plugins.DataFactory('array.kpoints'),
required=True,
help='The full list of kpoints where to calculate bands, in (direct) coordinates of the reciprocal space.'
)
spec.input(
'parent_folder',
valid_type=orm.RemoteData,
required=True,
help='Parent folder that contains file to restart from (density matrix, wave-functions..). What is used '
'is plugin dependent.'
)
spec.input_namespace(
'engines',
required=False,
help='Inputs for the quantum engines',
)
spec.input_namespace(
'engines.bands',
required=False,
help='Inputs for the quantum engine performing the calculation of bands.',
)
spec.input(
'engines.bands.code',
valid_type=orm.Code,
serializer=orm.load_code,
required=False,
help='The code instance to use for the bands calculation.',
)
spec.input(
'engines.bands.options',
valid_type=dict,
required=False,
help='Options for the bands calculations jobs.',
)
7 changes: 7 additions & 0 deletions aiida_common_workflows/workflows/bands/siesta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
# pylint: disable=undefined-variable
"""Module with the implementations of the common bands workchain for Siesta."""
from .generator import *
from .workchain import *

__all__ = (generator.__all__ + workchain.__all__)
70 changes: 70 additions & 0 deletions aiida_common_workflows/workflows/bands/siesta/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
"""Implementation of `aiida_common_workflows.common.bands.generator.CommonBandsInputGenerator` for SIESTA."""

from aiida import engine
from aiida import orm
from aiida.common import LinkType
from aiida_common_workflows.generators import CodeType
from ..generator import CommonBandsInputGenerator

__all__ = ('SiestaCommonBandsInputGenerator',)


class SiestaCommonBandsInputGenerator(CommonBandsInputGenerator):
"""Generator of inputs for the SiestaCommonBandsWorkChain"""

@classmethod
def define(cls, spec):
"""Define the specification of the input generator.

The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
"""
super().define(spec)
spec.inputs['engines']['bands']['code'].valid_type = CodeType('siesta.siesta')

def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
"""Construct a process builder based on the provided keyword arguments.

The keyword arguments will have been validated against the input generator specification.
"""
# pylint: disable=too-many-branches,too-many-statements,too-many-locals
engines = kwargs.get('engines', None)
parent_folder = kwargs['parent_folder']
bands_kpoints = kwargs['bands_kpoints']

# From the parent folder, we retrieve the calculation that created it. Note
# that we are sure it exists (it wouldn't be the same for WorkChains). We then check
# that it is a SiestaCalculation and create the builder.
parent_siesta_calc = parent_folder.get_incoming(link_type=LinkType.CREATE).one().node
if parent_siesta_calc.process_type != 'aiida.calculations:siesta.siesta':
raise ValueError('The `parent_folder` has not been created by a SiestaCalculation')
builder_siesta_calc = parent_siesta_calc.get_builder_restart()

# Construct the builder of the `common_bands_wc` from the builder of a SiestaCalculation.
# Siesta specific: we have to eampty the metadata and put the resources in `options`.
builder_common_bands_wc = self.process_class.get_builder()
builder_common_bands_wc.options = orm.Dict(dict=builder_siesta_calc._data['metadata']['options']) # pylint: disable=protected-access
builder_siesta_calc._data['metadata'] = {} # pylint: disable=protected-access
for key, value in builder_siesta_calc._data.items(): # pylint: disable=protected-access
if value:
builder_common_bands_wc[key] = value

# Updated the structure (in case we have one in output)
if 'output_structure' in parent_siesta_calc.outputs:
builder_common_bands_wc.structure = parent_siesta_calc.outputs.output_structure

# Update the code and computational options if `engines` is specified
try:
engb = engines['bands']
except KeyError:
raise ValueError('The `engines` dictionaly must contain "bands" as outermost key')
if 'code' in engb:
builder_common_bands_wc.code = orm.load_code(engines['bands']['code'])
if 'options' in engb:
builder_common_bands_wc.options = orm.Dict(dict=engines['bands']['options'])

# Set the `bandskpoints` and the `parent_calc_folder` for restart
builder_common_bands_wc.bandskpoints = bands_kpoints
builder_common_bands_wc.parent_calc_folder = parent_folder

return builder_common_bands_wc
23 changes: 23 additions & 0 deletions aiida_common_workflows/workflows/bands/siesta/workchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
"""Implementation of `aiida_common_workflows.common.relax.workchain.CommonRelaxWorkChain` for SIESTA."""
from aiida.plugins import WorkflowFactory

from ..workchain import CommonBandsWorkChain
from .generator import SiestaCommonBandsInputGenerator

__all__ = ('SiestaCommonBandsWorkChain',)


class SiestaCommonBandsWorkChain(CommonBandsWorkChain):
"""Implementation of `aiida_common_workflows.common.bands.workchain.CommonBandsWorkChain` for SIESTA."""

_process_class = WorkflowFactory('siesta.base')
_generator_class = SiestaCommonBandsInputGenerator

def convert_outputs(self):
"""Convert the outputs of the sub workchain to the common output specification."""
self.report('Bands calculation concluded sucessfully, converting outputs')
if 'bands' not in self.ctx.workchain.outputs:
self.report('SiestaBaseWorkChain concluded without returning bands!')
return self.exit_codes.ERROR_SUB_PROCESS_FAILED
self.out('bands', self.ctx.workchain.outputs['bands'])
61 changes: 61 additions & 0 deletions aiida_common_workflows/workflows/bands/workchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""Module with base wrapper workchain for bands workchains."""
from abc import ABCMeta, abstractmethod

from aiida.engine import WorkChain, ToContext
from aiida.orm import BandsData

from .generator import CommonBandsInputGenerator

__all__ = ('CommonBandsWorkChain',)


class CommonBandsWorkChain(WorkChain, metaclass=ABCMeta):
"""Base workchain implementation that serves as a wrapper for bands workchains.

Subclasses should simply define the concrete plugin-specific bands workchain for the `_process_class` attribute
and implement the `convert_outputs` class method to map the plugin specific outputs to the output spec of this
common wrapper workchain.
"""

_process_class = None
_generator_class = None

@classmethod
def get_input_generator(cls) -> CommonBandsInputGenerator:
"""Return an instance of the input generator for this work chain.

:return: input generator
"""
return cls._generator_class(process_class=cls) # pylint: disable=not-callable

@classmethod
def define(cls, spec):
# yapf: disable
super().define(spec)
spec.expose_inputs(cls._process_class)
spec.outline(
cls.run_workchain,
cls.inspect_workchain,
cls.convert_outputs,
)
spec.output('bands', valid_type=BandsData, required=False, help='Energies in eV.')
spec.exit_code(400, 'ERROR_SUB_PROCESS_FAILED',
message='The `{cls}` workchain failed with exit status {exit_status}.')

def run_workchain(self):
"""Run the wrapped workchain."""
inputs = self.exposed_inputs(self._process_class)
return ToContext(workchain=self.submit(self._process_class, **inputs))

def inspect_workchain(self):
"""Inspect the terminated workchain."""
if not self.ctx.workchain.is_finished_ok:
cls = self._process_class.__name__
exit_status = self.ctx.workchain.exit_status
self.report('the `{}` failed with exit status {}'.format(cls, exit_status))
return self.exit_codes.ERROR_SUB_PROCESS_FAILED.format(cls=cls, exit_status=exit_status)

@abstractmethod
def convert_outputs(self):
"""Convert the outputs of the sub workchain to the common output specification."""
8 changes: 6 additions & 2 deletions aiida_common_workflows/workflows/relax/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

from aiida import orm
from aiida import plugins

from aiida_common_workflows.protocol import ProtocolRegistry
from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
from aiida_common_workflows.generators import ChoiceType, InputGenerator

__all__ = ('CommonRelaxInputGenerator',)


class CommonRelaxInputGenerator(InputGenerator, metaclass=abc.ABCMeta):
class CommonRelaxInputGenerator(InputGenerator, ProtocolRegistry, metaclass=abc.ABCMeta):
"""Input generator for the common relax workflow.

This class should be subclassed by implementations for specific quantum engines. After calling the super, they can
Expand All @@ -34,6 +34,7 @@ def define(cls, spec):
'protocol',
valid_type=ChoiceType(('fast', 'moderate', 'precise')),
default='moderate',
valid_type_in_wc=orm.Str,
help='The protocol to use for the automated input generation. This value indicates the level of precision '
'of the results and computational cost that the input parameters will be selected for.',
)
Expand Down Expand Up @@ -62,6 +63,7 @@ def define(cls, spec):
'magnetization_per_site',
valid_type=list,
required=False,
valid_type_in_wc=orm.List,
help='The initial magnetization of the system. Should be a list of floats, where each float represents the '
'spin polarization in units of electrons, meaning the difference between spin up and spin down '
'electrons, for the site. This also corresponds to the magnetization of the site in Bohr magnetons '
Expand All @@ -71,13 +73,15 @@ def define(cls, spec):
'threshold_forces',
valid_type=float,
required=False,
valid_type_in_wc=orm.Float,
help='A real positive number indicating the target threshold for the forces in eV/Å. If not specified, '
'the protocol specification will select an appropriate value.',
)
spec.input(
'threshold_stress',
valid_type=float,
required=False,
valid_type_in_wc=orm.Float,
help='A real positive number indicating the target threshold for the stress in eV/Å^3. If not specified, '
'the protocol specification will select an appropriate value.',
)
Expand Down
1 change: 1 addition & 0 deletions aiida_common_workflows/workflows/relax/siesta/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ def convert_outputs(self):
self.out('stress', res_dict['stress'])
if 'stot' in self.ctx.workchain.outputs.output_parameters.attributes:
self.out('total_magnetization', get_magn(self.ctx.workchain.outputs.output_parameters))
self.out('remote_folder', self.ctx.workchain.outputs.remote_folder)
4 changes: 3 additions & 1 deletion aiida_common_workflows/workflows/relax/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABCMeta, abstractmethod

from aiida.engine import WorkChain, ToContext
from aiida.orm import StructureData, ArrayData, TrajectoryData, Float
from aiida.orm import StructureData, ArrayData, TrajectoryData, Float, RemoteData

from .generator import CommonRelaxInputGenerator

Expand Down Expand Up @@ -51,6 +51,8 @@ def define(cls, spec):
help='Total energy in eV.')
spec.output('total_magnetization', valid_type=Float, required=False,
help='Total magnetization in Bohr magnetons.')
spec.output('remote_folder', valid_type=RemoteData, required=False,
help='Folder of the last run calculation.')
spec.exit_code(400, 'ERROR_SUB_PROCESS_FAILED',
message='The `{cls}` workchain failed with exit status {exit_status}.')

Expand Down
Loading