Skip to content

Commit 62957f2

Browse files
committed
Adding support for common protocols in VASP
1 parent 0323343 commit 62957f2

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

src/aiida_common_workflows/workflows/relax/vasp/generator.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Implementation of `aiida_common_workflows.common.relax.generator.CommonRelaxInputGenerator` for VASP."""
2+
import copy
23
import pathlib
34
import typing as t
45

@@ -54,7 +55,9 @@ def define(cls, spec):
5455
spec.inputs['relax_type'].valid_type = ChoiceType(tuple(RelaxType))
5556
spec.inputs['electronic_type'].valid_type = ChoiceType((ElectronicType.METAL, ElectronicType.INSULATOR))
5657
spec.inputs['engines']['relax']['code'].valid_type = CodeType('vasp.vasp')
57-
spec.inputs['protocol'].valid_type = ChoiceType(('fast', 'moderate', 'precise', 'verification-PBE-v1'))
58+
spec.inputs['protocol'].valid_type = ChoiceType(
59+
('fast', 'moderate', 'precise', 'verification-PBE-v1', 'custom')
60+
)
5861

5962
def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR0912,PLR0915
6063
"""Construct a process builder based on the provided keyword arguments.
@@ -65,6 +68,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
6568
structure = kwargs['structure']
6669
engines = kwargs['engines']
6770
protocol = kwargs['protocol']
71+
custom_protocol = kwargs.get('custom_protocol', None)
6872
spin_type = kwargs['spin_type']
6973
relax_type = kwargs['relax_type']
7074
magnetization_per_site = kwargs.get('magnetization_per_site', None)
@@ -75,7 +79,15 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
7579
# Get the protocol that we want to use
7680
if protocol is None:
7781
protocol = self._default_protocol
78-
protocol = self.get_protocol(protocol)
82+
83+
if protocol == 'custom':
84+
if custom_protocol is None:
85+
raise ValueError(
86+
'the `custom_protocol` input must be provided when the `protocol` input is set to `custom`.'
87+
)
88+
protocol = copy.deepcopy(custom_protocol)
89+
else:
90+
protocol = copy.deepcopy(self.get_protocol(protocol))
7991

8092
# Set the builder
8193
builder = self.process_class.get_builder()

0 commit comments

Comments
 (0)