Skip to content

Commit 1d0b0d1

Browse files
committed
jax predict and further preprocessing options
1 parent 611cff0 commit 1d0b0d1

File tree

15 files changed

+429
-24
lines changed

15 files changed

+429
-24
lines changed

equitrain/argparser.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,12 @@ def get_args_parser(script_type: str) -> argparse.ArgumentParser:
440440
parser.add_argument(
441441
'--output-dir', help='Output directory', type=str, default=''
442442
)
443+
parser.add_argument(
444+
'--niggli-reduce',
445+
help='Apply Niggli reduction to periodic cells before writing HDF5 data',
446+
action='store_true',
447+
default=False,
448+
)
443449

444450
elif script_type == 'train':
445451
add_common_file_args(parser)
@@ -510,6 +516,12 @@ def get_args_parser(script_type: str) -> argparse.ArgumentParser:
510516
add_common_data_args(parser)
511517
add_model_args(parser)
512518
add_loss_weights_args(parser)
519+
parser.add_argument(
520+
'--niggli-reduce',
521+
help='Apply Niggli reduction before graph construction at inference time',
522+
action='store_true',
523+
default=False,
524+
)
513525
parser.add_argument(
514526
'--predict-file',
515527
help='File with data for which predictions should be computed',

equitrain/backends/jax_backend.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,13 @@ def train(args):
470470
if r_max <= 0.0:
471471
raise RuntimeError('Model configuration must define a positive `r_max`.')
472472

473-
train_graphs = atoms_to_graphs(args.train_file, r_max, z_table)
474-
valid_graphs = atoms_to_graphs(args.valid_file, r_max, z_table)
473+
reduce_cells = getattr(args, 'niggli_reduce', False)
474+
train_graphs = atoms_to_graphs(
475+
args.train_file, r_max, z_table, niggli_reduce=reduce_cells
476+
)
477+
valid_graphs = atoms_to_graphs(
478+
args.valid_file, r_max, z_table, niggli_reduce=reduce_cells
479+
)
475480

476481
if not train_graphs:
477482
raise RuntimeError('Training dataset is empty.')
@@ -755,7 +760,9 @@ def _host(tree):
755760

756761
test_metrics = None
757762
if getattr(args, 'test_file', None):
758-
test_graphs = atoms_to_graphs(args.test_file, r_max, z_table)
763+
test_graphs = atoms_to_graphs(
764+
args.test_file, r_max, z_table, niggli_reduce=reduce_cells
765+
)
759766
test_loader = build_loader(
760767
test_graphs,
761768
batch_size=args.batch_size,

equitrain/backends/jax_evaluate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,12 @@ def evaluate(args):
4444
if r_max <= 0.0:
4545
raise RuntimeError('Model configuration must define a positive `r_max`.')
4646

47-
test_graphs = atoms_to_graphs(args.test_file, r_max, z_table)
47+
test_graphs = atoms_to_graphs(
48+
args.test_file,
49+
r_max,
50+
z_table,
51+
niggli_reduce=getattr(args, 'niggli_reduce', False),
52+
)
4853
if not test_graphs:
4954
raise RuntimeError('Test dataset is empty.')
5055

equitrain/backends/jax_predict.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from __future__ import annotations
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import numpy as np
6+
from jax import tree_util as jtu
7+
from mace_jax.data.utils import AtomicNumberTable as JaxAtomicNumberTable
8+
9+
from equitrain.argparser import check_args_complete
10+
from equitrain.data.backend_jax import atoms_to_graphs, build_loader, make_apply_fn
11+
12+
13+
def _is_multi_device() -> bool:
14+
return jax.local_device_count() > 1
15+
16+
17+
def _prepare_single_batch(graph):
18+
def _to_device_array(x):
19+
if x is None:
20+
return None
21+
return jnp.asarray(x)
22+
23+
return jtu.tree_map(_to_device_array, graph, is_leaf=lambda leaf: leaf is None)
24+
25+
26+
def _stack_or_none(chunks):
27+
if not chunks:
28+
return None
29+
return np.concatenate(chunks, axis=0)
30+
31+
32+
def predict(args):
33+
check_args_complete(args, 'predict')
34+
backend = getattr(args, 'backend', 'torch') or 'torch'
35+
if backend != 'jax':
36+
raise NotImplementedError(
37+
f'JAX predict backend invoked with unsupported backend="{backend}".'
38+
)
39+
40+
if getattr(args, 'predict_file', None) is None:
41+
raise ValueError('--predict-file is a required argument for JAX prediction.')
42+
if getattr(args, 'model', None) is None:
43+
raise ValueError('--model is a required argument for JAX prediction.')
44+
45+
if _is_multi_device():
46+
raise NotImplementedError(
47+
'JAX prediction currently supports single-device runs only. '
48+
'Set XLA flags to limit execution to one device.'
49+
)
50+
51+
bundle = _load_bundle(args.model, dtype=args.dtype)
52+
53+
atomic_numbers = bundle.config.get('atomic_numbers')
54+
if not atomic_numbers:
55+
raise RuntimeError('Model configuration is missing `atomic_numbers`.')
56+
z_table = JaxAtomicNumberTable(atomic_numbers)
57+
58+
r_max = (
59+
float(args.r_max)
60+
if getattr(args, 'r_max', None)
61+
else float(bundle.config.get('r_max', 0.0))
62+
)
63+
if r_max <= 0.0:
64+
raise RuntimeError(
65+
'Model configuration must define a positive `r_max`, or override via --r-max.'
66+
)
67+
68+
graphs = atoms_to_graphs(
69+
args.predict_file,
70+
r_max,
71+
z_table,
72+
niggli_reduce=getattr(args, 'niggli_reduce', False),
73+
)
74+
loader = build_loader(
75+
graphs,
76+
batch_size=args.batch_size,
77+
shuffle=False,
78+
max_nodes=args.batch_max_nodes,
79+
max_edges=args.batch_max_edges,
80+
drop=getattr(args, 'batch_drop', False),
81+
)
82+
if loader is None:
83+
raise RuntimeError('Prediction dataset is empty.')
84+
85+
wrapper = _create_wrapper(
86+
bundle,
87+
compute_force=getattr(args, 'forces_weight', 0.0) > 0.0,
88+
compute_stress=getattr(args, 'stress_weight', 0.0) > 0.0,
89+
)
90+
apply_fn = make_apply_fn(wrapper, num_species=len(z_table))
91+
apply_fn = jax.jit(apply_fn)
92+
93+
energies: list[np.ndarray] = []
94+
forces: list[np.ndarray] = []
95+
stresses: list[np.ndarray] = []
96+
97+
for batch in loader:
98+
micro_batches = batch if isinstance(batch, list) else [batch]
99+
for micro in micro_batches:
100+
prepared = _prepare_single_batch(micro)
101+
outputs = jax.device_get(apply_fn(bundle.params, prepared))
102+
energy_pred = np.asarray(outputs['energy'])
103+
energies.append(energy_pred.reshape(-1))
104+
105+
if outputs.get('forces') is not None:
106+
forces.append(np.asarray(outputs['forces']))
107+
if outputs.get('stress') is not None:
108+
stresses.append(np.asarray(outputs['stress']))
109+
110+
return _stack_or_none(energies), _stack_or_none(forces), _stack_or_none(stresses)
111+
112+
113+
def _load_bundle(model_path: str, dtype: str):
114+
from equitrain.backends.jax_utils import load_model_bundle as _load_model_bundle
115+
116+
return _load_model_bundle(model_path, dtype=dtype)
117+
118+
119+
def _create_wrapper(bundle, *, compute_force: bool, compute_stress: bool):
120+
from equitrain.backends.jax_wrappers import MaceWrapper as JaxMaceWrapper
121+
122+
return JaxMaceWrapper(
123+
module=bundle.module,
124+
config=bundle.config,
125+
compute_force=compute_force,
126+
compute_stress=compute_stress,
127+
)
128+
129+
130+
__all__ = ['predict']

equitrain/backends/torch_predict.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from equitrain.data.atomic import AtomicNumberTable
1212
from equitrain.data.backend_torch.atoms_to_graphs import AtomsToGraphs
1313
from equitrain.data.backend_torch.loaders import get_dataloader
14+
from equitrain.data.configuration import niggli_reduce_inplace
1415

1516

1617
def predict_graphs(
@@ -59,6 +60,7 @@ def predict_atoms(
5960
pin_memory=False,
6061
batch_size=12,
6162
device=None,
63+
niggli_reduce: bool = False,
6264
) -> list[torch.Tensor]:
6365
atoms_to_graphs = AtomsToGraphs(
6466
z_table,
@@ -72,7 +74,12 @@ def predict_atoms(
7274
r_pbc=True,
7375
)
7476

75-
graph_list = [atoms_to_graphs.convert(atom) for atom in atoms_list]
77+
graph_list = []
78+
for atom in atoms_list:
79+
atoms_copy = atom.copy()
80+
if niggli_reduce:
81+
niggli_reduce_inplace(atoms_copy)
82+
graph_list.append(atoms_to_graphs.convert(atoms_copy))
7683

7784
return predict_graphs(
7885
model,
@@ -93,6 +100,7 @@ def predict_structures(
93100
pin_memory=False,
94101
batch_size=12,
95102
device=None,
103+
niggli_reduce: bool = False,
96104
) -> list[torch.Tensor]:
97105
atoms_list = [AseAtomsAdaptor.get_atoms(structure) for structure in structure_list]
98106
return predict_atoms(
@@ -104,6 +112,7 @@ def predict_structures(
104112
pin_memory=pin_memory,
105113
batch_size=batch_size,
106114
device=device,
115+
niggli_reduce=niggli_reduce,
107116
)
108117

109118

equitrain/data/backend_jax/atoms_to_graphs.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,19 @@
1010
from mace_jax.data.utils import Configuration as JaxConfiguration
1111
from mace_jax.data.utils import graph_from_configuration
1212

13-
from equitrain.data.configuration import Configuration as EqConfiguration
13+
from equitrain.data.configuration import (
14+
Configuration as EqConfiguration,
15+
niggli_reduce_inplace,
16+
)
1417
from equitrain.data.format_hdf5.dataset import HDF5Dataset
1518

1619

1720
def atoms_to_graphs(
1821
data_path: Path | str,
1922
r_max: float,
2023
z_table: JaxAtomicNumberTable,
24+
*,
25+
niggli_reduce: bool = False,
2126
) -> list[jraph.GraphsTuple]:
2227
if data_path is None:
2328
return []
@@ -27,6 +32,9 @@ def atoms_to_graphs(
2732
try:
2833
for idx in range(len(dataset)):
2934
atoms = dataset[idx]
35+
if niggli_reduce:
36+
atoms = atoms.copy()
37+
niggli_reduce_inplace(atoms)
3038
eq_conf = EqConfiguration.from_atoms(atoms)
3139
jax_conf = JaxConfiguration(
3240
atomic_numbers=eq_conf.atomic_numbers,

equitrain/data/backend_torch/loaders.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,13 @@ def get_dataloader(
8080
if data_file is None:
8181
return None
8282

83-
data_set = HDF5GraphDataset(data_file, r_max=r_max, atomic_numbers=atomic_numbers)
83+
niggli_reduce = getattr(args, 'niggli_reduce', False)
84+
data_set = HDF5GraphDataset(
85+
data_file,
86+
r_max=r_max,
87+
atomic_numbers=atomic_numbers,
88+
niggli_reduce=niggli_reduce,
89+
)
8490

8591
pin_memory = _should_pin_memory(args.pin_memory, accelerator)
8692
num_workers = _resolve_num_workers(args.workers, accelerator)

equitrain/data/configuration.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,17 @@ def get_forces(self, apply_constraint=False):
156156

157157
def get_stress(self, apply_constraint=False):
158158
return self.stress
159+
160+
161+
def niggli_reduce_inplace(atoms):
162+
"""Apply an in-place Niggli reduction when periodic directions exist."""
163+
from ase.build.tools import niggli_reduce as _niggli_reduce
164+
165+
pbc = getattr(atoms, 'pbc', None)
166+
if pbc is None:
167+
return atoms
168+
if not np.any(pbc):
169+
return atoms
170+
171+
_niggli_reduce(atoms)
172+
return atoms

equitrain/data/format_hdf5/dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ase import Atoms
66

77
from equitrain.data.atomic import AtomicNumberTable
8-
from equitrain.data.configuration import CachedCalc
8+
from equitrain.data.configuration import CachedCalc, niggli_reduce_inplace
99

1010

1111
class HDF5Dataset:
@@ -295,11 +295,13 @@ def __init__(
295295
r_max: float,
296296
atomic_numbers: AtomicNumberTable,
297297
*,
298+
niggli_reduce: bool = False,
298299
atoms_to_graphs_cls=None,
299300
**kwargs,
300301
):
301302
super().__init__(filename, mode='r', **kwargs)
302303

304+
self._niggli_reduce = niggli_reduce
303305
if atoms_to_graphs_cls is None:
304306
from equitrain.data.backend_torch import (
305307
AtomsToGraphs as atoms_to_graphs_cls,
@@ -317,6 +319,8 @@ def __init__(
317319

318320
def __getitem__(self, index):
319321
atoms = super().__getitem__(index)
322+
if self._niggli_reduce:
323+
niggli_reduce_inplace(atoms)
320324
graph = self.converter.convert(atoms)
321325
graph.idx = index
322326

equitrain/data/format_lmdb/lmdb.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def convert_lmdb_to_hdf5(
130130
dst: Path | str,
131131
*,
132132
config: Mapping | None = None,
133+
atoms_transform=None,
133134
overwrite: bool = False,
134135
show_progress: bool = False,
135136
) -> Path:
@@ -145,6 +146,9 @@ def convert_lmdb_to_hdf5(
145146
``overwrite`` is ``True``.
146147
config:
147148
Optional dictionary passed to ``AseDBDataset`` (e.g. metadata entries).
149+
atoms_transform:
150+
Optional callable applied to each ``Atoms`` object prior to storage
151+
(e.g. lattice reductions).
148152
overwrite:
149153
When ``False`` (default) an existing destination file raises ``FileExistsError``.
150154
show_progress:
@@ -174,6 +178,8 @@ def convert_lmdb_to_hdf5(
174178
with HDF5Dataset(dst, mode='w') as storage:
175179
for index, record in enumerate(iterator):
176180
atoms = lmdb_entry_to_atoms(record)
181+
if atoms_transform is not None:
182+
atoms = atoms_transform(atoms)
177183
storage[index] = atoms
178184

179185
return dst

0 commit comments

Comments
 (0)