Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 34 additions & 30 deletions src/aiida_common_workflows/workflows/relax/abinit/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def define(cls, spec):
(ElectronicType.METAL, ElectronicType.INSULATOR, ElectronicType.UNKNOWN)
)
spec.inputs['engines']['relax']['code'].valid_type = CodeType('abinit')
spec.inputs['protocol'].valid_type = ChoiceType(('fast', 'moderate', 'precise', 'verification-PBE-v1'))
spec.inputs['protocol'].valid_type = ChoiceType(
('fast', 'moderate', 'precise', 'verification-PBE-v1', 'custom')
)

def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR0912,PLR0915
"""Construct a process builder based on the provided keyword arguments.
Expand All @@ -62,6 +64,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
structure = kwargs['structure']
engines = kwargs['engines']
protocol = kwargs['protocol']
custom_protocol = kwargs.get('custom_protocol', None)
spin_type = kwargs['spin_type']
relax_type = kwargs['relax_type']
electronic_type = kwargs['electronic_type']
Expand All @@ -70,7 +73,14 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
threshold_stress = kwargs.get('threshold_stress', None)
reference_workchain = kwargs.get('reference_workchain', None)

protocol = copy.deepcopy(self.get_protocol(protocol))
if protocol == 'custom':
if custom_protocol is None:
raise ValueError(
'the `custom_protocol` input must be provided when the `protocol` input is set to `custom`.'
)
protocol = copy.deepcopy(custom_protocol)
else:
protocol = copy.deepcopy(self.get_protocol(protocol))
code = engines['relax']['code']

pseudo_family_label = protocol.pop('pseudo_family')
Expand All @@ -87,15 +97,31 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
recommended_ecut_wfc, recommended_ecut_rho = pseudo_family.get_recommended_cutoffs(
structure=structure, stringency=cutoff_stringency, unit='Eh'
)

# In both cases, if the protocol "hardcodes" the cutoff(s),
# I use that instead of the one from the pseudopotential family
# since it probably means the user really wanted that cutoff.
# I use try/except since I need to go deep into a dictionary and
# it is easier than using dict.get() a lot of times.
try:
protocol_ecut = protocol['base']['abinit']['parameters']['ecut']
except KeyError:
protocol_ecut = None

try:
protocol_pawecutdg = protocol['base']['abinit']['parameters']['pawecutdg']
except KeyError:
protocol_pawecutdg = None

if pseudo_type == 'pseudo.jthxml':
# JTH XML are PAW; we need `pawecutdg`
cutoff_parameters = {
'ecut': np.ceil(recommended_ecut_wfc),
'pawecutdg': np.ceil(recommended_ecut_rho),
'ecut': protocol_ecut if protocol_ecut is not None else np.ceil(recommended_ecut_wfc),
'pawecutdg': protocol_pawecutdg if protocol_pawecutdg is not None else np.ceil(recommended_ecut_rho),
}
else:
# All others are NC; no need for `pawecutdg`
cutoff_parameters = {'ecut': recommended_ecut_wfc}
cutoff_parameters = {'ecut': protocol_ecut if protocol_ecut is not None else np.ceil(recommended_ecut_wfc)}

override = {
'abinit': {
Expand Down Expand Up @@ -187,31 +213,9 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
warnings.warn(f'input magnetization per site was None, setting it to {magnetization_per_site}')
magnetization_per_site = np.array(magnetization_per_site)

sum_is_zero = np.isclose(sum(magnetization_per_site), 0.0)
all_are_zero = np.all(np.isclose(magnetization_per_site, 0.0))
non_zero_mags = magnetization_per_site[~np.isclose(magnetization_per_site, 0.0)]
all_non_zero_pos = np.all(non_zero_mags > 0.0)
all_non_zero_neg = np.all(non_zero_mags < 0.0)

if all_are_zero: # non-magnetic
warnings.warn(
'all of the initial magnetizations per site are close to zero; doing a non-spin-polarized '
'calculation'
)
elif (sum_is_zero and not all_are_zero) or (
not all_non_zero_pos and not all_non_zero_neg
): # antiferromagnetic
print('Detected antiferromagnetic!')
builder.abinit['parameters']['nsppol'] = 1 # antiferromagnetic system
builder.abinit['parameters']['nspden'] = 2 # scalar spin-magnetization in the z-axis
builder.abinit['parameters']['spinat'] = [[0.0, 0.0, mag] for mag in magnetization_per_site]
elif not all_are_zero and (all_non_zero_pos or all_non_zero_neg): # ferromagnetic
print('Detected ferromagnetic!')
builder.abinit['parameters']['nsppol'] = 2 # collinear spin-polarization
builder.abinit['parameters']['nspden'] = 2 # scalar spin-magnetization in the z-axis
builder.abinit['parameters']['spinat'] = [[0.0, 0.0, mag] for mag in magnetization_per_site]
else:
raise ValueError(f'Initial magnetization {magnetization_per_site} is ambiguous')
builder.abinit['parameters']['nsppol'] = 2 # collinear spin-polarization
builder.abinit['parameters']['nspden'] = 2 # scalar spin-magnetization in the z-axis
builder.abinit['parameters']['spinat'] = [[0.0, 0.0, mag] for mag in magnetization_per_site]
elif spin_type == SpinType.NON_COLLINEAR:
if magnetization_per_site is None:
magnetization_per_site = get_initial_magnetization(structure)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_stress(parameters):
def get_forces(parameters):
"""Return the forces array from the given parameters node."""
forces = orm.ArrayData()
forces.set_array(name='forces', array=np.array(parameters.base.attributes.get('forces')))
forces.set_array(name='forces', array=np.array(parameters.base.attributes.get('cart_forces')))
return forces


Expand Down
19 changes: 18 additions & 1 deletion src/aiida_common_workflows/workflows/relax/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ def validate_inputs(value, _):
if value.get('magnetization_per_site') is not None and value.get('fixed_total_cell_magnetization') is not None:
return 'the inputs `magnetization_per_site` and ' '`fixed_total_cell_magnetization` are mutually exclusive.'

if value.get('protocol') == 'custom' and value.get('custom_protocol') is None:
return 'the `custom_protocol` input must be provided when the `protocol` input is set to `custom`.'

if value.get('protocol') != 'custom' and value.get('custom_protocol') is not None:
return 'the `custom_protocol` input can only be provided when the `protocol` input is set to `custom`.'

# TODO: ensure all plugins actually honor this new custom_protocol input! (only QE implemented for now)


class OptionalRelaxFeatures(OptionalFeature):
FIXED_MAGNETIZATION = 'fixed_total_cell_magnetization'
Expand Down Expand Up @@ -45,7 +53,7 @@ def define(cls, spec):
)
spec.input(
'protocol',
valid_type=ChoiceType(('fast', 'moderate', 'precise')),
valid_type=ChoiceType(('fast', 'moderate', 'precise', 'custom')),
default='moderate',
non_db=True,
help='The protocol to use for the automated input generation. This value indicates the level of precision '
Expand Down Expand Up @@ -82,6 +90,15 @@ def define(cls, spec):
'electrons, for the site. This also corresponds to the magnetization of the site in Bohr magnetons '
'(μB).',
)
spec.input(
'custom_protocol',
valid_type=(dict, type(None)),
non_db=True,
required=False,
default=None,
help='A custom protocol dictionary that can be provided when the `protocol` input is set to `custom`. '
'In that case, this dictionary will be used to override the default protocol settings.',
)
spec.input(
'fixed_total_cell_magnetization',
valid_type=OptionalFeatureType(float),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def define(cls, spec):
"""
super().define(spec)
spec.inputs['protocol'].valid_type = ChoiceType(
('fast', 'balanced', 'stringent', 'moderate', 'precise', 'verification-PBE-v1')
('fast', 'balanced', 'stringent', 'moderate', 'precise', 'verification-PBE-v1', 'custom')
)
spec.inputs['spin_type'].valid_type = ChoiceType((SpinType.NONE, SpinType.COLLINEAR))
spec.inputs['relax_type'].valid_type = ChoiceType(
Expand Down Expand Up @@ -155,7 +155,15 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
# Currently, the `aiida-quantumespresso` workflows will expect one of the basic protocols to be passed to the
# `get_builder_from_protocol()` method. Here, we switch to using the default protocol for the
# `aiida-quantumespresso` plugin and pass the local protocols as `overrides`.
if (
if protocol == 'custom':
custom_protocol = kwargs.get('custom_protocol', None)
if custom_protocol is None:
raise ValueError(
'The `custom_protocol` input must be provided when the `protocol` input is set to `custom`.'
)
overrides = custom_protocol
protocol = self._default_protocol
elif (
protocol not in self.process_class._process_class.get_available_protocols()
and self.process_class._process_class._check_if_alias(protocol)
not in self.process_class._process_class.get_available_protocols()
Expand Down
26 changes: 22 additions & 4 deletions src/aiida_common_workflows/workflows/relax/vasp/generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Implementation of `aiida_common_workflows.common.relax.generator.CommonRelaxInputGenerator` for VASP."""

import copy
import os
import pathlib
import typing as t
Expand Down Expand Up @@ -56,7 +56,9 @@ def define(cls, spec):
spec.inputs['relax_type'].valid_type = ChoiceType(tuple(RelaxType))
spec.inputs['electronic_type'].valid_type = ChoiceType((ElectronicType.METAL, ElectronicType.INSULATOR))
spec.inputs['engines']['relax']['code'].valid_type = CodeType('vasp.vasp')
spec.inputs['protocol'].valid_type = ChoiceType(('fast', 'moderate', 'precise', 'verification-PBE-v1'))
spec.inputs['protocol'].valid_type = ChoiceType(
('fast', 'moderate', 'precise', 'verification-PBE-v1', 'custom')
)

def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR0912,PLR0915
"""Construct a process builder based on the provided keyword arguments.
Expand All @@ -67,6 +69,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
structure = kwargs['structure']
engines = kwargs['engines']
protocol = kwargs['protocol']
custom_protocol = kwargs.get('custom_protocol', None)
spin_type = kwargs['spin_type']
relax_type = kwargs['relax_type']
magnetization_per_site = kwargs.get('magnetization_per_site', None)
Expand All @@ -77,7 +80,15 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
# Get the protocol that we want to use
if protocol is None:
protocol = self._default_protocol
protocol = self.get_protocol(protocol)

if protocol == 'custom':
if custom_protocol is None:
raise ValueError(
'the `custom_protocol` input must be provided when the `protocol` input is set to `custom`.'
)
protocol = copy.deepcopy(custom_protocol)
else:
protocol = copy.deepcopy(self.get_protocol(protocol))

# Set the builder
builder = self.process_class.get_builder()
Expand Down Expand Up @@ -172,7 +183,14 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
previous_kpoints.base.attributes.get('mesh'), previous_kpoints.base.attributes.get('offset')
)
else:
kpoints.set_kpoints_mesh_from_density(protocol['kpoint_distance'])
if 'kpoints' in protocol and 'kpoint_distance' in protocol:
raise ValueError('Protocol cannot define both `kpoints` and `kpoint_distance` in protocol.')
if 'kpoints' not in protocol and 'kpoint_distance' not in protocol:
raise ValueError('Protocol must define either `kpoints` or `kpoint_distance` in protocol.')
if 'kpoints' in protocol:
kpoints.set_kpoints_mesh(protocol['kpoints'])
else:
kpoints.set_kpoints_mesh_from_density(protocol['kpoint_distance'])
builder.vasp.kpoints = kpoints

# Set the relax parameters
Expand Down