Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wien2k implementation #314

Merged
merged 8 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions aiida_common_workflows/workflows/relax/wien2k/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
# pylint: disable=undefined-variable
"""Module with the implementations of the common structure relaxation workchain for Siesta."""
from .generator import *
from .workchain import *

__all__ = (generator.__all__ + workchain.__all__)
111 changes: 111 additions & 0 deletions aiida_common_workflows/workflows/relax/wien2k/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
"""Implementation of `aiida_common_workflows.common.relax.generator.CommonRelaxInputGenerator` for Wien2k."""
import os

from aiida import engine, orm, plugins
import yaml

from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
from aiida_common_workflows.generators import ChoiceType, CodeType

from ..generator import CommonRelaxInputGenerator

__all__ = ('Wien2kCommonRelaxInputGenerator',)

StructureData = plugins.DataFactory('structure')


class Wien2kCommonRelaxInputGenerator(CommonRelaxInputGenerator):
"""Generator of inputs for the Wien2kCommonRelaxWorkChain"""

_default_protocol = 'moderate'

def __init__(self, *args, **kwargs):
"""Construct an instance of the input generator, validating the class attributes."""

self._initialize_protocols()

super().__init__(*args, **kwargs)

def _initialize_protocols(self):
"""Initialize the protocols class attribute by parsing them from the configuration file."""
_filepath = os.path.join(os.path.dirname(__file__), 'protocol.yml')

with open(_filepath, encoding='utf-8') as _thefile:
self._protocols = yaml.full_load(_thefile)

@classmethod
def define(cls, spec):
"""Define the specification of the input generator.

The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
"""
super().define(spec)
spec.inputs['spin_type'].valid_type = ChoiceType((SpinType.NONE,))
spec.inputs['relax_type'].valid_type = ChoiceType((RelaxType.NONE,))
spec.inputs['electronic_type'].valid_type = ChoiceType((ElectronicType.METAL, ElectronicType.INSULATOR))
spec.inputs['engines']['relax']['code'].valid_type = CodeType('wien2k-run123_lapw')

def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
"""Construct a process builder based on the provided keyword arguments.

The keyword arguments will have been validated against the input generator specification.
"""
structure = kwargs['structure']
engines = kwargs['engines']
protocol = kwargs['protocol']
reference_workchain = kwargs.get('reference_workchain', None)
electronic_type = kwargs['electronic_type']

# Checks
if protocol not in self.get_protocol_names():
import warnings
warnings.warn(f'no protocol implemented with name {protocol}, using default moderate')
protocol = self.get_default_protocol_name()
if not all(x in engines.keys() for x in ['relax']):
raise ValueError('The `engines` dictionary must contain "relax" as outermost key')

# construct input for run123_lapw
inpdict = orm.Dict(
dict={
'-red': self._protocols[protocol]['parameters']['red'],
'-i': self._protocols[protocol]['parameters']['max-scf-iterations'],
'-ec': self._protocols[protocol]['parameters']['scf-ene-tol-Ry'],
'-cc': self._protocols[protocol]['parameters']['scf-charge-tol'],
'-fermits': self._protocols[protocol]['parameters']['fermi-temp-Ry'],
'-nokshift': self._protocols[protocol]['parameters']['nokshift'],
'-noprec': self._protocols[protocol]['parameters']['noprec'],
'-numk': self._protocols[protocol]['parameters']['numk'],
'-numk2': self._protocols[protocol]['parameters']['numk2'],
'-p': self._protocols[protocol]['parameters']['parallel'],
}
)
if electronic_type == ElectronicType.INSULATOR:
inpdict['-nometal'] = True
if reference_workchain: # ref. workchain is passed as input
# derive Rmt's from the ref. workchain and pass as input
w2k_wchain = reference_workchain.get_outgoing(node_class=orm.WorkChainNode).one().node
ref_wrkchn_res_dict = w2k_wchain.outputs.workchain_result.get_dict()
rmt = ref_wrkchn_res_dict['Rmt']
atm_lbl = ref_wrkchn_res_dict['atom_labels']
if len(rmt) != len(atm_lbl):
raise ValueError(f'The list of rmt radii does not match the list of elements: {rmt} and {atm_lbl}')
inpdict['-red'] = ','.join([f'{a}:{r}' for a, r in zip(atm_lbl, rmt)])
# derive k mesh from the ref. workchain and pass as input
if 'kmesh3' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['kmesh3']: # check if kmesh3 is in results dict
inpdict['-numk'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3']
if 'kmesh3k' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['kmesh3k'
]: # check if kmesh3k is in results dict
inpdict['-numk2'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3k']
if 'fftmesh3k' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['fftmesh3k'
]: # check if fftmesh3k is in results dict
inpdict['-fft'] = ref_wrkchn_res_dict['fftmesh3k']

# res = NodeNumberJobResource(num_machines=8, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=1)
builder = self.process_class.get_builder()
builder.aiida_structure = structure
builder.code = engines['relax']['code'] # load wien2k-run123_lapw code
builder.options = orm.Dict(dict=engines['relax']['options'])
builder.inpdict = inpdict

return builder
13 changes: 13 additions & 0 deletions aiida_common_workflows/workflows/relax/wien2k/protocol.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
moderate:
description: 'A standard list of inputs for Wienk.'
parameters:
red: '3'
max-scf-iterations: '100'
scf-ene-tol-Ry: '0.000001'
scf-charge-tol: '0.000001'
fermi-temp-Ry: '0.0045'
nokshift: True
noprec: '2'
numk: '-1 0.0317506'
numk2: '-1 0.023812976204734406'
parallel: False
29 changes: 29 additions & 0 deletions aiida_common_workflows/workflows/relax/wien2k/workchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
"""Implementation of `aiida_common_workflows.common.relax.workchain.CommonRelaxWorkChain` for Wien2k."""
from aiida import orm
from aiida.engine import calcfunction
from aiida.plugins import WorkflowFactory

from ..workchain import CommonRelaxWorkChain
from .generator import Wien2kCommonRelaxInputGenerator

__all__ = ('Wien2kCommonRelaxWorkChain',)


@calcfunction
def get_energy(pardict):
"""Extract the energy from the `workchain_result` dictionary (Ry -> eV)"""
return orm.Float(pardict['EtotRyd'] * 13.605693122994)


class Wien2kCommonRelaxWorkChain(CommonRelaxWorkChain):
"""Implementation of `aiida_common_workflows.common.relax.workchain.CommonRelaxWorkChain` for WIEN2k."""

_process_class = WorkflowFactory('wien2k.scf123_wf')
_generator_class = Wien2kCommonRelaxInputGenerator

def convert_outputs(self):
"""Convert the outputs of the sub workchain to the common output specification."""
self.report('Relaxation task concluded sucessfully, converting outputs')
self.out('total_energy', get_energy(self.ctx.workchain.outputs.workchain_result))
self.out('relaxed_structure', self.ctx.workchain.outputs.aiida_structure_out)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
'aiida-quantumespresso~=3.4,>=3.4.1',
'aiida-siesta>=1.2.0',
'aiida-vasp~=2.2',
'aiida-wien2k~=0.1.1',
'aiida-ase',
'pymatgen>=2022.1.20',
'numpy<1.24.0',
Expand Down Expand Up @@ -88,6 +89,7 @@ acwf = 'aiida_common_workflows.cli:cmd_root'
'common_workflows.relax.siesta' = 'aiida_common_workflows.workflows.relax.siesta.workchain:SiestaCommonRelaxWorkChain'
'common_workflows.relax.vasp' = 'aiida_common_workflows.workflows.relax.vasp.workchain:VaspCommonRelaxWorkChain'
'common_workflows.relax.gpaw' = 'aiida_common_workflows.workflows.relax.gpaw.workchain:GpawCommonRelaxWorkChain'
'common_workflows.relax.wien2k' = 'aiida_common_workflows.workflows.relax.wien2k.workchain:Wien2kCommonRelaxWorkChain'
'common_workflows.bands.siesta' = 'aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain'

[tool.flit.module]
Expand Down
Loading