Skip to content

Commit 39ed517

Browse files
committed
Implement common bands work chain for Quantum ESPRESSO
1 parent 69468a7 commit 39ed517

File tree

6 files changed

+114
-1
lines changed

6 files changed

+114
-1
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -*- coding: utf-8 -*-
2+
# pylint: disable=undefined-variable
3+
"""Module with the implementations of the common bands workchain for Quantum ESPRESSO."""
4+
from .generator import *
5+
from .workchain import *
6+
7+
__all__ = (generator.__all__ + workchain.__all__)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# -*- coding: utf-8 -*-
2+
"""Implementation of the ``CommonBandsInputGenerator`` for Quantum ESPRESSO."""
3+
4+
from aiida import engine
5+
from aiida import orm
6+
from aiida.common import LinkType
7+
from aiida_common_workflows.generators import CodeType
8+
from ..generator import CommonBandsInputGenerator
9+
10+
__all__ = ('QuantumEspressoCommonBandsInputGenerator',)
11+
12+
13+
class QuantumEspressoCommonBandsInputGenerator(CommonBandsInputGenerator):
14+
"""Input generator for the ``QuantumEspressoCommonBandsWorkChain``"""
15+
16+
@classmethod
17+
def define(cls, spec):
18+
"""Define the specification of the input generator.
19+
20+
The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
21+
"""
22+
super().define(spec)
23+
spec.inputs['engines']['bands']['code'].valid_type = CodeType('quantumespresso.pw')
24+
25+
def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
26+
"""Construct a process builder based on the provided keyword arguments.
27+
28+
The keyword arguments will have been validated against the input generator specification.
29+
"""
30+
# pylint: disable=too-many-branches,too-many-statements,too-many-locals
31+
engines = kwargs.get('engines', None)
32+
parent_folder = kwargs['parent_folder']
33+
bands_kpoints = kwargs['bands_kpoints']
34+
35+
# Find the `PwCalculation` that created the `parent_folder` and obtain the restart builder.
36+
parent_calc = parent_folder.get_incoming(link_type=LinkType.CREATE).one().node
37+
if parent_calc.process_type != 'aiida.calculations:quantumespresso.pw':
38+
raise ValueError('The `parent_folder` has not been created by a `PwCalculation`.')
39+
builder_calc = parent_calc.get_builder_restart()
40+
41+
builder_common_bands_wc = self.process_class.get_builder()
42+
builder_calc.pop('kpoints')
43+
builder_common_bands_wc.pw = builder_calc
44+
parameters = builder_common_bands_wc.pw.parameters.get_dict()
45+
parameters['CONTROL']['calculation'] = 'bands'
46+
builder_common_bands_wc.pw.parameters = orm.Dict(dict=parameters)
47+
builder_common_bands_wc.kpoints = bands_kpoints
48+
builder_common_bands_wc.pw.parent_folder = parent_folder
49+
50+
# Update the structure in case we have one in output, i.e. the `parent_calc` optimized the structure
51+
if 'output_structure' in parent_calc.outputs:
52+
builder_common_bands_wc.pw.structure = parent_calc.outputs.output_structure
53+
54+
# Update the code and computational options if `engines` is specified
55+
try:
56+
bands_engine = engines['bands']
57+
except KeyError:
58+
raise ValueError('The `engines` dictionary must contain `bands` as a top-level key')
59+
if 'code' in bands_engine:
60+
code = engines['bands']['code']
61+
if isinstance(code, str):
62+
code = orm.load_code(code)
63+
builder_common_bands_wc.pw.code = code
64+
if 'options' in bands_engine:
65+
builder_common_bands_wc.pw.metadata.options = engines['bands']['options']
66+
67+
return builder_common_bands_wc
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# -*- coding: utf-8 -*-
2+
"""Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO."""
3+
from aiida.engine import calcfunction
4+
from aiida.plugins import WorkflowFactory
5+
from aiida.orm import Float
6+
7+
from ..workchain import CommonBandsWorkChain
8+
from .generator import QuantumEspressoCommonBandsInputGenerator
9+
10+
__all__ = ('QuantumEspressoCommonBandsWorkChain',)
11+
12+
13+
@calcfunction
14+
def get_fermi_energy(output_parameters):
15+
"""Extract the Fermi energy from the ``output_parameters`` of a ``PwBaseWorkChain``."""
16+
return Float(output_parameters['fermi_energy'])
17+
18+
class QuantumEspressoCommonBandsWorkChain(CommonBandsWorkChain):
19+
"""Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO."""
20+
21+
_process_class = WorkflowFactory('quantumespresso.pw.base')
22+
_generator_class = QuantumEspressoCommonBandsInputGenerator
23+
24+
def convert_outputs(self):
25+
"""Convert the outputs of the sub work chain to the common output specification."""
26+
outputs = self.ctx.workchain.outputs
27+
28+
if 'output_band' not in outputs:
29+
self.report('The `bands` PwBaseWorkChain does not have the `output_band` output.')
30+
return self.exit_codes.ERROR_SUB_PROCESS_FAILED
31+
32+
self.out('bands', outputs.output_band)
33+
self.out('fermi_energy', get_fermi_energy(outputs.output_parameters))

aiida_common_workflows/workflows/relax/quantum_espresso/generator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def define(cls, spec):
9191
)
9292
spec.inputs['electronic_type'].valid_type = ChoiceType((ElectronicType.METAL, ElectronicType.INSULATOR))
9393
spec.inputs['engines']['relax']['code'].valid_type = CodeType('quantumespresso.pw')
94+
spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False),
95+
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.')
9496

9597
def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
9698
"""Construct a process builder based on the provided keyword arguments.
@@ -111,6 +113,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
111113
threshold_forces = kwargs.get('threshold_forces', None)
112114
threshold_stress = kwargs.get('threshold_stress', None)
113115
reference_workchain = kwargs.get('reference_workchain', None)
116+
clean_workdir = kwargs.get('clean_workdir')
114117

115118
if isinstance(electronic_type, str):
116119
electronic_type = types.ElectronicType(electronic_type)
@@ -162,6 +165,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
162165
spin_type=spin_type,
163166
initial_magnetic_moments=initial_magnetic_moments,
164167
)
168+
builder.clean_workdir = clean_workdir
165169

166170
if threshold_forces is not None:
167171
threshold = threshold_forces * CONSTANTS.bohr_to_ang / CONSTANTS.ry_to_ev

aiida_common_workflows/workflows/relax/quantum_espresso/workchain.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,4 @@ def convert_outputs(self):
6363
self.out('total_energy', total_energy)
6464
self.out('forces', forces)
6565
self.out('stress', stress)
66+
self.out('remote_folder', outputs.remote_folder)

setup.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@
7272
"common_workflows.relax.quantum_espresso = aiida_common_workflows.workflows.relax.quantum_espresso.workchain:QuantumEspressoCommonRelaxWorkChain",
7373
"common_workflows.relax.siesta = aiida_common_workflows.workflows.relax.siesta.workchain:SiestaCommonRelaxWorkChain",
7474
"common_workflows.relax.vasp = aiida_common_workflows.workflows.relax.vasp.workchain:VaspCommonRelaxWorkChain",
75-
"common_workflows.bands.siesta = aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain"
75+
"common_workflows.bands.siesta = aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain",
76+
"common_workflows.bands.quantum_espresso = aiida_common_workflows.workflows.bands.quantum_espresso.workchain:QuantumEspressoCommonBandsWorkChain"
7677
]
7778
},
7879
"license": "MIT License",

0 commit comments

Comments
 (0)