From ad6e0f3ec1acd8cdf6ea00c1844827cf509191bd Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Fri, 14 Jul 2023 09:33:34 +0200 Subject: [PATCH] Run pre-commit Also refactored the formatting of the `-red` option string to a one-liner --- .../workflows/relax/wien2k/generator.py | 60 +++++++------------ .../workflows/relax/wien2k/workchain.py | 2 +- 2 files changed, 23 insertions(+), 39 deletions(-) diff --git a/aiida_common_workflows/workflows/relax/wien2k/generator.py b/aiida_common_workflows/workflows/relax/wien2k/generator.py index 5fe6ac8e..cab405da 100644 --- a/aiida_common_workflows/workflows/relax/wien2k/generator.py +++ b/aiida_common_workflows/workflows/relax/wien2k/generator.py @@ -1,16 +1,13 @@ # -*- coding: utf-8 -*- """Implementation of `aiida_common_workflows.common.relax.generator.CommonRelaxInputGenerator` for Wien2k.""" import os -from typing import Any, Dict, List, Tuple, Union +from aiida import engine, orm, plugins import yaml -from aiida import engine -from aiida import orm -from aiida import plugins - from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType from aiida_common_workflows.generators import ChoiceType, CodeType + from ..generator import CommonRelaxInputGenerator __all__ = ('Wien2kCommonRelaxInputGenerator',) @@ -30,14 +27,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def raise_invalid(message): - raise RuntimeError('invalid protocol registry `{}`: '.format(self.__class__.__name__) + message) - def _initialize_protocols(self): """Initialize the protocols class attribute by parsing them from the configuration file.""" _filepath = os.path.join(os.path.dirname(__file__), 'protocol.yml') - with open(_filepath) as _thefile: + with open(_filepath, encoding='utf-8') as _thefile: self._protocols = yaml.full_load(_thefile) @classmethod @@ -57,28 +51,23 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: The keyword arguments will have been validated against the input generator specification. """ - structure = kwargs['structure'] engines = kwargs['engines'] protocol = kwargs['protocol'] - spin_type = kwargs['spin_type'] - relax_type = kwargs['relax_type'] - magnetization_per_site = kwargs.get('magnetization_per_site', None) - threshold_forces = kwargs.get('threshold_forces', None) - threshold_stress = kwargs.get('threshold_stress', None) reference_workchain = kwargs.get('reference_workchain', None) electronic_type = kwargs['electronic_type'] # Checks if protocol not in self.get_protocol_names(): import warnings - warnings.warn('no protocol implemented with name {}, using default moderate'.format(protocol)) + warnings.warn(f'no protocol implemented with name {protocol}, using default moderate') protocol = self.get_default_protocol_name() if not all(x in engines.keys() for x in ['relax']): raise ValueError('The `engines` dictionary must contain "relax" as outermost key') # construct input for run123_lapw - inpdict = orm.Dict(dict={ + inpdict = orm.Dict( + dict={ '-red': self._protocols[protocol]['parameters']['red'], '-i': self._protocols[protocol]['parameters']['max-scf-iterations'], '-ec': self._protocols[protocol]['parameters']['scf-ene-tol-Ry'], @@ -89,38 +78,33 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: '-numk': self._protocols[protocol]['parameters']['numk'], '-numk2': self._protocols[protocol]['parameters']['numk2'], '-p': self._protocols[protocol]['parameters']['parallel'], - }) # run123_lapw [param] + } + ) if electronic_type == ElectronicType.INSULATOR: inpdict['-nometal'] = True - if reference_workchain: # ref. workchain is passed as input + if reference_workchain: # ref. workchain is passed as input # derive Rmt's from the ref. workchain and pass as input w2k_wchain = reference_workchain.get_outgoing(node_class=orm.WorkChainNode).one().node ref_wrkchn_res_dict = w2k_wchain.outputs.workchain_result.get_dict() rmt = ref_wrkchn_res_dict['Rmt'] atm_lbl = ref_wrkchn_res_dict['atom_labels'] if len(rmt) != len(atm_lbl): - raise # the list with Rmt radii should match the list of elements - red_string = '' - for i in range(len(rmt)): - red_string += atm_lbl[i] + ':' + str(rmt[i]) - if i < len(rmt)-1: # for all, but the last element of the list - red_string += ',' # append comma - inpdict['-red'] = red_string # pass Rmt's as input to subsequent wrk. chains + raise ValueError(f'The list of rmt radii does not match the list of elements: {rmt} and {atm_lbl}') + inpdict['-red'] = ','.join([f'{a}:{r}' for a, r in zip(atm_lbl, rmt)]) # derive k mesh from the ref. workchain and pass as input - if 'kmesh3' in ref_wrkchn_res_dict: # check if kmesh3 is in results dict - if ref_wrkchn_res_dict['kmesh3']: # check if the k mesh is not empty - inpdict['-numk'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3'] - if 'kmesh3k' in ref_wrkchn_res_dict: # check if kmesh3k is in results dict - if ref_wrkchn_res_dict['kmesh3k']: # check if the k mesh is not empty - inpdict['-numk2'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3k'] - if 'fftmesh3k' in ref_wrkchn_res_dict: # check if fftmesh3k is in results dict - if ref_wrkchn_res_dict['fftmesh3k']: # check if the FFT mesh is not empty - inpdict['-fft'] = ref_wrkchn_res_dict['fftmesh3k'] - - #res = NodeNumberJobResource(num_machines=8, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=1) # set resources + if 'kmesh3' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['kmesh3']: # check if kmesh3 is in results dict + inpdict['-numk'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3'] + if 'kmesh3k' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['kmesh3k' + ]: # check if kmesh3k is in results dict + inpdict['-numk2'] = '0' + ' ' + ref_wrkchn_res_dict['kmesh3k'] + if 'fftmesh3k' in ref_wrkchn_res_dict and ref_wrkchn_res_dict['fftmesh3k' + ]: # check if fftmesh3k is in results dict + inpdict['-fft'] = ref_wrkchn_res_dict['fftmesh3k'] + + # res = NodeNumberJobResource(num_machines=8, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=1) builder = self.process_class.get_builder() builder.aiida_structure = structure - builder.code = engines['relax']['code'] # load wien2k-run123_lapw code + builder.code = engines['relax']['code'] # load wien2k-run123_lapw code builder.options = orm.Dict(dict=engines['relax']['options']) builder.inpdict = inpdict diff --git a/aiida_common_workflows/workflows/relax/wien2k/workchain.py b/aiida_common_workflows/workflows/relax/wien2k/workchain.py index d1410a99..f430a6a0 100644 --- a/aiida_common_workflows/workflows/relax/wien2k/workchain.py +++ b/aiida_common_workflows/workflows/relax/wien2k/workchain.py @@ -13,7 +13,7 @@ @calcfunction def get_energy(pardict): """Extract the energy from the `workchain_result` dictionary (Ry -> eV)""" - return orm.Float(pardict['EtotRyd']*13.605693122994) + return orm.Float(pardict['EtotRyd'] * 13.605693122994) class Wien2kCommonRelaxWorkChain(CommonRelaxWorkChain):