Skip to content

Commit 3e4014d

Browse files
sphuberrubel75mbercxbosonie
authored
Add the Wien2kCommonRelaxWorkChain
First implementation of the common relax work chain for WIEN2k. This also includes the `moderate` protocol, used for the oxides verification project. Co-authored-by: Oleg Rubel <[email protected]> Co-authored-by: Marnik Bercx <[email protected]> Co-authored-by: Emanuele Bosoni <[email protected]>
1 parent ec219f6 commit 3e4014d

File tree

5 files changed

+162
-0
lines changed

5 files changed

+162
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -*- coding: utf-8 -*-
2+
# pylint: disable=undefined-variable
3+
"""Module with the implementations of the common structure relaxation workchain for Siesta."""
4+
from .generator import *
5+
from .workchain import *
6+
7+
__all__ = (generator.__all__ + workchain.__all__)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# -*- coding: utf-8 -*-
2+
"""Implementation of `aiida_common_workflows.common.relax.generator.CommonRelaxInputGenerator` for Wien2k."""
3+
import os
4+
5+
from aiida import engine, orm, plugins
6+
import yaml
7+
8+
from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
9+
from aiida_common_workflows.generators import ChoiceType, CodeType
10+
11+
from ..generator import CommonRelaxInputGenerator
12+
13+
__all__ = ('Wien2kCommonRelaxInputGenerator',)
14+
15+
StructureData = plugins.DataFactory('structure')
16+
17+
18+
class Wien2kCommonRelaxInputGenerator(CommonRelaxInputGenerator):
19+
"""Generator of inputs for the Wien2kCommonRelaxWorkChain"""
20+
21+
_default_protocol = 'moderate'
22+
23+
def __init__(self, *args, **kwargs):
24+
"""Construct an instance of the input generator, validating the class attributes."""
25+
26+
self._initialize_protocols()
27+
28+
super().__init__(*args, **kwargs)
29+
30+
def _initialize_protocols(self):
31+
"""Initialize the protocols class attribute by parsing them from the configuration file."""
32+
_filepath = os.path.join(os.path.dirname(__file__), 'protocol.yml')
33+
34+
with open(_filepath, encoding='utf-8') as _thefile:
35+
self._protocols = yaml.full_load(_thefile)
36+
37+
@classmethod
38+
def define(cls, spec):
39+
"""Define the specification of the input generator.
40+
41+
The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
42+
"""
43+
super().define(spec)
44+
spec.inputs['spin_type'].valid_type = ChoiceType((SpinType.NONE,))
45+
spec.inputs['relax_type'].valid_type = ChoiceType((RelaxType.NONE,))
46+
spec.inputs['electronic_type'].valid_type = ChoiceType((ElectronicType.METAL, ElectronicType.INSULATOR))
47+
spec.inputs['engines']['relax']['code'].valid_type = CodeType('wien2k-run123_lapw')
48+
49+
def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
50+
"""Construct a process builder based on the provided keyword arguments.
51+
52+
The keyword arguments will have been validated against the input generator specification.
53+
"""
54+
structure = kwargs['structure']
55+
engines = kwargs['engines']
56+
protocol = kwargs['protocol']
57+
reference_workchain = kwargs.get('reference_workchain', None)
58+
electronic_type = kwargs['electronic_type']
59+
60+
# Checks
61+
if protocol not in self.get_protocol_names():
62+
import warnings
63+
warnings.warn(f'no protocol implemented with name {protocol}, using default moderate')
64+
protocol = self.get_default_protocol_name()
65+
if not all(x in engines.keys() for x in ['relax']):
66+
raise ValueError('The `engines` dictionary must contain "relax" as outermost key')
67+
68+
# construct input for run123_lapw
69+
inpdict = orm.Dict(
70+
dict={
71+
'-red': self._protocols[protocol]['parameters']['red'],
72+
'-i': self._protocols[protocol]['parameters']['max-scf-iterations'],
73+
'-ec': self._protocols[protocol]['parameters']['scf-ene-tol-Ry'],
74+
'-cc': self._protocols[protocol]['parameters']['scf-charge-tol'],
75+
'-fermits': self._protocols[protocol]['parameters']['fermi-temp-Ry'],
76+
'-nokshift': self._protocols[protocol]['parameters']['nokshift'],
77+
'-noprec': self._protocols[protocol]['parameters']['noprec'],
78+
'-numk': self._protocols[protocol]['parameters']['numk'],
79+
'-numk2': self._protocols[protocol]['parameters']['numk2'],
80+
'-p': self._protocols[protocol]['parameters']['parallel'],
81+
}
82+
)
83+
if electronic_type == ElectronicType.INSULATOR:
84+
inpdict['-nometal'] = True
85+
if reference_workchain: # ref. workchain is passed as input
86+
# derive Rmt's from the ref. workchain and pass as input
87+
w2k_wchain = reference_workchain.get_outgoing(node_class=orm.WorkChainNode).one().node
88+
ref_wrkchn_res_dict = w2k_wchain.outputs.workchain_result.get_dict()
89+
rmt = ref_wrkchn_res_dict['Rmt']
90+
atm_lbl = ref_wrkchn_res_dict['atom_labels']
91+
if len(rmt) != len(atm_lbl):
92+
raise ValueError(f'The list of rmt radii does not match the list of elements: {rmt} and {atm_lbl}')
93+
inpdict['-red'] = ','.join([f'{a}:{r}' for a, r in zip(atm_lbl, rmt)])
94+
# derive k mesh from the ref. workchain and pass as input
95+
if 'kmesh3' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['kmesh3']: # check if kmesh3 is in results dict
96+
inpdict['-numk'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3']
97+
if 'kmesh3k' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['kmesh3k'
98+
]: # check if kmesh3k is in results dict
99+
inpdict['-numk2'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3k']
100+
if 'fftmesh3k' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['fftmesh3k'
101+
]: # check if fftmesh3k is in results dict
102+
inpdict['-fft'] = ref_wrkchn_res_dict['fftmesh3k']
103+
104+
# res = NodeNumberJobResource(num_machines=8, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=1)
105+
builder = self.process_class.get_builder()
106+
builder.aiida_structure = structure
107+
builder.code = engines['relax']['code'] # load wien2k-run123_lapw code
108+
builder.options = orm.Dict(dict=engines['relax']['options'])
109+
builder.inpdict = inpdict
110+
111+
return builder
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
moderate:
2+
description: 'A standard list of inputs for Wienk.'
3+
parameters:
4+
red: '3'
5+
max-scf-iterations: '100'
6+
scf-ene-tol-Ry: '0.000001'
7+
scf-charge-tol: '0.000001'
8+
fermi-temp-Ry: '0.0045'
9+
nokshift: True
10+
noprec: '2'
11+
numk: '-1 0.0317506'
12+
numk2: '-1 0.023812976204734406'
13+
parallel: False
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# -*- coding: utf-8 -*-
2+
"""Implementation of `aiida_common_workflows.common.relax.workchain.CommonRelaxWorkChain` for Wien2k."""
3+
from aiida import orm
4+
from aiida.engine import calcfunction
5+
from aiida.plugins import WorkflowFactory
6+
7+
from ..workchain import CommonRelaxWorkChain
8+
from .generator import Wien2kCommonRelaxInputGenerator
9+
10+
__all__ = ('Wien2kCommonRelaxWorkChain',)
11+
12+
13+
@calcfunction
14+
def get_energy(pardict):
15+
"""Extract the energy from the `workchain_result` dictionary (Ry -> eV)"""
16+
return orm.Float(pardict['EtotRyd'] * 13.605693122994)
17+
18+
19+
class Wien2kCommonRelaxWorkChain(CommonRelaxWorkChain):
20+
"""Implementation of `aiida_common_workflows.common.relax.workchain.CommonRelaxWorkChain` for WIEN2k."""
21+
22+
_process_class = WorkflowFactory('wien2k.scf123_wf')
23+
_generator_class = Wien2kCommonRelaxInputGenerator
24+
25+
def convert_outputs(self):
26+
"""Convert the outputs of the sub workchain to the common output specification."""
27+
self.report('Relaxation task concluded sucessfully, converting outputs')
28+
self.out('total_energy', get_energy(self.ctx.workchain.outputs.workchain_result))
29+
self.out('relaxed_structure', self.ctx.workchain.outputs.aiida_structure_out)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
'aiida-quantumespresso~=3.4,>=3.4.1',
3636
'aiida-siesta>=1.2.0',
3737
'aiida-vasp~=2.2',
38+
'aiida-wien2k~=0.1.1',
3839
'aiida-ase',
3940
'pymatgen>=2022.1.20',
4041
'numpy<1.24.0',
@@ -88,6 +89,7 @@ acwf = 'aiida_common_workflows.cli:cmd_root'
8889
'common_workflows.relax.siesta' = 'aiida_common_workflows.workflows.relax.siesta.workchain:SiestaCommonRelaxWorkChain'
8990
'common_workflows.relax.vasp' = 'aiida_common_workflows.workflows.relax.vasp.workchain:VaspCommonRelaxWorkChain'
9091
'common_workflows.relax.gpaw' = 'aiida_common_workflows.workflows.relax.gpaw.workchain:GpawCommonRelaxWorkChain'
92+
'common_workflows.relax.wien2k' = 'aiida_common_workflows.workflows.relax.wien2k.workchain:Wien2kCommonRelaxWorkChain'
9193
'common_workflows.bands.siesta' = 'aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain'
9294

9395
[tool.flit.module]

0 commit comments

Comments
 (0)