Skip to content

Commit 445ff0d

Browse files
committed
Adding support for custom protocols
For now, only implemented for QE and abinit, needs to be implemented for all Also, fixing a couple of bugs in the ACWF relax for ABINIT
1 parent 8a70baf commit 445ff0d

File tree

4 files changed

+39
-31
lines changed

4 files changed

+39
-31
lines changed

src/aiida_common_workflows/workflows/relax/abinit/generator.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define(cls, spec):
5151
(ElectronicType.METAL, ElectronicType.INSULATOR, ElectronicType.UNKNOWN)
5252
)
5353
spec.inputs['engines']['relax']['code'].valid_type = CodeType('abinit')
54-
spec.inputs['protocol'].valid_type = ChoiceType(('fast', 'moderate', 'precise', 'verification-PBE-v1'))
54+
spec.inputs['protocol'].valid_type = ChoiceType(('fast', 'moderate', 'precise', 'verification-PBE-v1', 'custom'))
5555

5656
def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR0912,PLR0915
5757
"""Construct a process builder based on the provided keyword arguments.
@@ -62,6 +62,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
6262
structure = kwargs['structure']
6363
engines = kwargs['engines']
6464
protocol = kwargs['protocol']
65+
custom_protocol = kwargs.get('custom_protocol', None)
6566
spin_type = kwargs['spin_type']
6667
relax_type = kwargs['relax_type']
6768
electronic_type = kwargs['electronic_type']
@@ -70,7 +71,12 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
7071
threshold_stress = kwargs.get('threshold_stress', None)
7172
reference_workchain = kwargs.get('reference_workchain', None)
7273

73-
protocol = copy.deepcopy(self.get_protocol(protocol))
74+
if protocol == 'custom':
75+
if custom_protocol is None:
76+
raise ValueError('the `custom_protocol` input must be provided when the `protocol` input is set to `custom`.')
77+
protocol = copy.deepcopy(custom_protocol)
78+
else:
79+
protocol = copy.deepcopy(self.get_protocol(protocol))
7480
code = engines['relax']['code']
7581

7682
pseudo_family_label = protocol.pop('pseudo_family')
@@ -187,31 +193,9 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
187193
warnings.warn(f'input magnetization per site was None, setting it to {magnetization_per_site}')
188194
magnetization_per_site = np.array(magnetization_per_site)
189195

190-
sum_is_zero = np.isclose(sum(magnetization_per_site), 0.0)
191-
all_are_zero = np.all(np.isclose(magnetization_per_site, 0.0))
192-
non_zero_mags = magnetization_per_site[~np.isclose(magnetization_per_site, 0.0)]
193-
all_non_zero_pos = np.all(non_zero_mags > 0.0)
194-
all_non_zero_neg = np.all(non_zero_mags < 0.0)
195-
196-
if all_are_zero: # non-magnetic
197-
warnings.warn(
198-
'all of the initial magnetizations per site are close to zero; doing a non-spin-polarized '
199-
'calculation'
200-
)
201-
elif (sum_is_zero and not all_are_zero) or (
202-
not all_non_zero_pos and not all_non_zero_neg
203-
): # antiferromagnetic
204-
print('Detected antiferromagnetic!')
205-
builder.abinit['parameters']['nsppol'] = 1 # antiferromagnetic system
206-
builder.abinit['parameters']['nspden'] = 2 # scalar spin-magnetization in the z-axis
207-
builder.abinit['parameters']['spinat'] = [[0.0, 0.0, mag] for mag in magnetization_per_site]
208-
elif not all_are_zero and (all_non_zero_pos or all_non_zero_neg): # ferromagnetic
209-
print('Detected ferromagnetic!')
210-
builder.abinit['parameters']['nsppol'] = 2 # collinear spin-polarization
211-
builder.abinit['parameters']['nspden'] = 2 # scalar spin-magnetization in the z-axis
212-
builder.abinit['parameters']['spinat'] = [[0.0, 0.0, mag] for mag in magnetization_per_site]
213-
else:
214-
raise ValueError(f'Initial magnetization {magnetization_per_site} is ambiguous')
196+
builder.abinit['parameters']['nsppol'] = 2 # collinear spin-polarization
197+
builder.abinit['parameters']['nspden'] = 2 # scalar spin-magnetization in the z-axis
198+
builder.abinit['parameters']['spinat'] = [[0.0, 0.0, mag] for mag in magnetization_per_site]
215199
elif spin_type == SpinType.NON_COLLINEAR:
216200
if magnetization_per_site is None:
217201
magnetization_per_site = get_initial_magnetization(structure)

src/aiida_common_workflows/workflows/relax/abinit/workchain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def get_stress(parameters):
2525
def get_forces(parameters):
2626
"""Return the forces array from the given parameters node."""
2727
forces = orm.ArrayData()
28-
forces.set_array(name='forces', array=np.array(parameters.base.attributes.get('forces')))
28+
forces.set_array(name='forces', array=np.array(parameters.base.attributes.get('cart_forces')))
2929
return forces
3030

3131

src/aiida_common_workflows/workflows/relax/generator.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ def validate_inputs(value, _):
1717
if value.get('magnetization_per_site') is not None and value.get('fixed_total_cell_magnetization') is not None:
1818
return 'the inputs `magnetization_per_site` and ' '`fixed_total_cell_magnetization` are mutually exclusive.'
1919

20+
if value.get('protocol') == 'custom' and value.get('custom_protocol') is None:
21+
return 'the `custom_protocol` input must be provided when the `protocol` input is set to `custom`.'
22+
23+
if value.get('protocol') != 'custom' and value.get('custom_protocol') is not None:
24+
return 'the `custom_protocol` input can only be provided when the `protocol` input is set to `custom`.'
25+
26+
# TODO: ensure all plugins actually honor this new custom_protocol input! (only QE implemented for now)
2027

2128
class OptionalRelaxFeatures(OptionalFeature):
2229
FIXED_MAGNETIZATION = 'fixed_total_cell_magnetization'
@@ -45,7 +52,7 @@ def define(cls, spec):
4552
)
4653
spec.input(
4754
'protocol',
48-
valid_type=ChoiceType(('fast', 'moderate', 'precise')),
55+
valid_type=ChoiceType(('fast', 'moderate', 'precise', 'custom')),
4956
default='moderate',
5057
non_db=True,
5158
help='The protocol to use for the automated input generation. This value indicates the level of precision '
@@ -82,6 +89,15 @@ def define(cls, spec):
8289
'electrons, for the site. This also corresponds to the magnetization of the site in Bohr magnetons '
8390
'(μB).',
8491
)
92+
spec.input(
93+
'custom_protocol',
94+
valid_type=dict,
95+
non_db=True,
96+
required=False,
97+
default=None,
98+
help='A custom protocol dictionary that can be provided when the `protocol` input is set to `custom`. '
99+
'In that case, this dictionary will be used to override the default protocol settings.',
100+
)
85101
spec.input(
86102
'fixed_total_cell_magnetization',
87103
valid_type=OptionalFeatureType(float),

src/aiida_common_workflows/workflows/relax/quantum_espresso/generator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define(cls, spec):
9797
"""
9898
super().define(spec)
9999
spec.inputs['protocol'].valid_type = ChoiceType(
100-
('fast', 'balanced', 'stringent', 'moderate', 'precise', 'verification-PBE-v1')
100+
('fast', 'balanced', 'stringent', 'moderate', 'precise', 'verification-PBE-v1', 'custom')
101101
)
102102
spec.inputs['spin_type'].valid_type = ChoiceType((SpinType.NONE, SpinType.COLLINEAR))
103103
spec.inputs['relax_type'].valid_type = ChoiceType(
@@ -155,7 +155,15 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
155155
# Currently, the `aiida-quantumespresso` workflows will expect one of the basic protocols to be passed to the
156156
# `get_builder_from_protocol()` method. Here, we switch to using the default protocol for the
157157
# `aiida-quantumespresso` plugin and pass the local protocols as `overrides`.
158-
if (
158+
if protocol == 'custom':
159+
custom_protocol = kwargs.get('custom_protocol', None)
160+
if custom_protocol is None:
161+
raise ValueError(
162+
'The `custom_protocol` input must be provided when the `protocol` input is set to `custom`.'
163+
)
164+
overrides = custom_protocol
165+
protocol = self._default_protocol
166+
elif (
159167
protocol not in self.process_class._process_class.get_available_protocols()
160168
and self.process_class._process_class._check_if_alias(protocol)
161169
not in self.process_class._process_class.get_available_protocols()

0 commit comments

Comments
 (0)