Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement output trajectory merge in Cp2kBaseWorkChain #209

Merged
merged 16 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# For further information on the license, see the LICENSE.txt file. #
###############################################################################

FROM aiidateam/aiida-core:2.1.2
FROM aiidateam/aiida-core:2.3.1

# To prevent the container to exit prematurely.
ENV KILL_ALL_RPOCESSES_TIMEOUT=50
Expand Down
6 changes: 6 additions & 0 deletions aiida_cp2k/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
###############################################################################
"""AiiDA-CP2K utils"""

from .datatype_helpers import (
merge_trajectory_data_non_unique,
merge_trajectory_data_unique,
)
from .input_generator import (
Cp2kInput,
add_ext_restart_section,
Expand Down Expand Up @@ -42,4 +46,6 @@
"merge_Dict",
"ot_has_small_bandgap",
"resize_unit_cell",
"merge_trajectory_data_unique",
"merge_trajectory_data_non_unique",
]
110 changes: 87 additions & 23 deletions aiida_cp2k/utils/datatype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import re
from collections.abc import Sequence

from aiida.common import InputValidationError
from aiida.plugins import DataFactory
import numpy as np
from aiida import common, engine, orm, plugins


def _unpack(adict):
Expand Down Expand Up @@ -50,7 +50,9 @@ def _kind_element_from_kind_section(section):
try:
kind = section["_"]
except KeyError:
raise InputValidationError("No default parameter '_' found in KIND section.")
raise common.InputValidationError(
"No default parameter '_' found in KIND section."
)

try:
element = section["ELEMENT"]
Expand All @@ -60,7 +62,7 @@ def _kind_element_from_kind_section(section):
try:
element = match["sym"]
except TypeError:
raise InputValidationError(
raise common.InputValidationError(
f"Unable to figure out atomic symbol from KIND '{kind}'."
)

Expand Down Expand Up @@ -125,7 +127,7 @@ def _write_gdt(inp, entries, folder, key, fname):
def validate_basissets_namespace(basissets, _):
"""A input_namespace validator to ensure passed down basis sets have the correct type."""
return _validate_gdt_namespace(
basissets, DataFactory("gaussian.basisset"), "basis set"
basissets, plugins.DataFactory("gaussian.basisset"), "basis set"
)


Expand Down Expand Up @@ -176,7 +178,7 @@ def validate_basissets(inp, basissets, structure):
bsets = [(t, b) for t, s, b in basissets if s == element]

if not bsets:
raise InputValidationError(
raise common.InputValidationError(
f"No basis set found for kind {kind} or element {element}"
f" in basissets input namespace and not explicitly set."
)
Expand All @@ -203,7 +205,7 @@ def validate_basissets(inp, basissets, structure):
bsets = [(t, b) for t, s, b in basissets if s == element]

if not bsets:
raise InputValidationError(
raise common.InputValidationError(
f"'BASIS_SET {bstype} {bsname}' for element {element} (from kind {kind})"
" not found in basissets input namespace"
)
Expand All @@ -213,7 +215,7 @@ def validate_basissets(inp, basissets, structure):
basissets_used.add(bset)
break
else:
raise InputValidationError(
raise common.InputValidationError(
f"'BASIS_SET {bstype} {bsname}' for element {element} (from kind {kind})"
" not found in basissets input namespace"
)
Expand All @@ -222,14 +224,14 @@ def validate_basissets(inp, basissets, structure):
if not structure and any(
bset not in basissets_used for bset in basissets_specified
):
raise InputValidationError(
raise common.InputValidationError(
"No explicit structure given and basis sets not referenced in input"
)

if isinstance(inp["FORCE_EVAL"], Sequence) and any(
kind.name not in explicit_kinds for kind in structure.kinds
):
raise InputValidationError(
raise common.InputValidationError(
"Automated BASIS_SET keyword creation is not yet supported with multiple FORCE_EVALs."
" Please explicitly reference a BASIS_SET for each KIND."
)
Expand All @@ -250,13 +252,13 @@ def validate_basissets(inp, basissets, structure):
bsets = [(t, b) for t, s, b in basissets if s == kind.symbol]

if not bsets:
raise InputValidationError(
raise common.InputValidationError(
f"No basis set found in the given basissets for kind '{kind.name}' of your structure."
)

for _, bset in bsets:
if bset.element != kind.symbol:
raise InputValidationError(
raise common.InputValidationError(
f"Basis set '{bset.name}' for '{bset.element}' specified"
f" for kind '{kind.name}' (of '{kind.symbol}')."
)
Expand All @@ -274,7 +276,7 @@ def validate_basissets(inp, basissets, structure):

for bset in basissets_specified:
if bset not in basissets_used:
raise InputValidationError(
raise common.InputValidationError(
f"Basis set '{bset.name}' ('{bset.element}') specified in the basissets"
f" input namespace but not referenced by either input or structure."
)
Expand All @@ -287,7 +289,9 @@ def write_basissets(inp, basissets, folder):

def validate_pseudos_namespace(pseudos, _):
"""A input_namespace validator to ensure passed down pseudopentials have the correct type."""
return _validate_gdt_namespace(pseudos, DataFactory("gaussian.pseudo"), "pseudo")
return _validate_gdt_namespace(
pseudos, plugins.DataFactory("gaussian.pseudo"), "pseudo"
)


def validate_pseudos(inp, pseudos, structure):
Expand Down Expand Up @@ -318,7 +322,7 @@ def validate_pseudos(inp, pseudos, structure):
try:
pseudo = pseudos[element]
except KeyError:
raise InputValidationError(
raise common.InputValidationError(
f"No pseudopotential found for kind {kind} or element {element}"
f" in pseudos input namespace and not explicitly set."
)
Expand All @@ -335,19 +339,19 @@ def validate_pseudos(inp, pseudos, structure):
try:
pseudo = pseudos[element]
except KeyError:
raise InputValidationError(
raise common.InputValidationError(
f"'POTENTIAL {ptype} {pname}' for element {element} (from kind {kind})"
" not found in pseudos input namespace"
)

if pname not in pseudo.aliases:
raise InputValidationError(
raise common.InputValidationError(
f"'POTENTIAL {ptype} {pname}' for element {element} (from kind {kind})"
" not found in pseudos input namespace"
)

if pseudo.element != element:
raise InputValidationError(
raise common.InputValidationError(
f"Pseudopotential '{pseudo.name}' for '{pseudo.element}' specified"
f" for element '{element}'."
)
Expand All @@ -358,14 +362,14 @@ def validate_pseudos(inp, pseudos, structure):
if not structure and any(
pseudo not in pseudos_used for pseudo in pseudos_specified
):
raise InputValidationError(
raise common.InputValidationError(
"No explicit structure given and pseudo not referenced in input"
)

if isinstance(inp["FORCE_EVAL"], Sequence) and any(
kind.name not in explicit_kinds for kind in structure.kinds
):
raise InputValidationError(
raise common.InputValidationError(
"Automated POTENTIAL keyword creation is not yet supported with multiple FORCE_EVALs."
" Please explicitly reference a POTENTIAL for each KIND."
)
Expand All @@ -383,13 +387,13 @@ def validate_pseudos(inp, pseudos, structure):
try:
pseudo = pseudos[kind.symbol]
except KeyError:
raise InputValidationError(
raise common.InputValidationError(
f"No basis set found in the given basissets"
f" for kind '{kind.name}' (or '{kind.symbol}') of your structure."
)

if pseudo.element != kind.symbol:
raise InputValidationError(
raise common.InputValidationError(
f"Pseudopotential '{pseudo.name}' for '{pseudo.element}' specified"
f" for kind '{kind.name}' (of '{kind.symbol}')."
)
Expand All @@ -402,7 +406,7 @@ def validate_pseudos(inp, pseudos, structure):

for pseudo in pseudos_specified:
if pseudo not in pseudos_used:
raise InputValidationError(
raise common.InputValidationError(
f"Pseudopodential '{pseudo.name}' specified in the pseudos input namespace"
f" but not referenced by either input or structure."
)
Expand All @@ -411,3 +415,63 @@ def validate_pseudos(inp, pseudos, structure):
def write_pseudos(inp, pseudos, folder):
"""Writes the unified POTENTIAL file with the used pseudos"""
_write_gdt(inp, pseudos, folder, "POTENTIAL_FILE_NAME", "POTENTIAL")


def _merge_trajectories_into_dictionary(*trajectories, unique_stepids=False):
if len(trajectories) < 0:
return None
final_trajectory_dict = {}

array_names = trajectories[0].get_arraynames()

for array_name in array_names:
if any(array_name not in traj.get_arraynames() for traj in trajectories):
raise ValueError(
f"Array name '{array_name}' not found in all trajectories."
)
merged_array = np.concatenate(
[traj.get_array(array_name) for traj in trajectories], axis=0
)
final_trajectory_dict[array_name] = merged_array

# If unique_stepids is True, we only keep the unique stepids.
# The other arrays are then also reduced to the unique stepids.
if unique_stepids:
stepids = np.concatenate([traj.get_stepids() for traj in trajectories], axis=0)
final_trajectory_dict["stepids"], unique_indices = np.unique(
stepids, return_index=True
)

for array_name in array_names:
final_trajectory_dict[array_name] = final_trajectory_dict[array_name][
unique_indices
]

return final_trajectory_dict


def _dictionary_to_trajectory(trajectory_dict, symbols):
final_trajectory = orm.TrajectoryData()
final_trajectory.set_trajectory(
symbols=symbols, positions=trajectory_dict.pop("positions")
)
for array_name, array in trajectory_dict.items():
final_trajectory.set_array(array_name, array)

return final_trajectory


@engine.calcfunction
def merge_trajectory_data_unique(*trajectories):
trajectory_dict = _merge_trajectories_into_dictionary(
*trajectories, unique_stepids=True
)
return _dictionary_to_trajectory(trajectory_dict, trajectories[0].symbols)


@engine.calcfunction
def merge_trajectory_data_non_unique(*trajectories):
trajectory_dict = _merge_trajectories_into_dictionary(
*trajectories, unique_stepids=False
)
return _dictionary_to_trajectory(trajectory_dict, trajectories[0].symbols)
16 changes: 16 additions & 0 deletions aiida_cp2k/workchains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,27 @@ def setup(self):
super().setup()
self.ctx.inputs = common.AttributeDict(self.exposed_inputs(Cp2kCalculation, 'cp2k'))

def _collect_all_trajetories(self):
"""Collect all trajectories from the children calculations."""
trajectories = []
for called in self.ctx.children:
if isinstance(called, orm.CalcJobNode):
try:
trajectories.append(called.outputs.output_trajectory)
except AttributeError:
pass
return trajectories

def results(self):
super().results()
if self.inputs.cp2k.parameters != self.ctx.inputs.parameters:
self.out('final_input_parameters', self.ctx.inputs.parameters)

trajectories = self._collect_all_trajetories()
if trajectories:
self.report("Work chain completed successfully, collecting all trajectories")
self.out("output_trajectory", utils.merge_trajectory_data_unique(*trajectories))

def overwrite_input_structure(self):
if "output_structure" in self.ctx.children[self.ctx.iteration-1].outputs:
self.ctx.inputs.structure = self.ctx.children[self.ctx.iteration-1].outputs.output_structure
Expand Down
39 changes: 31 additions & 8 deletions examples/workchains/example_base_md_reftraj_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""An example testing the restart calculation handler for geo_opt run in CP2K."""

import os
import random
import sys

import ase.io
Expand Down Expand Up @@ -44,14 +43,9 @@ def example_base(cp2k_code):

# Trajectory.
steps = 20
positions = np.array(
[[[2, 2, 2.73 + 0.05 * random.random()], [2, 2, 2]] for i in range(steps)]
)
positions = np.array([[[2, 2, 2.73 + 0.01 * i], [2, 2, 2]] for i in range(steps)])
cells = np.array(
[
[[4, 0, 0], [0, 4, 0], [0, 0, 4.75 + 0.05 * random.random()]]
for i in range(steps)
]
[[[4, 0, 0], [0, 4, 0], [0, 0, 4.75 + 0.01 * i]] for i in range(steps)]
)
symbols = ["H", "H"]
trajectory = TrajectoryData()
Expand Down Expand Up @@ -172,6 +166,8 @@ def example_base(cp2k_code):
"ERROR, EXT_RESTART section is NOT present in the final_input_parameters."
)
sys.exit(1)

# Check stepids extracted from each individual calculation.
stepids = np.concatenate(
[
called.outputs.output_trajectory.get_stepids()
Expand All @@ -188,6 +184,33 @@ def example_base(cp2k_code):
)
sys.exit(1)

# Check the final trajectory.
final_trajectory = outputs["output_trajectory"]

if np.all(final_trajectory.get_stepids() == np.arange(1, steps + 1)):
print("OK, final trajectory stepids are correct.")
else:
print(
f"ERROR, final trajectory stepids are NOT correct. Expected: {np.arange(1, steps + 1)} but got: {final_trajectory.get_stepids()}"
)
sys.exit(1)

if final_trajectory.get_positions().shape == (steps, len(structure.sites), 3):
print("OK, the shape of the positions array is correct.")
else:
print(
f"ERROR, the shape of the positions array is NOT correct. Expected: {(steps, len(structure.sites), 3)} but got: {final_trajectory.get_positions().shape}"
)
sys.exit(1)

if final_trajectory.get_cells().shape == (steps, 3, 3):
print("OK, the shape of the cells array is correct.")
else:
print(
f"ERROR, the shape of the cells array is NOT correct. Expected: {(steps, 3, 3)} but got: {final_trajectory.get_cells().shape}"
)
sys.exit(1)


@click.command("cli")
@click.argument("codelabel")
Expand Down
Loading
Loading