Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
# pylint: disable=undefined-variable
"""Module with the implementations of the common bands workchain for Quantum ESPRESSO."""
from .generator import *
from .workchain import *

__all__ = (generator.__all__ + workchain.__all__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
"""Implementation of the ``CommonBandsInputGenerator`` for Quantum ESPRESSO."""

from aiida import engine, orm

from aiida_common_workflows.generators import CodeType

from ..generator import CommonBandsInputGenerator

__all__ = ('QuantumEspressoCommonBandsInputGenerator',)


class QuantumEspressoCommonBandsInputGenerator(CommonBandsInputGenerator):
"""Input generator for the ``QuantumEspressoCommonBandsWorkChain``"""

@classmethod
def define(cls, spec):
"""Define the specification of the input generator.

The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
"""
super().define(spec)
spec.inputs['engines']['bands']['code'].valid_type = CodeType('quantumespresso.pw')

def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
"""Construct a process builder based on the provided keyword arguments.

The keyword arguments will have been validated against the input generator specification.
"""
# pylint: disable=too-many-branches,too-many-statements,too-many-locals
engines = kwargs.get('engines', None)
parent_folder = kwargs['parent_folder']
bands_kpoints = kwargs['bands_kpoints']

builder = self.process_class.get_builder()

# Inputs of the `pw` calcjob are based of the inputs of the `parent_folder` creator's inputs
parent_calc = parent_folder.creator
if parent_calc.process_type != 'aiida.calculations:quantumespresso.pw':
raise ValueError('The `parent_folder` has not been created by a `PwCalculation`.')
pw_builder = parent_calc.get_builder_restart()
pw_builder.pop('kpoints')
builder.pw = pw_builder
builder.pw.parent_folder = parent_folder

# Use the explicit `kpoints` list from the inputs
builder.kpoints = bands_kpoints

# Update the `calculation` type to `bands`
parameters = builder.pw.parameters.get_dict()
parameters['CONTROL']['calculation'] = 'bands'
builder.pw.parameters = orm.Dict(dict=parameters)

# Update the structure in case we have one in output, i.e. the `parent_calc` optimized the structure
if 'output_structure' in parent_calc.outputs:
builder.pw.structure = parent_calc.outputs.output_structure

# Update the code and computational options only if the `engines` input is provided
if engines is None:
return builder

try:
bands_engine = engines['bands']
except KeyError:
raise ValueError('The `engines` dictionary must contain `bands` as a top-level key')
if 'code' in bands_engine:
code = bands_engine['code']
if isinstance(code, str):
code = orm.load_code(code)
builder.pw.code = code
if 'options' in bands_engine:
builder.pw.metadata.options = bands_engine['options']

return builder
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
"""Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO."""
from aiida.engine import calcfunction
from aiida.orm import Float
from aiida.plugins import WorkflowFactory

from ..workchain import CommonBandsWorkChain
from .generator import QuantumEspressoCommonBandsInputGenerator

__all__ = ('QuantumEspressoCommonBandsWorkChain',)


@calcfunction
def get_fermi_energy(output_parameters):
"""Extract the Fermi energy from the ``output_parameters`` of a ``PwBaseWorkChain``."""
return Float(output_parameters['fermi_energy'])


class QuantumEspressoCommonBandsWorkChain(CommonBandsWorkChain):
"""Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO."""

_process_class = WorkflowFactory('quantumespresso.pw.base')
_generator_class = QuantumEspressoCommonBandsInputGenerator

def convert_outputs(self):
"""Convert the outputs of the sub work chain to the common output specification."""
outputs = self.ctx.workchain.outputs

if 'output_band' not in outputs:
self.report('The `bands` PwBaseWorkChain does not have the `output_band` output.')
return self.exit_codes.ERROR_SUB_PROCESS_FAILED

self.out('bands', outputs.output_band)
self.out('fermi_energy', get_fermi_energy(outputs.output_parameters))
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ def define(cls, spec):
)
spec.inputs['electronic_type'].valid_type = ChoiceType((ElectronicType.METAL, ElectronicType.INSULATOR))
spec.inputs['engines']['relax']['code'].valid_type = CodeType('quantumespresso.pw')
spec.input(
'clean_workdir',
valid_type=orm.Bool,
default=lambda: orm.Bool(False),
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.'
)

def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
"""Construct a process builder based on the provided keyword arguments.
Expand All @@ -111,6 +117,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
threshold_forces = kwargs.get('threshold_forces', None)
threshold_stress = kwargs.get('threshold_stress', None)
reference_workchain = kwargs.get('reference_workchain', None)
clean_workdir = kwargs.get('clean_workdir')

if isinstance(electronic_type, str):
electronic_type = types.ElectronicType(electronic_type)
Expand Down Expand Up @@ -162,6 +169,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
spin_type=spin_type,
initial_magnetic_moments=initial_magnetic_moments,
)
builder.clean_workdir = clean_workdir

if threshold_forces is not None:
threshold = threshold_forces * CONSTANTS.bohr_to_ang / CONSTANTS.ry_to_ev
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,4 @@ def convert_outputs(self):
self.out('total_energy', total_energy)
self.out('forces', forces)
self.out('stress', stress)
self.out('remote_folder', outputs.remote_folder)
3 changes: 2 additions & 1 deletion setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
"common_workflows.relax.quantum_espresso = aiida_common_workflows.workflows.relax.quantum_espresso.workchain:QuantumEspressoCommonRelaxWorkChain",
"common_workflows.relax.siesta = aiida_common_workflows.workflows.relax.siesta.workchain:SiestaCommonRelaxWorkChain",
"common_workflows.relax.vasp = aiida_common_workflows.workflows.relax.vasp.workchain:VaspCommonRelaxWorkChain",
"common_workflows.bands.siesta = aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain"
"common_workflows.bands.siesta = aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain",
"common_workflows.bands.quantum_espresso = aiida_common_workflows.workflows.bands.quantum_espresso.workchain:QuantumEspressoCommonBandsWorkChain"
]
},
"license": "MIT License",
Expand Down