Skip to content

Commit 2015c6c

Browse files
committed
chgnet io
1 parent 65e76b7 commit 2015c6c

File tree

1 file changed

+119
-2
lines changed

1 file changed

+119
-2
lines changed

nff/io/chgnet.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,132 @@
11
"""Convert NFF Dataset to CHGNet StructureData"""
22

33
from typing import Dict, List
4-
4+
import functools
5+
import random
56
import torch
6-
from chgnet.data.dataset import StructureData
7+
import numpy as np
78
from pymatgen.core.structure import Structure
89
from pymatgen.io.ase import AseAtomsAdaptor
910

1011
from nff.data import Dataset
1112
from nff.io import AtomsBatch
1213
from nff.utils.cuda import batch_detach, detach
14+
from chgnet.graph import CrystalGraph, CrystalGraphConverter
15+
if TYPE_CHECKING:
16+
from collections.abc import Sequence
17+
18+
datatype = torch.float32
19+
20+
class StructureData(Dataset):
21+
"""A simple torch Dataset of structures."""
22+
23+
def __init__(
24+
self,
25+
structures: list[Structure],
26+
energies: list[float],
27+
forces: list[Sequence[Sequence[float]]],
28+
stresses: list[Sequence[Sequence[float]]] | None = None,
29+
magmoms: list[Sequence[Sequence[float]]] | None = None,
30+
structure_ids: list[str] | None = None,
31+
graph_converter: CrystalGraphConverter | None = None,
32+
) -> None:
33+
"""Initialize the dataset.
34+
35+
Args:
36+
structures (list[dict]): pymatgen Structure objects.
37+
energies (list[float]): [data_size, 1]
38+
forces (list[list[float]]): [data_size, n_atoms, 3]
39+
stresses (list[list[float]], optional): [data_size, 3, 3]
40+
magmoms (list[list[float]], optional): [data_size, n_atoms, 1]
41+
structure_ids (list[str], optional): a list of ids to track the structures
42+
graph_converter (CrystalGraphConverter, optional): Converts the structures
43+
to graphs. If None, it will be set to CHGNet 0.3.0 converter
44+
with AtomGraph cutoff = 6A.
45+
46+
Raises:
47+
RuntimeError: if the length of structures and labels (energies, forces,
48+
stresses, magmoms) are not equal.
49+
"""
50+
for idx, struct in enumerate(structures):
51+
if not isinstance(struct, Structure):
52+
raise ValueError(f"{idx} is not a pymatgen Structure object: {struct}")
53+
for name in "energies forces stresses magmoms structure_ids".split():
54+
labels = locals()[name]
55+
if labels is not None and len(labels) != len(structures):
56+
raise RuntimeError(
57+
f"Inconsistent number of structures and labels: "
58+
f"{len(structures)=}, len({name})={len(labels)}"
59+
)
60+
self.structures = structures
61+
self.energies = energies
62+
self.forces = forces
63+
self.stresses = stresses
64+
self.magmoms = magmoms
65+
self.structure_ids = structure_ids
66+
self.keys = np.arange(len(structures))
67+
random.shuffle(self.keys)
68+
self.graph_converter = graph_converter or CrystalGraphConverter(
69+
atom_graph_cutoff=6, bond_graph_cutoff=3
70+
)
71+
self.failed_idx: list[int] = []
72+
self.failed_graph_id: dict[str, str] = {}
73+
74+
def __len__(self) -> int:
75+
"""Get the number of structures in this dataset."""
76+
return len(self.keys)
77+
78+
@functools.cache # Cache loaded structures
79+
def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]:
80+
"""Get one graph for a structure in this dataset.
81+
82+
Args:
83+
idx (int): Index of the structure
84+
85+
Returns:
86+
crystal_graph (CrystalGraph): graph of the crystal structure
87+
targets (dict): list of targets. i.e. energy, force, stress
88+
"""
89+
if idx not in self.failed_idx:
90+
graph_id = self.keys[idx]
91+
try:
92+
struct = self.structures[graph_id]
93+
if self.structure_ids is not None:
94+
mp_id = self.structure_ids[graph_id]
95+
else:
96+
mp_id = graph_id
97+
crystal_graph = self.graph_converter(
98+
struct, graph_id=graph_id, mp_id=mp_id
99+
)
100+
targets = {
101+
"e": torch.tensor(self.energies[graph_id], dtype=datatype),
102+
"f": torch.tensor(self.forces[graph_id], dtype=datatype),
103+
}
104+
if self.stresses is not None:
105+
# Convert VASP stress
106+
targets["s"] = torch.tensor(
107+
self.stresses[graph_id], dtype=datatype
108+
) * (-0.1)
109+
if self.magmoms is not None:
110+
mag = self.magmoms[graph_id]
111+
# use absolute value for magnetic moments
112+
if mag is None:
113+
targets["m"] = None
114+
else:
115+
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
116+
117+
return crystal_graph, targets
118+
119+
# Omit structures with isolated atoms. Return another randomly selected
120+
# structure
121+
except Exception:
122+
struct = self.structures[graph_id]
123+
self.failed_graph_id[graph_id] = struct.composition.formula
124+
self.failed_idx.append(idx)
125+
idx = random.randint(0, len(self) - 1)
126+
return self.__getitem__(idx)
127+
else:
128+
idx = random.randint(0, len(self) - 1)
129+
return self.__getitem__(idx)
13130

14131

15132
def convert_nff_to_chgnet_structure_data(

0 commit comments

Comments
 (0)