Skip to content

Commit ac01168

Browse files
committed
improved hdf5 format
1 parent 2361431 commit ac01168

File tree

8 files changed

+384
-42
lines changed

8 files changed

+384
-42
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,13 @@ equitrain-preprocess \
129129

130130
The preprocessing command accepts `.xyz`, `.lmdb`/`.aselmdb`, and `.h5` inputs; LMDB datasets are automatically converted to the native HDF5 format before statistics are computed. XYZ files are parsed through ASE so that lattice vectors, species labels, and per-configuration metadata are retained. The generated HDF5 archive is a lightweight collection of numbered groups where each entry stores positions, atomic numbers, energy, optional forces and stress, the cell matrix, and periodic boundary conditions. Precomputed statistics (means, standard deviations, cutoff radius, atomic energies) are stored alongside and reused by the training entry points.
131131

132+
Under the hood, each processed file is organised as:
133+
134+
- `/structures`: per-configuration metadata (cell, energy, stress, weights, etc.) and pointers into the per-atom arrays.
135+
- `/positions`, `/forces`, `/atomic_numbers`: flat, chunked arrays sized by the total number of atoms across the dataset. Random reads only touch the slices required for a batch.
136+
137+
This layout keeps the HDF5 file compact even for tens of millions of structures: chunked per-atom arrays avoid the pointer-chasing overhead of variable-length fields, enabling efficient multi-worker dataloaders that issue many small reads concurrently.
138+
132139
<!-- TODO: change this following a notebook style -->
133140
#### Python Script:
134141

equitrain/data/format_hdf5/dataset.py

Lines changed: 177 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,60 @@
99

1010

1111
class HDF5Dataset:
12-
# ! check whether this can be pushed to repo
12+
"""
13+
Lightweight, append-only HDF5 store for ASE ``Atoms`` objects.
14+
15+
Layout & performance
16+
--------------------
17+
- ``/structures``: per-configuration metadata (cell, PBC, energy, stress, weights,
18+
etc.) plus the offset/length that point into the contiguous arrays below.
19+
Structure-level quantities such as stress or dipole live here because they do
20+
not scale with atom count.
21+
- ``/positions``, ``/forces``, ``/atomic_numbers``: flat, chunked arrays that
22+
contain per-atom data. Random access only touches the slices required for a
23+
given structure.
24+
25+
Each per-atom array is chunked by a configurable number of atoms (default 1024),
26+
which keeps small random reads in-cache for typical batch sizes and avoids the
27+
pointer-chasing penalty of HDF5 variable-length fields. Appending new structures
28+
only requires extending those arrays, so the file remains compact and performant
29+
even with tens of millions of entries.
30+
"""
31+
1332
MAGIC_STRING = 'ZVNjaWVuY2UgRXF1aXRyYWlu'
33+
STRUCTURES_DATASET = 'structures'
34+
POSITIONS_DATASET = 'positions'
35+
FORCES_DATASET = 'forces'
36+
ATOMIC_NUMBERS_DATASET = 'atomic_numbers'
37+
_DEFAULT_CHUNK_ATOMS = 1024
1438

1539
def __init__(self, filename: Path | str, mode: str = 'r'):
1640
filename = Path(filename)
1741

18-
if filename.exists():
19-
self.file = h5py.File(filename, mode)
20-
self.check_magic()
21-
else:
22-
self.file = h5py.File(filename, mode)
42+
self.file = h5py.File(filename, mode)
43+
44+
if 'MAGIC' not in self.file:
2345
self.write_magic()
2446
self.create_dataset()
47+
else:
48+
self.check_magic()
49+
if self.STRUCTURES_DATASET not in self.file:
50+
raise OSError(
51+
'HDF5 file was created with an unsupported legacy format. '
52+
'Please regenerate the dataset with the current Equitrain version.'
53+
)
2554

2655
def create_dataset(self):
56+
if self.STRUCTURES_DATASET in self.file:
57+
return
58+
2759
atom_dtype = np.dtype(
2860
[
29-
('atomic_numbers', h5py.special_dtype(vlen=np.int32)),
30-
('positions', h5py.special_dtype(vlen=np.float64)),
61+
('offset', np.int64),
62+
('length', np.int32),
3163
('cell', np.float64, (3, 3)),
3264
('pbc', np.bool_, (3,)),
3365
('energy', np.float64),
34-
('forces', h5py.special_dtype(vlen=np.float64)),
3566
('stress', np.float64, (6,)),
3667
('virials', np.float64, (3, 3)),
3768
('dipole', np.float64, (3,)),
@@ -42,13 +73,36 @@ def create_dataset(self):
4273
('dipole_weight', np.float32),
4374
]
4475
)
45-
# There are some parameters that should be accessible through
46-
# command-line options, i.e. chunking and compression
76+
4777
self.file.create_dataset(
48-
'atoms',
49-
shape=(0,), # Initially empty
50-
maxshape=(None,), # Extendable along the first dimension
78+
self.STRUCTURES_DATASET,
79+
shape=(0,),
80+
maxshape=(None,),
5181
dtype=atom_dtype,
82+
chunks=True,
83+
)
84+
85+
chunk_atoms = self._DEFAULT_CHUNK_ATOMS
86+
self.file.create_dataset(
87+
self.POSITIONS_DATASET,
88+
shape=(0, 3),
89+
maxshape=(None, 3),
90+
dtype=np.float64,
91+
chunks=(chunk_atoms, 3),
92+
)
93+
self.file.create_dataset(
94+
self.FORCES_DATASET,
95+
shape=(0, 3),
96+
maxshape=(None, 3),
97+
dtype=np.float64,
98+
chunks=(chunk_atoms, 3),
99+
)
100+
self.file.create_dataset(
101+
self.ATOMIC_NUMBERS_DATASET,
102+
shape=(0,),
103+
maxshape=(None,),
104+
dtype=np.int32,
105+
chunks=(chunk_atoms,),
52106
)
53107

54108
def open(self, filename: Path | str, mode: str = 'r'):
@@ -82,21 +136,29 @@ def __getstate__(self):
82136
return d
83137

84138
def __len__(self):
85-
if 'atoms' not in self.file:
86-
raise RuntimeError("Dataset 'atoms' does not exist")
87-
return self.file['atoms'].shape[0]
139+
return self.file[self.STRUCTURES_DATASET].shape[0]
88140

89141
def __getitem__(self, i: int) -> Atoms:
90-
entry = self.file['atoms'][i]
91-
num_atoms = len(entry['positions']) // 3
142+
structures = self.file[self.STRUCTURES_DATASET]
143+
entry = structures[i]
144+
offset = int(entry['offset'])
145+
length = int(entry['length'])
146+
end = offset + length
147+
148+
positions = self.file[self.POSITIONS_DATASET][offset:end]
149+
forces = self.file[self.FORCES_DATASET][offset:end]
150+
atomic_numbers = self.file[self.ATOMIC_NUMBERS_DATASET][offset:end]
151+
92152
atoms = Atoms(
93-
numbers=entry['atomic_numbers'],
94-
positions=entry['positions'].reshape((num_atoms, 3)),
153+
numbers=atomic_numbers.astype(np.int32, copy=False),
154+
positions=positions,
95155
cell=entry['cell'],
96156
pbc=entry['pbc'],
97157
)
98158
atoms.calc = CachedCalc(
99-
entry['energy'], entry['forces'].reshape((num_atoms, 3)), entry['stress']
159+
float(entry['energy']),
160+
forces,
161+
entry['stress'],
100162
)
101163
atoms.info['virials'] = entry['virials']
102164
atoms.info['dipole'] = entry['dipole']
@@ -108,26 +170,99 @@ def __getitem__(self, i: int) -> Atoms:
108170
return atoms
109171

110172
def __setitem__(self, i: int, atoms: Atoms) -> None:
111-
dataset = self.file['atoms']
112-
# Extend dataset if necessary
113-
if i >= len(dataset):
114-
dataset.resize(i + 1, axis=0)
115-
116-
dataset[i] = (
117-
atoms.get_atomic_numbers().astype(np.int32),
118-
atoms.get_positions().flatten().astype(np.float64),
119-
atoms.get_cell().astype(np.float64),
120-
atoms.get_pbc().astype(np.bool_),
121-
np.float64(atoms.get_potential_energy()),
122-
atoms.get_forces().flatten().astype(np.float64),
123-
atoms.get_stress().astype(np.float64),
124-
atoms.info['virials'].astype(np.float64),
125-
atoms.info['dipole'].astype(np.float64),
126-
np.float32(atoms.info.get('energy_weight', 1.0)),
127-
np.float32(atoms.info.get('forces_weight', 1.0)),
128-
np.float32(atoms.info.get('stress_weight', 1.0)),
129-
np.float32(atoms.info.get('virials_weight', 1.0)),
130-
np.float32(atoms.info.get('dipole_weight', 1.0)),
173+
structures = self.file[self.STRUCTURES_DATASET]
174+
positions_ds = self.file[self.POSITIONS_DATASET]
175+
forces_ds = self.file[self.FORCES_DATASET]
176+
atomic_numbers_ds = self.file[self.ATOMIC_NUMBERS_DATASET]
177+
178+
numbers = atoms.get_atomic_numbers().astype(np.int32, copy=True)
179+
positions = np.asarray(atoms.get_positions(), dtype=np.float64)
180+
forces = np.asarray(atoms.get_forces(), dtype=np.float64)
181+
length = positions.shape[0]
182+
183+
if length != numbers.shape[0] or length != forces.shape[0]:
184+
raise ValueError('Inconsistent atom count for positions/forces/numbers')
185+
186+
cell = atoms.get_cell().astype(np.float64)
187+
pbc = atoms.get_pbc().astype(np.bool_)
188+
energy = np.float64(atoms.get_potential_energy())
189+
stress = np.asarray(atoms.get_stress(), dtype=np.float64).reshape(6)
190+
virials = np.asarray(
191+
atoms.info.get('virials', np.zeros((3, 3), dtype=np.float64)),
192+
dtype=np.float64,
193+
).reshape(3, 3)
194+
dipole = np.asarray(
195+
atoms.info.get('dipole', np.zeros(3, dtype=np.float64)),
196+
dtype=np.float64,
197+
).reshape(3)
198+
energy_weight = np.float32(atoms.info.get('energy_weight', 1.0))
199+
forces_weight = np.float32(atoms.info.get('forces_weight', 1.0))
200+
stress_weight = np.float32(atoms.info.get('stress_weight', 1.0))
201+
virials_weight = np.float32(atoms.info.get('virials_weight', 1.0))
202+
dipole_weight = np.float32(atoms.info.get('dipole_weight', 1.0))
203+
204+
current_len = len(structures)
205+
if i < current_len:
206+
entry = structures[i]
207+
if int(entry['length']) != length:
208+
raise ValueError(
209+
'Cannot change number of atoms for an existing entry in HDF5Dataset'
210+
)
211+
offset = int(entry['offset'])
212+
end = offset + length
213+
positions_ds[offset:end] = positions
214+
forces_ds[offset:end] = forces
215+
atomic_numbers_ds[offset:end] = numbers
216+
structures[i] = (
217+
offset,
218+
length,
219+
cell,
220+
pbc,
221+
energy,
222+
stress,
223+
virials,
224+
dipole,
225+
energy_weight,
226+
forces_weight,
227+
stress_weight,
228+
virials_weight,
229+
dipole_weight,
230+
)
231+
return
232+
233+
if i > current_len:
234+
raise IndexError(
235+
'Cannot assign to non-contiguous index in HDF5Dataset; '
236+
f'expected index {current_len}, received {i}'
237+
)
238+
239+
offset = positions_ds.shape[0]
240+
end = offset + length
241+
242+
positions_ds.resize(end, axis=0)
243+
positions_ds[offset:end] = positions
244+
245+
forces_ds.resize(end, axis=0)
246+
forces_ds[offset:end] = forces
247+
248+
atomic_numbers_ds.resize(end, axis=0)
249+
atomic_numbers_ds[offset:end] = numbers
250+
251+
structures.resize(current_len + 1, axis=0)
252+
structures[current_len] = (
253+
offset,
254+
length,
255+
cell,
256+
pbc,
257+
energy,
258+
stress,
259+
virials,
260+
dipole,
261+
energy_weight,
262+
forces_weight,
263+
stress_weight,
264+
virials_weight,
265+
dipole_weight,
131266
)
132267

133268
def check_magic(self):
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# MACE-JAX Resources for Equitrain
2+
3+
Utilities and examples for working with pre-trained [MACE](https://github.com/ACEsuit/mace) foundation models in the JAX backend via the companion [mace-jax](https://github.com/ACEsuit/mace-jax) project.
4+
5+
## Contents
6+
7+
- `convert_foundation_to_jax.py` – downloads a Torch MACE foundation model (e.g. the `mp` “small” checkpoint), converts it to MACE-JAX parameters using `mace_jax.cli.mace_torch2jax`, and writes a ready-to-use bundle (`config.json` + `params.msgpack`).
8+
9+
## Usage
10+
11+
Activate an environment that has both `mace` and `mace-jax` installed (including the optional `cuequivariance` extras when available), then run:
12+
13+
```bash
14+
python resources/models/mace-jax/convert_foundation_to_jax.py \
15+
--source mp \
16+
--model small \
17+
--output-dir resources/models/mace-jax/mp-small-jax
18+
```
19+
20+
This produces a directory containing the serialized parameters and a JSON configuration that can be passed directly to Equitrain’s JAX backend (`--model path/to/bundle`) or loaded with the utilities in `mace_jax.tools`.
21+
22+
Use `--source` to pick a different foundation family (`mp`, `off`, `anicc`, `omol`) and `--model` to select a specific variant when multiple sizes exist.
23+
24+
## Dependencies
25+
26+
The script relies on the optional `mace` and `mace-jax` stacks, including their CUDA-enabled cuequivariance extensions. Install them via:
27+
28+
```bash
29+
pip install equitrain[mace,jax] # or the corresponding mace/mace-jax wheels
30+
```
31+
32+
If the cuequivariance libraries are unavailable, the script will exit after downloading the Torch model; the export step itself requires the accelerated kernels to be importable. Run `python -c "import mace_jax, cuequivariance_ops_torch"` to check whether your environment is configured correctly.

0 commit comments

Comments
 (0)