Skip to content

Commit

Permalink
Run pre-commit
Browse files Browse the repository at this point in the history
Also refactored the formatting of the `-red` option string to a
one-liner
  • Loading branch information
sphuber committed Jul 14, 2023
1 parent cc4d902 commit ad6e0f3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 39 deletions.
60 changes: 22 additions & 38 deletions aiida_common_workflows/workflows/relax/wien2k/generator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# -*- coding: utf-8 -*-
"""Implementation of `aiida_common_workflows.common.relax.generator.CommonRelaxInputGenerator` for Wien2k."""
import os
from typing import Any, Dict, List, Tuple, Union

from aiida import engine, orm, plugins
import yaml

from aiida import engine
from aiida import orm
from aiida import plugins

from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
from aiida_common_workflows.generators import ChoiceType, CodeType

from ..generator import CommonRelaxInputGenerator

__all__ = ('Wien2kCommonRelaxInputGenerator',)
Expand All @@ -30,14 +27,11 @@ def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)

def raise_invalid(message):
raise RuntimeError('invalid protocol registry `{}`: '.format(self.__class__.__name__) + message)

def _initialize_protocols(self):
"""Initialize the protocols class attribute by parsing them from the configuration file."""
_filepath = os.path.join(os.path.dirname(__file__), 'protocol.yml')

with open(_filepath) as _thefile:
with open(_filepath, encoding='utf-8') as _thefile:
self._protocols = yaml.full_load(_thefile)

@classmethod
Expand All @@ -57,28 +51,23 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
The keyword arguments will have been validated against the input generator specification.
"""

structure = kwargs['structure']
engines = kwargs['engines']
protocol = kwargs['protocol']
spin_type = kwargs['spin_type']
relax_type = kwargs['relax_type']
magnetization_per_site = kwargs.get('magnetization_per_site', None)
threshold_forces = kwargs.get('threshold_forces', None)
threshold_stress = kwargs.get('threshold_stress', None)
reference_workchain = kwargs.get('reference_workchain', None)
electronic_type = kwargs['electronic_type']

# Checks
if protocol not in self.get_protocol_names():
import warnings
warnings.warn('no protocol implemented with name {}, using default moderate'.format(protocol))
warnings.warn(f'no protocol implemented with name {protocol}, using default moderate')
protocol = self.get_default_protocol_name()
if not all(x in engines.keys() for x in ['relax']):
raise ValueError('The `engines` dictionary must contain "relax" as outermost key')

# construct input for run123_lapw
inpdict = orm.Dict(dict={
inpdict = orm.Dict(
dict={
'-red': self._protocols[protocol]['parameters']['red'],
'-i': self._protocols[protocol]['parameters']['max-scf-iterations'],
'-ec': self._protocols[protocol]['parameters']['scf-ene-tol-Ry'],
Expand All @@ -89,38 +78,33 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
'-numk': self._protocols[protocol]['parameters']['numk'],
'-numk2': self._protocols[protocol]['parameters']['numk2'],
'-p': self._protocols[protocol]['parameters']['parallel'],
}) # run123_lapw [param]
}
)
if electronic_type == ElectronicType.INSULATOR:
inpdict['-nometal'] = True
if reference_workchain: # ref. workchain is passed as input
if reference_workchain: # ref. workchain is passed as input
# derive Rmt's from the ref. workchain and pass as input
w2k_wchain = reference_workchain.get_outgoing(node_class=orm.WorkChainNode).one().node
ref_wrkchn_res_dict = w2k_wchain.outputs.workchain_result.get_dict()
rmt = ref_wrkchn_res_dict['Rmt']
atm_lbl = ref_wrkchn_res_dict['atom_labels']
if len(rmt) != len(atm_lbl):
raise # the list with Rmt radii should match the list of elements
red_string = ''
for i in range(len(rmt)):
red_string += atm_lbl[i] + ':' + str(rmt[i])
if i < len(rmt)-1: # for all, but the last element of the list
red_string += ',' # append comma
inpdict['-red'] = red_string # pass Rmt's as input to subsequent wrk. chains
raise ValueError(f'The list of rmt radii does not match the list of elements: {rmt} and {atm_lbl}')
inpdict['-red'] = ','.join([f'{a}:{r}' for a, r in zip(atm_lbl, rmt)])
# derive k mesh from the ref. workchain and pass as input
if 'kmesh3' in ref_wrkchn_res_dict: # check if kmesh3 is in results dict
if ref_wrkchn_res_dict['kmesh3']: # check if the k mesh is not empty
inpdict['-numk'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3']
if 'kmesh3k' in ref_wrkchn_res_dict: # check if kmesh3k is in results dict
if ref_wrkchn_res_dict['kmesh3k']: # check if the k mesh is not empty
inpdict['-numk2'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3k']
if 'fftmesh3k' in ref_wrkchn_res_dict: # check if fftmesh3k is in results dict
if ref_wrkchn_res_dict['fftmesh3k']: # check if the FFT mesh is not empty
inpdict['-fft'] = ref_wrkchn_res_dict['fftmesh3k']

#res = NodeNumberJobResource(num_machines=8, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=1) # set resources
if 'kmesh3' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['kmesh3']: # check if kmesh3 is in results dict
inpdict['-numk'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3']
if 'kmesh3k' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['kmesh3k'
]: # check if kmesh3k is in results dict
inpdict['-numk2'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3k']
if 'fftmesh3k' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['fftmesh3k'
]: # check if fftmesh3k is in results dict
inpdict['-fft'] = ref_wrkchn_res_dict['fftmesh3k']

# res = NodeNumberJobResource(num_machines=8, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=1)
builder = self.process_class.get_builder()
builder.aiida_structure = structure
builder.code = engines['relax']['code'] # load wien2k-run123_lapw code
builder.code = engines['relax']['code'] # load wien2k-run123_lapw code
builder.options = orm.Dict(dict=engines['relax']['options'])
builder.inpdict = inpdict

Expand Down
2 changes: 1 addition & 1 deletion aiida_common_workflows/workflows/relax/wien2k/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@calcfunction
def get_energy(pardict):
"""Extract the energy from the `workchain_result` dictionary (Ry -> eV)"""
return orm.Float(pardict['EtotRyd']*13.605693122994)
return orm.Float(pardict['EtotRyd'] * 13.605693122994)


class Wien2kCommonRelaxWorkChain(CommonRelaxWorkChain):
Expand Down

0 comments on commit ad6e0f3

Please sign in to comment.