Skip to content

Commit

Permalink
first implementation wien2k
Browse files Browse the repository at this point in the history
  • Loading branch information
rubel75 authored and sphuber committed Jul 14, 2023
1 parent 7d7cc9f commit cc4d902
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 0 deletions.
7 changes: 7 additions & 0 deletions aiida_common_workflows/workflows/relax/wien2k/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
# pylint: disable=undefined-variable
"""Module with the implementations of the common structure relaxation workchain for Siesta."""
from .generator import *
from .workchain import *

__all__ = (generator.__all__ + workchain.__all__)
127 changes: 127 additions & 0 deletions aiida_common_workflows/workflows/relax/wien2k/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
"""Implementation of `aiida_common_workflows.common.relax.generator.CommonRelaxInputGenerator` for Wien2k."""
import os
from typing import Any, Dict, List, Tuple, Union

import yaml

from aiida import engine
from aiida import orm
from aiida import plugins

from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
from aiida_common_workflows.generators import ChoiceType, CodeType
from ..generator import CommonRelaxInputGenerator

__all__ = ('Wien2kCommonRelaxInputGenerator',)

StructureData = plugins.DataFactory('structure')


class Wien2kCommonRelaxInputGenerator(CommonRelaxInputGenerator):
"""Generator of inputs for the Wien2kCommonRelaxWorkChain"""

_default_protocol = 'moderate'

def __init__(self, *args, **kwargs):
"""Construct an instance of the input generator, validating the class attributes."""

self._initialize_protocols()

super().__init__(*args, **kwargs)

def raise_invalid(message):
raise RuntimeError('invalid protocol registry `{}`: '.format(self.__class__.__name__) + message)

def _initialize_protocols(self):
"""Initialize the protocols class attribute by parsing them from the configuration file."""
_filepath = os.path.join(os.path.dirname(__file__), 'protocol.yml')

with open(_filepath) as _thefile:
self._protocols = yaml.full_load(_thefile)

@classmethod
def define(cls, spec):
"""Define the specification of the input generator.
The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
"""
super().define(spec)
spec.inputs['spin_type'].valid_type = ChoiceType((SpinType.NONE,))
spec.inputs['relax_type'].valid_type = ChoiceType((RelaxType.NONE,))
spec.inputs['electronic_type'].valid_type = ChoiceType((ElectronicType.METAL, ElectronicType.INSULATOR))
spec.inputs['engines']['relax']['code'].valid_type = CodeType('wien2k-run123_lapw')

def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
"""Construct a process builder based on the provided keyword arguments.
The keyword arguments will have been validated against the input generator specification.
"""

structure = kwargs['structure']
engines = kwargs['engines']
protocol = kwargs['protocol']
spin_type = kwargs['spin_type']
relax_type = kwargs['relax_type']
magnetization_per_site = kwargs.get('magnetization_per_site', None)
threshold_forces = kwargs.get('threshold_forces', None)
threshold_stress = kwargs.get('threshold_stress', None)
reference_workchain = kwargs.get('reference_workchain', None)
electronic_type = kwargs['electronic_type']

# Checks
if protocol not in self.get_protocol_names():
import warnings
warnings.warn('no protocol implemented with name {}, using default moderate'.format(protocol))
protocol = self.get_default_protocol_name()
if not all(x in engines.keys() for x in ['relax']):
raise ValueError('The `engines` dictionary must contain "relax" as outermost key')

# construct input for run123_lapw
inpdict = orm.Dict(dict={
'-red': self._protocols[protocol]['parameters']['red'],
'-i': self._protocols[protocol]['parameters']['max-scf-iterations'],
'-ec': self._protocols[protocol]['parameters']['scf-ene-tol-Ry'],
'-cc': self._protocols[protocol]['parameters']['scf-charge-tol'],
'-fermits': self._protocols[protocol]['parameters']['fermi-temp-Ry'],
'-nokshift': self._protocols[protocol]['parameters']['nokshift'],
'-noprec': self._protocols[protocol]['parameters']['noprec'],
'-numk': self._protocols[protocol]['parameters']['numk'],
'-numk2': self._protocols[protocol]['parameters']['numk2'],
'-p': self._protocols[protocol]['parameters']['parallel'],
}) # run123_lapw [param]
if electronic_type == ElectronicType.INSULATOR:
inpdict['-nometal'] = True
if reference_workchain: # ref. workchain is passed as input
# derive Rmt's from the ref. workchain and pass as input
w2k_wchain = reference_workchain.get_outgoing(node_class=orm.WorkChainNode).one().node
ref_wrkchn_res_dict = w2k_wchain.outputs.workchain_result.get_dict()
rmt = ref_wrkchn_res_dict['Rmt']
atm_lbl = ref_wrkchn_res_dict['atom_labels']
if len(rmt) != len(atm_lbl):
raise # the list with Rmt radii should match the list of elements
red_string = ''
for i in range(len(rmt)):
red_string += atm_lbl[i] + ':' + str(rmt[i])
if i < len(rmt)-1: # for all, but the last element of the list
red_string += ',' # append comma
inpdict['-red'] = red_string # pass Rmt's as input to subsequent wrk. chains
# derive k mesh from the ref. workchain and pass as input
if 'kmesh3' in ref_wrkchn_res_dict: # check if kmesh3 is in results dict
if ref_wrkchn_res_dict['kmesh3']: # check if the k mesh is not empty
inpdict['-numk'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3']
if 'kmesh3k' in ref_wrkchn_res_dict: # check if kmesh3k is in results dict
if ref_wrkchn_res_dict['kmesh3k']: # check if the k mesh is not empty
inpdict['-numk2'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3k']
if 'fftmesh3k' in ref_wrkchn_res_dict: # check if fftmesh3k is in results dict
if ref_wrkchn_res_dict['fftmesh3k']: # check if the FFT mesh is not empty
inpdict['-fft'] = ref_wrkchn_res_dict['fftmesh3k']

#res = NodeNumberJobResource(num_machines=8, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=1) # set resources
builder = self.process_class.get_builder()
builder.aiida_structure = structure
builder.code = engines['relax']['code'] # load wien2k-run123_lapw code
builder.options = orm.Dict(dict=engines['relax']['options'])
builder.inpdict = inpdict

return builder
13 changes: 13 additions & 0 deletions aiida_common_workflows/workflows/relax/wien2k/protocol.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
moderate:
description: 'A standard list of inputs for Wienk.'
parameters:
red: '3'
max-scf-iterations: '100'
scf-ene-tol-Ry: '0.000001'
scf-charge-tol: '0.000001'
fermi-temp-Ry: '0.0045'
nokshift: True
noprec: '2'
numk: '-1 0.0317506'
numk2: '-1 0.023812976204734406'
parallel: False
29 changes: 29 additions & 0 deletions aiida_common_workflows/workflows/relax/wien2k/workchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
"""Implementation of `aiida_common_workflows.common.relax.workchain.CommonRelaxWorkChain` for Wien2k."""
from aiida import orm
from aiida.engine import calcfunction
from aiida.plugins import WorkflowFactory

from ..workchain import CommonRelaxWorkChain
from .generator import Wien2kCommonRelaxInputGenerator

__all__ = ('Wien2kCommonRelaxWorkChain',)


@calcfunction
def get_energy(pardict):
"""Extract the energy from the `workchain_result` dictionary (Ry -> eV)"""
return orm.Float(pardict['EtotRyd']*13.605693122994)


class Wien2kCommonRelaxWorkChain(CommonRelaxWorkChain):
"""Implementation of `aiida_common_workflows.common.relax.workchain.CommonRelaxWorkChain` for WIEN2k."""

_process_class = WorkflowFactory('wien2k.scf123_wf')
_generator_class = Wien2kCommonRelaxInputGenerator

def convert_outputs(self):
"""Convert the outputs of the sub workchain to the common output specification."""
self.report('Relaxation task concluded sucessfully, converting outputs')
self.out('total_energy', get_energy(self.ctx.workchain.outputs.workchain_result))
self.out('relaxed_structure', self.ctx.workchain.outputs.aiida_structure_out)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ acwf = 'aiida_common_workflows.cli:cmd_root'
'common_workflows.relax.siesta' = 'aiida_common_workflows.workflows.relax.siesta.workchain:SiestaCommonRelaxWorkChain'
'common_workflows.relax.vasp' = 'aiida_common_workflows.workflows.relax.vasp.workchain:VaspCommonRelaxWorkChain'
'common_workflows.relax.gpaw' = 'aiida_common_workflows.workflows.relax.gpaw.workchain:GpawCommonRelaxWorkChain'
'common_workflows.relax.wien2k' = 'aiida_common_workflows.workflows.relax.wien2k.workchain:Wien2kCommonRelaxWorkChain'
'common_workflows.bands.siesta' = 'aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain'

[tool.flit.module]
Expand Down

0 comments on commit cc4d902

Please sign in to comment.