|
1 | 1 | """Convert NFF Dataset to CHGNet StructureData""" |
2 | 2 |
|
3 | | -from typing import Dict, List |
4 | | - |
| 3 | +from typing import Dict, List, TYPE_CHECKING |
| 4 | +import functools |
| 5 | +import random |
5 | 6 | import torch |
6 | | -from chgnet.data.dataset import StructureData |
| 7 | +import numpy as np |
7 | 8 | from pymatgen.core.structure import Structure |
8 | 9 | from pymatgen.io.ase import AseAtomsAdaptor |
9 | 10 |
|
10 | 11 | from nff.data import Dataset |
11 | 12 | from nff.io import AtomsBatch |
12 | 13 | from nff.utils.cuda import batch_detach, detach |
| 14 | +from chgnet.graph import CrystalGraph, CrystalGraphConverter |
| 15 | +from collections.abc import Sequence |
| 16 | + |
| 17 | +datatype = torch.float32 |
| 18 | + |
| 19 | +class StructureData(Dataset): |
| 20 | + """A simple torch Dataset of structures.""" |
| 21 | + |
| 22 | + def __init__( |
| 23 | + self, |
| 24 | + structures: list[Structure], |
| 25 | + energies: list[float], |
| 26 | + forces: list[Sequence[Sequence[float]]], |
| 27 | + stresses: list[Sequence[Sequence[float]]] | None = None, |
| 28 | + magmoms: list[Sequence[Sequence[float]]] | None = None, |
| 29 | + structure_ids: list[str] | None = None, |
| 30 | + graph_converter: CrystalGraphConverter | None = None, |
| 31 | + ) -> None: |
| 32 | + """Initialize the dataset. |
| 33 | +
|
| 34 | + Args: |
| 35 | + structures (list[dict]): pymatgen Structure objects. |
| 36 | + energies (list[float]): [data_size, 1] |
| 37 | + forces (list[list[float]]): [data_size, n_atoms, 3] |
| 38 | + stresses (list[list[float]], optional): [data_size, 3, 3] |
| 39 | + magmoms (list[list[float]], optional): [data_size, n_atoms, 1] |
| 40 | + structure_ids (list[str], optional): a list of ids to track the structures |
| 41 | + graph_converter (CrystalGraphConverter, optional): Converts the structures |
| 42 | + to graphs. If None, it will be set to CHGNet 0.3.0 converter |
| 43 | + with AtomGraph cutoff = 6A. |
| 44 | +
|
| 45 | + Raises: |
| 46 | + RuntimeError: if the length of structures and labels (energies, forces, |
| 47 | + stresses, magmoms) are not equal. |
| 48 | + """ |
| 49 | + for idx, struct in enumerate(structures): |
| 50 | + if not isinstance(struct, Structure): |
| 51 | + raise ValueError(f"{idx} is not a pymatgen Structure object: {struct}") |
| 52 | + for name in "energies forces stresses magmoms structure_ids".split(): |
| 53 | + labels = locals()[name] |
| 54 | + if labels is not None and len(labels) != len(structures): |
| 55 | + raise RuntimeError( |
| 56 | + f"Inconsistent number of structures and labels: " |
| 57 | + f"{len(structures)=}, len({name})={len(labels)}" |
| 58 | + ) |
| 59 | + self.structures = structures |
| 60 | + self.energies = energies |
| 61 | + self.forces = forces |
| 62 | + self.stresses = stresses |
| 63 | + self.magmoms = magmoms |
| 64 | + self.structure_ids = structure_ids |
| 65 | + self.keys = np.arange(len(structures)) |
| 66 | + random.shuffle(self.keys) |
| 67 | + self.graph_converter = graph_converter or CrystalGraphConverter( |
| 68 | + atom_graph_cutoff=6, bond_graph_cutoff=3 |
| 69 | + ) |
| 70 | + self.failed_idx: list[int] = [] |
| 71 | + self.failed_graph_id: dict[str, str] = {} |
| 72 | + |
| 73 | + def __len__(self) -> int: |
| 74 | + """Get the number of structures in this dataset.""" |
| 75 | + return len(self.keys) |
| 76 | + |
| 77 | + @functools.cache # Cache loaded structures |
| 78 | + def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]: |
| 79 | + """Get one graph for a structure in this dataset. |
| 80 | +
|
| 81 | + Args: |
| 82 | + idx (int): Index of the structure |
| 83 | +
|
| 84 | + Returns: |
| 85 | + crystal_graph (CrystalGraph): graph of the crystal structure |
| 86 | + targets (dict): list of targets. i.e. energy, force, stress |
| 87 | + """ |
| 88 | + if idx not in self.failed_idx: |
| 89 | + graph_id = self.keys[idx] |
| 90 | + try: |
| 91 | + struct = self.structures[graph_id] |
| 92 | + if self.structure_ids is not None: |
| 93 | + mp_id = self.structure_ids[graph_id] |
| 94 | + else: |
| 95 | + mp_id = graph_id |
| 96 | + crystal_graph = self.graph_converter( |
| 97 | + struct, graph_id=graph_id, mp_id=mp_id |
| 98 | + ) |
| 99 | + targets = { |
| 100 | + "e": torch.tensor(self.energies[graph_id], dtype=datatype), |
| 101 | + "f": torch.tensor(self.forces[graph_id], dtype=datatype), |
| 102 | + } |
| 103 | + if self.stresses is not None: |
| 104 | + # Convert VASP stress |
| 105 | + targets["s"] = torch.tensor( |
| 106 | + self.stresses[graph_id], dtype=datatype |
| 107 | + ) * (-0.1) |
| 108 | + if self.magmoms is not None: |
| 109 | + mag = self.magmoms[graph_id] |
| 110 | + # use absolute value for magnetic moments |
| 111 | + if mag is None: |
| 112 | + targets["m"] = None |
| 113 | + else: |
| 114 | + targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype)) |
| 115 | + |
| 116 | + return crystal_graph, targets |
| 117 | + |
| 118 | + # Omit structures with isolated atoms. Return another randomly selected |
| 119 | + # structure |
| 120 | + except Exception: |
| 121 | + struct = self.structures[graph_id] |
| 122 | + self.failed_graph_id[graph_id] = struct.composition.formula |
| 123 | + self.failed_idx.append(idx) |
| 124 | + idx = random.randint(0, len(self) - 1) |
| 125 | + return self.__getitem__(idx) |
| 126 | + else: |
| 127 | + idx = random.randint(0, len(self) - 1) |
| 128 | + return self.__getitem__(idx) |
13 | 129 |
|
14 | 130 |
|
15 | 131 | def convert_nff_to_chgnet_structure_data( |
|
0 commit comments