Skip to content

Commit 8fb0496

Browse files
committed
Implement common bands work chain for Quantum ESPRESSO
1 parent 69468a7 commit 8fb0496

File tree

6 files changed

+120
-1
lines changed

6 files changed

+120
-1
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -*- coding: utf-8 -*-
2+
# pylint: disable=undefined-variable
3+
"""Module with the implementations of the common bands workchain for Quantum ESPRESSO."""
4+
from .generator import *
5+
from .workchain import *
6+
7+
__all__ = (generator.__all__ + workchain.__all__)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# -*- coding: utf-8 -*-
2+
"""Implementation of the ``CommonBandsInputGenerator`` for Quantum ESPRESSO."""
3+
4+
from aiida import engine, orm
5+
from aiida.common import LinkType
6+
7+
from aiida_common_workflows.generators import CodeType
8+
9+
from ..generator import CommonBandsInputGenerator
10+
11+
__all__ = ('QuantumEspressoCommonBandsInputGenerator',)
12+
13+
14+
class QuantumEspressoCommonBandsInputGenerator(CommonBandsInputGenerator):
15+
"""Input generator for the ``QuantumEspressoCommonBandsWorkChain``"""
16+
17+
@classmethod
18+
def define(cls, spec):
19+
"""Define the specification of the input generator.
20+
21+
The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
22+
"""
23+
super().define(spec)
24+
spec.inputs['engines']['bands']['code'].valid_type = CodeType('quantumespresso.pw')
25+
26+
def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
27+
"""Construct a process builder based on the provided keyword arguments.
28+
29+
The keyword arguments will have been validated against the input generator specification.
30+
"""
31+
# pylint: disable=too-many-branches,too-many-statements,too-many-locals
32+
engines = kwargs.get('engines', None)
33+
parent_folder = kwargs['parent_folder']
34+
bands_kpoints = kwargs['bands_kpoints']
35+
36+
# Find the `PwCalculation` that created the `parent_folder` and obtain the restart builder.
37+
parent_calc = parent_folder.get_incoming(link_type=LinkType.CREATE).one().node
38+
if parent_calc.process_type != 'aiida.calculations:quantumespresso.pw':
39+
raise ValueError('The `parent_folder` has not been created by a `PwCalculation`.')
40+
builder_calc = parent_calc.get_builder_restart()
41+
42+
builder_common_bands_wc = self.process_class.get_builder()
43+
builder_calc.pop('kpoints')
44+
builder_common_bands_wc.pw = builder_calc
45+
parameters = builder_common_bands_wc.pw.parameters.get_dict()
46+
parameters['CONTROL']['calculation'] = 'bands'
47+
builder_common_bands_wc.pw.parameters = orm.Dict(dict=parameters)
48+
builder_common_bands_wc.kpoints = bands_kpoints
49+
builder_common_bands_wc.pw.parent_folder = parent_folder
50+
51+
# Update the structure in case we have one in output, i.e. the `parent_calc` optimized the structure
52+
if 'output_structure' in parent_calc.outputs:
53+
builder_common_bands_wc.pw.structure = parent_calc.outputs.output_structure
54+
55+
# Update the code and computational options if `engines` is specified
56+
try:
57+
bands_engine = engines['bands']
58+
except KeyError:
59+
raise ValueError('The `engines` dictionary must contain `bands` as a top-level key')
60+
if 'code' in bands_engine:
61+
code = engines['bands']['code']
62+
if isinstance(code, str):
63+
code = orm.load_code(code)
64+
builder_common_bands_wc.pw.code = code
65+
if 'options' in bands_engine:
66+
builder_common_bands_wc.pw.metadata.options = engines['bands']['options']
67+
68+
return builder_common_bands_wc
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# -*- coding: utf-8 -*-
2+
"""Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO."""
3+
from aiida.engine import calcfunction
4+
from aiida.orm import Float
5+
from aiida.plugins import WorkflowFactory
6+
7+
from ..workchain import CommonBandsWorkChain
8+
from .generator import QuantumEspressoCommonBandsInputGenerator
9+
10+
__all__ = ('QuantumEspressoCommonBandsWorkChain',)
11+
12+
13+
@calcfunction
14+
def get_fermi_energy(output_parameters):
15+
"""Extract the Fermi energy from the ``output_parameters`` of a ``PwBaseWorkChain``."""
16+
return Float(output_parameters['fermi_energy'])
17+
18+
19+
class QuantumEspressoCommonBandsWorkChain(CommonBandsWorkChain):
20+
"""Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO."""
21+
22+
_process_class = WorkflowFactory('quantumespresso.pw.base')
23+
_generator_class = QuantumEspressoCommonBandsInputGenerator
24+
25+
def convert_outputs(self):
26+
"""Convert the outputs of the sub work chain to the common output specification."""
27+
outputs = self.ctx.workchain.outputs
28+
29+
if 'output_band' not in outputs:
30+
self.report('The `bands` PwBaseWorkChain does not have the `output_band` output.')
31+
return self.exit_codes.ERROR_SUB_PROCESS_FAILED
32+
33+
self.out('bands', outputs.output_band)
34+
self.out('fermi_energy', get_fermi_energy(outputs.output_parameters))

aiida_common_workflows/workflows/relax/quantum_espresso/generator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ def define(cls, spec):
9191
)
9292
spec.inputs['electronic_type'].valid_type = ChoiceType((ElectronicType.METAL, ElectronicType.INSULATOR))
9393
spec.inputs['engines']['relax']['code'].valid_type = CodeType('quantumespresso.pw')
94+
spec.input(
95+
'clean_workdir',
96+
valid_type=orm.Bool,
97+
default=lambda: orm.Bool(False),
98+
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.'
99+
)
94100

95101
def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
96102
"""Construct a process builder based on the provided keyword arguments.
@@ -111,6 +117,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
111117
threshold_forces = kwargs.get('threshold_forces', None)
112118
threshold_stress = kwargs.get('threshold_stress', None)
113119
reference_workchain = kwargs.get('reference_workchain', None)
120+
clean_workdir = kwargs.get('clean_workdir')
114121

115122
if isinstance(electronic_type, str):
116123
electronic_type = types.ElectronicType(electronic_type)
@@ -162,6 +169,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
162169
spin_type=spin_type,
163170
initial_magnetic_moments=initial_magnetic_moments,
164171
)
172+
builder.clean_workdir = clean_workdir
165173

166174
if threshold_forces is not None:
167175
threshold = threshold_forces * CONSTANTS.bohr_to_ang / CONSTANTS.ry_to_ev

aiida_common_workflows/workflows/relax/quantum_espresso/workchain.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,4 @@ def convert_outputs(self):
6363
self.out('total_energy', total_energy)
6464
self.out('forces', forces)
6565
self.out('stress', stress)
66+
self.out('remote_folder', outputs.remote_folder)

setup.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@
7272
"common_workflows.relax.quantum_espresso = aiida_common_workflows.workflows.relax.quantum_espresso.workchain:QuantumEspressoCommonRelaxWorkChain",
7373
"common_workflows.relax.siesta = aiida_common_workflows.workflows.relax.siesta.workchain:SiestaCommonRelaxWorkChain",
7474
"common_workflows.relax.vasp = aiida_common_workflows.workflows.relax.vasp.workchain:VaspCommonRelaxWorkChain",
75-
"common_workflows.bands.siesta = aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain"
75+
"common_workflows.bands.siesta = aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain",
76+
"common_workflows.bands.quantum_espresso = aiida_common_workflows.workflows.bands.quantum_espresso.workchain:QuantumEspressoCommonBandsWorkChain"
7677
]
7778
},
7879
"license": "MIT License",

0 commit comments

Comments
 (0)