Skip to content

Commit 7385124

Browse files
committed
update mace ,chgnet io
1 parent e645de4 commit 7385124

File tree

3 files changed

+135
-32
lines changed

3 files changed

+135
-32
lines changed

nff/io/chgnet.py

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,131 @@
11
"""Convert NFF Dataset to CHGNet StructureData"""
22

3-
from typing import Dict, List
4-
3+
from typing import Dict, List, TYPE_CHECKING
4+
import functools
5+
import random
56
import torch
6-
from chgnet.data.dataset import StructureData
7+
import numpy as np
78
from pymatgen.core.structure import Structure
89
from pymatgen.io.ase import AseAtomsAdaptor
910

1011
from nff.data import Dataset
1112
from nff.io import AtomsBatch
1213
from nff.utils.cuda import batch_detach, detach
14+
from chgnet.graph import CrystalGraph, CrystalGraphConverter
15+
from collections.abc import Sequence
16+
17+
datatype = torch.float32
18+
19+
class StructureData(Dataset):
20+
"""A simple torch Dataset of structures."""
21+
22+
def __init__(
23+
self,
24+
structures: list[Structure],
25+
energies: list[float],
26+
forces: list[Sequence[Sequence[float]]],
27+
stresses: list[Sequence[Sequence[float]]] | None = None,
28+
magmoms: list[Sequence[Sequence[float]]] | None = None,
29+
structure_ids: list[str] | None = None,
30+
graph_converter: CrystalGraphConverter | None = None,
31+
) -> None:
32+
"""Initialize the dataset.
33+
34+
Args:
35+
structures (list[dict]): pymatgen Structure objects.
36+
energies (list[float]): [data_size, 1]
37+
forces (list[list[float]]): [data_size, n_atoms, 3]
38+
stresses (list[list[float]], optional): [data_size, 3, 3]
39+
magmoms (list[list[float]], optional): [data_size, n_atoms, 1]
40+
structure_ids (list[str], optional): a list of ids to track the structures
41+
graph_converter (CrystalGraphConverter, optional): Converts the structures
42+
to graphs. If None, it will be set to CHGNet 0.3.0 converter
43+
with AtomGraph cutoff = 6A.
44+
45+
Raises:
46+
RuntimeError: if the length of structures and labels (energies, forces,
47+
stresses, magmoms) are not equal.
48+
"""
49+
for idx, struct in enumerate(structures):
50+
if not isinstance(struct, Structure):
51+
raise ValueError(f"{idx} is not a pymatgen Structure object: {struct}")
52+
for name in "energies forces stresses magmoms structure_ids".split():
53+
labels = locals()[name]
54+
if labels is not None and len(labels) != len(structures):
55+
raise RuntimeError(
56+
f"Inconsistent number of structures and labels: "
57+
f"{len(structures)=}, len({name})={len(labels)}"
58+
)
59+
self.structures = structures
60+
self.energies = energies
61+
self.forces = forces
62+
self.stresses = stresses
63+
self.magmoms = magmoms
64+
self.structure_ids = structure_ids
65+
self.keys = np.arange(len(structures))
66+
random.shuffle(self.keys)
67+
self.graph_converter = graph_converter or CrystalGraphConverter(
68+
atom_graph_cutoff=6, bond_graph_cutoff=3
69+
)
70+
self.failed_idx: list[int] = []
71+
self.failed_graph_id: dict[str, str] = {}
72+
73+
def __len__(self) -> int:
74+
"""Get the number of structures in this dataset."""
75+
return len(self.keys)
76+
77+
@functools.cache # Cache loaded structures
78+
def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]:
79+
"""Get one graph for a structure in this dataset.
80+
81+
Args:
82+
idx (int): Index of the structure
83+
84+
Returns:
85+
crystal_graph (CrystalGraph): graph of the crystal structure
86+
targets (dict): list of targets. i.e. energy, force, stress
87+
"""
88+
if idx not in self.failed_idx:
89+
graph_id = self.keys[idx]
90+
try:
91+
struct = self.structures[graph_id]
92+
if self.structure_ids is not None:
93+
mp_id = self.structure_ids[graph_id]
94+
else:
95+
mp_id = graph_id
96+
crystal_graph = self.graph_converter(
97+
struct, graph_id=graph_id, mp_id=mp_id
98+
)
99+
targets = {
100+
"e": torch.tensor(self.energies[graph_id], dtype=datatype),
101+
"f": torch.tensor(self.forces[graph_id], dtype=datatype),
102+
}
103+
if self.stresses is not None:
104+
# Convert VASP stress
105+
targets["s"] = torch.tensor(
106+
self.stresses[graph_id], dtype=datatype
107+
) * (-0.1)
108+
if self.magmoms is not None:
109+
mag = self.magmoms[graph_id]
110+
# use absolute value for magnetic moments
111+
if mag is None:
112+
targets["m"] = None
113+
else:
114+
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
115+
116+
return crystal_graph, targets
117+
118+
# Omit structures with isolated atoms. Return another randomly selected
119+
# structure
120+
except Exception:
121+
struct = self.structures[graph_id]
122+
self.failed_graph_id[graph_id] = struct.composition.formula
123+
self.failed_idx.append(idx)
124+
idx = random.randint(0, len(self) - 1)
125+
return self.__getitem__(idx)
126+
else:
127+
idx = random.randint(0, len(self) - 1)
128+
return self.__getitem__(idx)
13129

14130

15131
def convert_nff_to_chgnet_structure_data(

nff/io/mace.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import os
33
import urllib
4+
from pathlib import Path
45
from collections.abc import Iterable
56
from typing import Dict, List, Optional, Tuple, Union
67

@@ -17,6 +18,7 @@
1718
from e3nn import o3
1819
from nff.data import Dataset
1920
from nff.utils.cuda import detach
21+
from mace.calculators.foundations_models import download_mace_mp_checkpoint, mace_mp_names
2022

2123
# get the path to NFF models dir, which is the parent directory of this file
2224
module_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "..", "..", "..", "models"))
@@ -37,7 +39,7 @@ def _check_non_zero(std):
3739
return std
3840

3941

40-
def get_mace_mp_model_path(model: Optional[str] = None, supress_print=True) -> str:
42+
def get_mace_foundtion_model_path(model: Optional[str] = None, supress_print=True) -> str:
4143
"""Get the default MACE MP model. Replicated from the MACE codebase,
4244
Copyright (c) 2022 ACEsuit/mace and licensed under the MIT license.
4345
@@ -52,32 +54,17 @@ def get_mace_mp_model_path(model: Optional[str] = None, supress_print=True) -> s
5254
Returns:
5355
str: path to the model
5456
"""
55-
if model in (None, "medium") and os.path.isfile(LOCAL_MODEL_PATH):
56-
model_path = LOCAL_MODEL_PATH
57+
try:
58+
if model in mace_mp_names or str(model).startswith("https:"):
59+
model_path = download_mace_mp_checkpoint(model)
60+
else:
61+
if not Path(model).exists():
62+
raise FileNotFoundError(f"{model} not found locally")
63+
model_path = model
5764
if not supress_print:
58-
print(f"Using local medium Materials Project MACE model for MACECalculator {model}")
59-
elif model in (None, "small", "medium", "large") or str(model).startswith("https:"):
60-
try:
61-
checkpoint_url = (
62-
MACE_URLS.get(model, MACE_URLS["medium"]) if model in (None, "small", "medium", "large") else model
63-
)
64-
cache_dir = os.path.expanduser("~/.cache/mace")
65-
checkpoint_url_name = "".join(c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_")
66-
model_path = f"{cache_dir}/{checkpoint_url_name}"
67-
if not os.path.isfile(model_path):
68-
os.makedirs(cache_dir, exist_ok=True)
69-
# download and save to disk
70-
urllib.request.urlretrieve(checkpoint_url, model_path)
71-
if not supress_print:
72-
print(f"Downloading MACE model from {checkpoint_url!r}")
73-
print(f"Cached MACE model to {model_path}")
74-
if not supress_print:
75-
msg = f"Loading Materials Project MACE with {model_path}"
76-
print(msg)
77-
except Exception as exc:
78-
raise RuntimeError("Model download failed and no local model found") from exc
79-
else:
80-
raise RuntimeError("Model download failed and no local model found")
65+
print(f"Using MACE mdoel with {model_path}")
66+
except Exception as exc:
67+
raise RuntimeError("Model download failed or no local model found") from exc
8168

8269
return model_path
8370

nff/nn/models/mace.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
NffBatch,
2626
get_atomic_number_table_from_zs,
2727
get_init_kwargs_from_model,
28-
get_mace_mp_model_path,
28+
get_mace_foundtion_model_path,
2929
)
3030

3131

@@ -293,8 +293,8 @@ def load_foundations(
293293
Returns:
294294
NffScaleMACE: NffScaleMACE foundational model.
295295
"""
296-
mace_model_path = get_mace_mp_model_path(model)
297-
mace_model = torch.load(mace_model_path, map_location=map_location)
296+
mace_model_path = get_mace_foundtion_model_path(model)
297+
mace_model = torch.load(mace_model_path, map_location=map_location, weights_only=False)
298298
init_params = get_init_kwargs_from_model(mace_model)
299299
model_dtype = get_model_dtype(mace_model)
300300
if default_dtype == "":

0 commit comments

Comments
 (0)