|
| 1 | +"""Workflow to calculate the energy as a function of fixed total magnetization that can use any code |
| 2 | +plugin implementing the common relax workflow and supports fixed spin-moment calulations.""" |
| 3 | +import inspect |
| 4 | + |
| 5 | +from aiida import orm |
| 6 | +from aiida.common import exceptions |
| 7 | +from aiida.engine import WorkChain, append_ |
| 8 | +from aiida.plugins import WorkflowFactory |
| 9 | + |
| 10 | +from aiida_common_workflows.workflows.relax.generator import ElectronicType, OptionalRelaxFeatures, RelaxType, SpinType |
| 11 | +from aiida_common_workflows.workflows.relax.workchain import CommonRelaxWorkChain |
| 12 | + |
| 13 | + |
| 14 | +def validate_inputs(value, _): |
| 15 | + """Validate the entire input namespace.""" |
| 16 | + |
| 17 | + # Validate that the provided ``generator_inputs`` are valid for the associated input generator. |
| 18 | + process_class = WorkflowFactory(value['sub_process_class']) |
| 19 | + generator = process_class.get_input_generator() |
| 20 | + |
| 21 | + if not generator.supports_feature(OptionalRelaxFeatures.FIXED_MAGNETIZATION): |
| 22 | + return ( |
| 23 | + f'The `{value["sub_process_class"]}` plugin does not support the ' |
| 24 | + f'`{OptionalRelaxFeatures.FIXED_MAGNETIZATION}` optional feature required for this workflow.' |
| 25 | + ) |
| 26 | + |
| 27 | + try: |
| 28 | + generator.get_builder(structure=value['structure'], **value['generator_inputs']) |
| 29 | + except Exception as exc: |
| 30 | + return f'`{generator.__class__.__name__}.get_builder()` fails for the provided `generator_inputs`: {exc}' |
| 31 | + |
| 32 | + |
| 33 | +def validate_sub_process_class(value, _): |
| 34 | + """Validate the sub process class.""" |
| 35 | + try: |
| 36 | + process_class = WorkflowFactory(value) |
| 37 | + except exceptions.EntryPointError: |
| 38 | + return f'`{value}` is not a valid or registered workflow entry point.' |
| 39 | + |
| 40 | + if not inspect.isclass(process_class) or not issubclass(process_class, CommonRelaxWorkChain): |
| 41 | + return f'`{value}` is not a subclass of the `CommonRelaxWorkChain` common workflow.' |
| 42 | + |
| 43 | + |
| 44 | +def validate_total_magnetizations(value, _): |
| 45 | + """Validate the `fixed_total_magnetizations` input.""" |
| 46 | + if value and len(value) < 3: |
| 47 | + return 'need at least 3 total magnetizations.' |
| 48 | + if not all(isinstance(m, (int, float)) for m in value): |
| 49 | + return 'all total magnetizations must be numbers (int or float).' |
| 50 | + |
| 51 | + |
| 52 | +def validate_relax_type(value, _): |
| 53 | + """Validate the `generator_inputs.relax_type` input.""" |
| 54 | + if value is not None and isinstance(value, str): |
| 55 | + value = RelaxType(value) |
| 56 | + |
| 57 | + if value not in [RelaxType.NONE, RelaxType.POSITIONS, RelaxType.SHAPE, RelaxType.POSITIONS_SHAPE]: |
| 58 | + return '`generator_inputs.relax_type`. Equation of state and relaxation with variable volume not compatible.' |
| 59 | + |
| 60 | + |
| 61 | +class EnergyMagnetizationWorkChain(WorkChain): |
| 62 | + """Workflow to compute the energy vs magnetization curve for a given crystal structure.""" |
| 63 | + |
| 64 | + @classmethod |
| 65 | + def define(cls, spec): |
| 66 | + # yapf: disable |
| 67 | + super().define(spec) |
| 68 | + spec.input('structure', valid_type=orm.StructureData, help='The structure at equilibrium volume.') |
| 69 | + spec.input('fixed_total_magnetizations', valid_type=orm.List, required=True, |
| 70 | + validator=validate_total_magnetizations, serializer=orm.to_aiida_type, |
| 71 | + help='The list of fixed total magnetizations to be calculated for the structure.') |
| 72 | + spec.input_namespace('generator_inputs', |
| 73 | + help='The inputs that will be passed to the input generator of the specified `sub_process`.') |
| 74 | + spec.input('generator_inputs.engines', valid_type=dict, non_db=True) |
| 75 | + spec.input('generator_inputs.protocol', valid_type=str, non_db=True, |
| 76 | + help='The protocol to use when determining the workchain inputs.') |
| 77 | + spec.input('generator_inputs.relax_type', |
| 78 | + valid_type=(RelaxType, str), non_db=True, validator=validate_relax_type, |
| 79 | + help='The type of relaxation to perform.') |
| 80 | + spec.input('generator_inputs.spin_type', valid_type=(SpinType, str), required=False, non_db=True, |
| 81 | + help='The type of spin for the calculation.') |
| 82 | + spec.input('generator_inputs.electronic_type', valid_type=(ElectronicType, str), required=False, non_db=True, |
| 83 | + help='The type of electronics (insulator/metal) for the calculation.') |
| 84 | + spec.input('generator_inputs.threshold_forces', valid_type=float, required=False, non_db=True, |
| 85 | + help='Target threshold for the forces in eV/Å.') |
| 86 | + spec.input('generator_inputs.threshold_stress', valid_type=float, required=False, non_db=True, |
| 87 | + help='Target threshold for the stress in eV/Å^3.') |
| 88 | + spec.input_namespace('sub_process', dynamic=True, populate_defaults=False) |
| 89 | + spec.input('sub_process_class', non_db=True, validator=validate_sub_process_class) |
| 90 | + spec.inputs.validator = validate_inputs |
| 91 | + |
| 92 | + spec.outline( |
| 93 | + cls.run_em, |
| 94 | + cls.inspect_em, |
| 95 | + ) |
| 96 | + |
| 97 | + spec.output_namespace('total_energies', valid_type=orm.Float, |
| 98 | + help='The computed total energy of the relaxed structures at each scaling factor.') |
| 99 | + spec.output_namespace('total_magnetizations', valid_type=orm.Float, |
| 100 | + help='The fixed total magnetizations that were evaluated in mu_B.') |
| 101 | + spec.output_namespace('fermi_energies_up', valid_type=orm.Float, |
| 102 | + help=( |
| 103 | + 'The fermi energies of the spin-up channel (at each fixed total magnetization). ' |
| 104 | + 'Can be used to compute an effective magnetic field. Otherwise, only meaningful ' |
| 105 | + 'in combination with the BandsData that will be added in the future.' |
| 106 | + ) |
| 107 | + ) |
| 108 | + spec.output_namespace('fermi_energies_down', valid_type=orm.Float, |
| 109 | + help=( |
| 110 | + 'The fermi energies of the spin-down channel (at each fixed total magnetization). ' |
| 111 | + 'Can be used to compute an effective magnetic field. Otherwise, only meaningful ' |
| 112 | + 'in combination with the BandsData that will be added in the future.' |
| 113 | + ) |
| 114 | + ) |
| 115 | + spec.exit_code(400, 'ERROR_SUB_PROCESS_FAILED', |
| 116 | + message='At least one of the `{cls}` sub processes did not finish successfully.') |
| 117 | + |
| 118 | + def get_sub_workchain_builder(self, total_magnetization): |
| 119 | + """Return the builder for the relax workchain.""" |
| 120 | + structure = self.inputs.structure |
| 121 | + process_class = WorkflowFactory(self.inputs.sub_process_class) |
| 122 | + |
| 123 | + base_inputs = {'structure': structure, 'fixed_total_cell_magnetization': total_magnetization} |
| 124 | + |
| 125 | + builder = process_class.get_input_generator().get_builder(**base_inputs, **self.inputs.generator_inputs) |
| 126 | + builder._merge(**self.inputs.get('sub_process', {})) |
| 127 | + |
| 128 | + return builder |
| 129 | + |
| 130 | + def run_em(self): |
| 131 | + """Run the sub process at each scale factor to compute the structure volume and total energy.""" |
| 132 | + for total_magnetization in self.inputs.fixed_total_magnetizations: |
| 133 | + builder = self.get_sub_workchain_builder(total_magnetization) |
| 134 | + self.report( |
| 135 | + f'submitting `{builder.process_class.__name__}` for total_magnetization `{total_magnetization}`' |
| 136 | + ) |
| 137 | + self.to_context(children=append_(self.submit(builder))) |
| 138 | + |
| 139 | + def inspect_em(self): |
| 140 | + """Inspect all children workflows to make sure they finished successfully.""" |
| 141 | + if any(not child.is_finished_ok for child in self.ctx.children): |
| 142 | + return self.exit_codes.ERROR_SUB_PROCESS_FAILED.format(cls=self.inputs.sub_process_class) |
| 143 | + |
| 144 | + for index, child in enumerate(self.ctx.children): |
| 145 | + energy = child.outputs.total_energy |
| 146 | + total_magnetization = child.outputs.total_magnetization |
| 147 | + |
| 148 | + fermi_energy_up = child.outputs.fermi_energy_up |
| 149 | + fermi_energy_down = child.outputs.fermi_energy_down |
| 150 | + |
| 151 | + self.report(f'Image {index}: total_magnetization={total_magnetization}, total energy={energy.value}') |
| 152 | + |
| 153 | + self.out(f'total_energies.{index}', energy) |
| 154 | + self.out(f'total_magnetizations.{index}', total_magnetization) |
| 155 | + self.out(f'fermi_energies_up.{index}', fermi_energy_up) |
| 156 | + self.out(f'fermi_energies_down.{index}', fermi_energy_down) |
0 commit comments