Skip to content

Commit

Permalink
Implement output trajectory merge in Cp2kBaseWorkChain (#209)
Browse files Browse the repository at this point in the history
* Implement output trajectory merge in Cp2kBaseWorkChain
* Add a unit test for the merge_trajectory_data function.
* Add an example of the MD restart.
* Modify the reftraj example to facilitate the check of trajectories.
---------

Co-authored-by: Carlo Antonio Pignedoli <[email protected]>
  • Loading branch information
yakutovicha and cpignedoli authored Mar 13, 2024
1 parent 4714825 commit 8874fb7
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 33 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# For further information on the license, see the LICENSE.txt file. #
###############################################################################


FROM aiidateam/aiida-core-with-services:2.5.0

# To prevent the container to exit prematurely.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ docker build -t aiida_cp2k_test .
Then, you can launch the container:

```bash
DOKERID=`docker run -it aiida_cp2k_test`
DOKERID=`docker run -d aiida_cp2k_test`
```
This will remeber the container ID in the variable `DOKERID`.
You can then run the tests with the following command:
Expand Down
10 changes: 9 additions & 1 deletion aiida_cp2k/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,19 @@ def _parse_trajectory(self, structure):

positions_traj = []
stepids_traj = []
energies_traj = []
for frame in parse(output_xyz_pos):
_, positions = zip(*frame["atoms"])
positions_traj.append(positions)
stepids_traj.append(int(frame["comment"].split()[2][:-1]))
comment_split = frame["comment"].split(",")
stepids_traj.append(int(comment_split[0].split()[-1]))
energy_index = next(
(i for i, s in enumerate(comment_split) if "E =" in s), None
)
energies_traj.append(float(comment_split[energy_index].split()[-1]))
positions_traj = np.array(positions_traj)
stepids_traj = np.array(stepids_traj)
energies_traj = np.array(energies_traj)

cell_traj = None
cell_traj_fname = self.node.process_class._DEFAULT_TRAJECT_CELL_FILE_NAME
Expand Down Expand Up @@ -190,6 +197,7 @@ def _parse_trajectory(self, structure):
symbols=symbols,
positions=positions_traj,
)
trajectory.set_array("energies", energies_traj)
if forces_traj is not None:
trajectory.set_array("forces", forces_traj)

Expand Down
6 changes: 6 additions & 0 deletions aiida_cp2k/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
###############################################################################
"""AiiDA-CP2K utils"""

from .datatype_helpers import (
merge_trajectory_data_non_unique,
merge_trajectory_data_unique,
)
from .input_generator import (
Cp2kInput,
add_ext_restart_section,
Expand Down Expand Up @@ -42,4 +46,6 @@
"merge_Dict",
"ot_has_small_bandgap",
"resize_unit_cell",
"merge_trajectory_data_unique",
"merge_trajectory_data_non_unique",
]
110 changes: 87 additions & 23 deletions aiida_cp2k/utils/datatype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import re
from collections.abc import Sequence

from aiida.common import InputValidationError
from aiida.plugins import DataFactory
import numpy as np
from aiida import common, engine, orm, plugins


def _unpack(adict):
Expand Down Expand Up @@ -50,7 +50,9 @@ def _kind_element_from_kind_section(section):
try:
kind = section["_"]
except KeyError:
raise InputValidationError("No default parameter '_' found in KIND section.")
raise common.InputValidationError(
"No default parameter '_' found in KIND section."
)

try:
element = section["ELEMENT"]
Expand All @@ -60,7 +62,7 @@ def _kind_element_from_kind_section(section):
try:
element = match["sym"]
except TypeError:
raise InputValidationError(
raise common.InputValidationError(
f"Unable to figure out atomic symbol from KIND '{kind}'."
)

Expand Down Expand Up @@ -125,7 +127,7 @@ def _write_gdt(inp, entries, folder, key, fname):
def validate_basissets_namespace(basissets, _):
"""A input_namespace validator to ensure passed down basis sets have the correct type."""
return _validate_gdt_namespace(
basissets, DataFactory("gaussian.basisset"), "basis set"
basissets, plugins.DataFactory("gaussian.basisset"), "basis set"
)


Expand Down Expand Up @@ -176,7 +178,7 @@ def validate_basissets(inp, basissets, structure):
bsets = [(t, b) for t, s, b in basissets if s == element]

if not bsets:
raise InputValidationError(
raise common.InputValidationError(
f"No basis set found for kind {kind} or element {element}"
f" in basissets input namespace and not explicitly set."
)
Expand All @@ -203,7 +205,7 @@ def validate_basissets(inp, basissets, structure):
bsets = [(t, b) for t, s, b in basissets if s == element]

if not bsets:
raise InputValidationError(
raise common.InputValidationError(
f"'BASIS_SET {bstype} {bsname}' for element {element} (from kind {kind})"
" not found in basissets input namespace"
)
Expand All @@ -213,7 +215,7 @@ def validate_basissets(inp, basissets, structure):
basissets_used.add(bset)
break
else:
raise InputValidationError(
raise common.InputValidationError(
f"'BASIS_SET {bstype} {bsname}' for element {element} (from kind {kind})"
" not found in basissets input namespace"
)
Expand All @@ -222,14 +224,14 @@ def validate_basissets(inp, basissets, structure):
if not structure and any(
bset not in basissets_used for bset in basissets_specified
):
raise InputValidationError(
raise common.InputValidationError(
"No explicit structure given and basis sets not referenced in input"
)

if isinstance(inp["FORCE_EVAL"], Sequence) and any(
kind.name not in explicit_kinds for kind in structure.kinds
):
raise InputValidationError(
raise common.InputValidationError(
"Automated BASIS_SET keyword creation is not yet supported with multiple FORCE_EVALs."
" Please explicitly reference a BASIS_SET for each KIND."
)
Expand All @@ -250,13 +252,13 @@ def validate_basissets(inp, basissets, structure):
bsets = [(t, b) for t, s, b in basissets if s == kind.symbol]

if not bsets:
raise InputValidationError(
raise common.InputValidationError(
f"No basis set found in the given basissets for kind '{kind.name}' of your structure."
)

for _, bset in bsets:
if bset.element != kind.symbol:
raise InputValidationError(
raise common.InputValidationError(
f"Basis set '{bset.name}' for '{bset.element}' specified"
f" for kind '{kind.name}' (of '{kind.symbol}')."
)
Expand All @@ -274,7 +276,7 @@ def validate_basissets(inp, basissets, structure):

for bset in basissets_specified:
if bset not in basissets_used:
raise InputValidationError(
raise common.InputValidationError(
f"Basis set '{bset.name}' ('{bset.element}') specified in the basissets"
f" input namespace but not referenced by either input or structure."
)
Expand All @@ -287,7 +289,9 @@ def write_basissets(inp, basissets, folder):

def validate_pseudos_namespace(pseudos, _):
"""A input_namespace validator to ensure passed down pseudopentials have the correct type."""
return _validate_gdt_namespace(pseudos, DataFactory("gaussian.pseudo"), "pseudo")
return _validate_gdt_namespace(
pseudos, plugins.DataFactory("gaussian.pseudo"), "pseudo"
)


def validate_pseudos(inp, pseudos, structure):
Expand Down Expand Up @@ -318,7 +322,7 @@ def validate_pseudos(inp, pseudos, structure):
try:
pseudo = pseudos[element]
except KeyError:
raise InputValidationError(
raise common.InputValidationError(
f"No pseudopotential found for kind {kind} or element {element}"
f" in pseudos input namespace and not explicitly set."
)
Expand All @@ -335,19 +339,19 @@ def validate_pseudos(inp, pseudos, structure):
try:
pseudo = pseudos[element]
except KeyError:
raise InputValidationError(
raise common.InputValidationError(
f"'POTENTIAL {ptype} {pname}' for element {element} (from kind {kind})"
" not found in pseudos input namespace"
)

if pname not in pseudo.aliases:
raise InputValidationError(
raise common.InputValidationError(
f"'POTENTIAL {ptype} {pname}' for element {element} (from kind {kind})"
" not found in pseudos input namespace"
)

if pseudo.element != element:
raise InputValidationError(
raise common.InputValidationError(
f"Pseudopotential '{pseudo.name}' for '{pseudo.element}' specified"
f" for element '{element}'."
)
Expand All @@ -358,14 +362,14 @@ def validate_pseudos(inp, pseudos, structure):
if not structure and any(
pseudo not in pseudos_used for pseudo in pseudos_specified
):
raise InputValidationError(
raise common.InputValidationError(
"No explicit structure given and pseudo not referenced in input"
)

if isinstance(inp["FORCE_EVAL"], Sequence) and any(
kind.name not in explicit_kinds for kind in structure.kinds
):
raise InputValidationError(
raise common.InputValidationError(
"Automated POTENTIAL keyword creation is not yet supported with multiple FORCE_EVALs."
" Please explicitly reference a POTENTIAL for each KIND."
)
Expand All @@ -383,13 +387,13 @@ def validate_pseudos(inp, pseudos, structure):
try:
pseudo = pseudos[kind.symbol]
except KeyError:
raise InputValidationError(
raise common.InputValidationError(
f"No basis set found in the given basissets"
f" for kind '{kind.name}' (or '{kind.symbol}') of your structure."
)

if pseudo.element != kind.symbol:
raise InputValidationError(
raise common.InputValidationError(
f"Pseudopotential '{pseudo.name}' for '{pseudo.element}' specified"
f" for kind '{kind.name}' (of '{kind.symbol}')."
)
Expand All @@ -402,7 +406,7 @@ def validate_pseudos(inp, pseudos, structure):

for pseudo in pseudos_specified:
if pseudo not in pseudos_used:
raise InputValidationError(
raise common.InputValidationError(
f"Pseudopodential '{pseudo.name}' specified in the pseudos input namespace"
f" but not referenced by either input or structure."
)
Expand All @@ -411,3 +415,63 @@ def validate_pseudos(inp, pseudos, structure):
def write_pseudos(inp, pseudos, folder):
"""Writes the unified POTENTIAL file with the used pseudos"""
_write_gdt(inp, pseudos, folder, "POTENTIAL_FILE_NAME", "POTENTIAL")


def _merge_trajectories_into_dictionary(*trajectories, unique_stepids=False):
if len(trajectories) < 0:
return None
final_trajectory_dict = {}

array_names = trajectories[0].get_arraynames()

for array_name in array_names:
if any(array_name not in traj.get_arraynames() for traj in trajectories):
raise ValueError(
f"Array name '{array_name}' not found in all trajectories."
)
merged_array = np.concatenate(
[traj.get_array(array_name) for traj in trajectories], axis=0
)
final_trajectory_dict[array_name] = merged_array

# If unique_stepids is True, we only keep the unique stepids.
# The other arrays are then also reduced to the unique stepids.
if unique_stepids:
stepids = np.concatenate([traj.get_stepids() for traj in trajectories], axis=0)
final_trajectory_dict["stepids"], unique_indices = np.unique(
stepids, return_index=True
)

for array_name in array_names:
final_trajectory_dict[array_name] = final_trajectory_dict[array_name][
unique_indices
]

return final_trajectory_dict


def _dictionary_to_trajectory(trajectory_dict, symbols):
final_trajectory = orm.TrajectoryData()
final_trajectory.set_trajectory(
symbols=symbols, positions=trajectory_dict.pop("positions")
)
for array_name, array in trajectory_dict.items():
final_trajectory.set_array(array_name, array)

return final_trajectory


@engine.calcfunction
def merge_trajectory_data_unique(*trajectories):
trajectory_dict = _merge_trajectories_into_dictionary(
*trajectories, unique_stepids=True
)
return _dictionary_to_trajectory(trajectory_dict, trajectories[0].symbols)


@engine.calcfunction
def merge_trajectory_data_non_unique(*trajectories):
trajectory_dict = _merge_trajectories_into_dictionary(
*trajectories, unique_stepids=False
)
return _dictionary_to_trajectory(trajectory_dict, trajectories[0].symbols)
20 changes: 20 additions & 0 deletions aiida_cp2k/workchains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,31 @@ def setup(self):
super().setup()
self.ctx.inputs = common.AttributeDict(self.exposed_inputs(Cp2kCalculation, 'cp2k'))

def _collect_all_trajetories(self):
"""Collect all trajectories from the children calculations."""
trajectories = []
for called in self.ctx.children:
if isinstance(called, orm.CalcJobNode):
try:
trajectories.append(called.outputs.output_trajectory)
except AttributeError:
pass
return trajectories

def results(self):
super().results()
if self.inputs.cp2k.parameters != self.ctx.inputs.parameters:
self.out('final_input_parameters', self.ctx.inputs.parameters)

trajectories = self._collect_all_trajetories()
if trajectories:
self.report("Work chain completed successfully, collecting all trajectories")
if self.ctx.inputs.parameters.get("GLOBAL", {}).get("RUN_TYPE") == "GEO_OPT":
output_trajectory = utils.merge_trajectory_data_non_unique(*trajectories)
else:
output_trajectory = utils.merge_trajectory_data_unique(*trajectories)
self.out("output_trajectory", output_trajectory)

def overwrite_input_structure(self):
if "output_structure" in self.ctx.children[self.ctx.iteration-1].outputs:
self.ctx.inputs.structure = self.ctx.children[self.ctx.iteration-1].outputs.output_structure
Expand Down
Loading

0 comments on commit 8874fb7

Please sign in to comment.