Skip to content

Commit d2a9d59

Browse files
authored
Use str ids in AseReadDataset (#1555)
* use str as system ids in AseReadDataset * add AseReadDataset to tests * reduce number of relax steps in test * mark tests GPU * fix tests
1 parent 04ba469 commit d2a9d59

File tree

7 files changed

+59
-50
lines changed

7 files changed

+59
-50
lines changed

ruff.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ select = [
3333
"YTT", # flake8-2020
3434
]
3535
ignore = [
36+
"C408", # Unnecessary `dict` call (rewrite as a literal)
3637
"PLR", # Design related pylint codes
3738
"E501", # Line too long
3839
"B028", # No explicit stacklevel
@@ -55,7 +56,7 @@ ignore = [
5556
"RUF005", # concat lists
5657
"SIM108", # Use ternary operator
5758
"PT006", # Wrong type passed to first argument
58-
"PYI024", # Use `typing.NamedTuple` instead of `collections.namedtuple`
59+
"PYI024", # Use `typing.NamedTuple` instead of `collections.namedtuple`,
5960
]
6061
unfixable = [
6162
"T20", # Removes print statements

src/fairchem/core/datasets/ase_datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,14 @@ def _load_dataset_get_ids(self, config) -> list[Path]:
235235

236236
def get_atoms(self, idx: int) -> ase.Atoms:
237237
try:
238-
str_file = self.ids[idx]
239-
atoms = ase.io.read(str_file, **self.ase_read_args)
238+
file_path = self.ids[idx]
239+
atoms = ase.io.read(file_path, **self.ase_read_args)
240240
except Exception as err:
241241
warnings.warn(f"{err} occured for: {idx}", stacklevel=2)
242242
raise err
243243

244244
if "sid" not in atoms.info:
245-
atoms.info["sid"] = str_file
245+
atoms.info["sid"] = str(file_path)
246246

247247
return atoms
248248

tests/core/components/conftest.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,14 @@
77

88
from __future__ import annotations
99

10-
from itertools import product
1110
from random import choice
1211

13-
import numpy as np
1412
import pytest
15-
from ase.db import connect
16-
from pymatgen.core import Structure
17-
from pymatgen.core.periodic_table import Element
1813

1914
from fairchem.core import FAIRChemCalculator, pretrained_mlip
20-
from fairchem.core.datasets import AseDBDataset
2115

2216

23-
@pytest.fixture(scope="module")
17+
@pytest.fixture(scope="session")
2418
def calculator() -> FAIRChemCalculator:
2519
uma_sm_models = [
2620
model for model in pretrained_mlip.available_models if "uma-s" in model

tests/core/components/test_calculate_runners.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import os
1111

12+
import pytest
13+
1214
from fairchem.core.components.calculate import (
1315
ElasticityRunner,
1416
RelaxationRunner,
@@ -17,6 +19,7 @@
1719
from fairchem.core.datasets.atoms_sequence import AtomsDatasetSequence
1820

1921

22+
@pytest.mark.gpu()
2023
def test_elasticity_runner(calculator, dummy_binary_dataset, tmp_path):
2124
elastic_runner = ElasticityRunner(
2225
calculator, input_data=AtomsDatasetSequence(dummy_binary_dataset)
@@ -55,6 +58,7 @@ def test_elasticity_runner(calculator, dummy_binary_dataset, tmp_path):
5558
assert len(results) == len(dummy_binary_dataset) // 2
5659

5760

61+
@pytest.mark.gpu()
5862
def test_singlepoint_runner(calculator, dummy_binary_dataset, tmp_path):
5963
# Test basic instantiation
6064
singlepoint_runner = SinglePointRunner(
@@ -75,7 +79,6 @@ def test_singlepoint_runner(calculator, dummy_binary_dataset, tmp_path):
7579
input_data=AtomsDatasetSequence(dummy_binary_dataset),
7680
calculate_properties=["energy", "forces"],
7781
normalize_properties_by={"energy": "natoms"},
78-
save_target_properties=["energy"],
7982
)
8083
results_custom = singlepoint_runner_custom.calculate()
8184
assert len(results_custom) == len(dummy_binary_dataset)
@@ -95,6 +98,7 @@ def test_singlepoint_runner(calculator, dummy_binary_dataset, tmp_path):
9598
assert singlepoint_runner.save_state("dummy_checkpoint") is True
9699

97100

101+
@pytest.mark.gpu()
98102
def test_relaxation_runner(calculator, dummy_binary_dataset, tmp_path):
99103
# Test basic instantiation
100104
relaxation_runner = RelaxationRunner(
@@ -121,9 +125,8 @@ def test_relaxation_runner(calculator, dummy_binary_dataset, tmp_path):
121125
calculate_properties=["energy", "forces"],
122126
save_relaxed_atoms=False,
123127
normalize_properties_by={"energy": "natoms"},
124-
save_target_properties=["energy"],
125128
fmax=0.1, # relax_kwargs
126-
steps=50, # relax_kwargs
129+
steps=5, # relax_kwargs
127130
)
128131
results_custom = relaxation_runner_custom.calculate()
129132
assert len(results_custom) == len(dummy_binary_dataset)

tests/core/conftest.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
import numpy as np
1717
import pytest
1818
import torch
19+
from ase.calculators.singlepoint import SinglePointCalculator
1920
from ase.db import connect
21+
from ase.io import write
2022
from pymatgen.core import Structure
2123
from pymatgen.core.periodic_table import Element
2224
from syrupy.extensions.amber import AmberSnapshotExtension
2325

24-
from fairchem.core.datasets import AseDBDataset
26+
from fairchem.core.datasets.ase_datasets import AseDBDataset, AseReadDataset
2527
from fairchem.core.units.mlip_unit.mlip_unit import (
2628
UNIT_INFERENCE_CHECKPOINT,
2729
UNIT_RESUME_CONFIG,
@@ -192,27 +194,36 @@ def dummy_binary_dataset_path(tmpdir_factory, dummy_element_refs):
192194
+ 0.05 * rng.random() * dummy_element_refs.mean()
193195
)
194196
atoms = structure.to_ase_atoms()
195-
db.write(
197+
atoms.calc = SinglePointCalculator(
196198
atoms,
197-
data={
198-
"sid": f"structure_{i}",
199-
"energy": energy,
200-
"forces": rng.random((2, 3)),
201-
"stress": rng.random((3, 3)),
202-
},
199+
energy=energy,
200+
forces=rng.random((2, 3)),
201+
stress=rng.random((3, 3)),
203202
)
203+
# write to the lmdb file
204+
db.write(atoms, data={"sid": f"structure_{i}"})
204205

205-
return tmpdir / "dummy.aselmdb"
206+
# write it as a cif file as well
207+
write(str(tmpdir / f"structure_{i}.cif"), atoms)
208+
209+
return tmpdir
210+
211+
212+
@pytest.fixture(scope="session", params=["asedb", "cif"])
213+
def dummy_binary_dataset(dummy_binary_dataset_path, request):
214+
config = dict(src=str(dummy_binary_dataset_path))
215+
216+
if request.param == "cif":
217+
config["pattern"] = "*.cif"
218+
return AseReadDataset(config=config)
219+
else:
220+
return AseDBDataset(config=config)
206221

207222

208223
@pytest.fixture(scope="session")
209-
def dummy_binary_dataset(dummy_binary_dataset_path):
210-
return AseDBDataset(
211-
config={
212-
"src": str(dummy_binary_dataset_path),
213-
"a2g_args": {"r_data_keys": ["energy", "forces", "stress"]},
214-
}
215-
)
224+
def dummy_binary_db_dataset(dummy_binary_dataset_path):
225+
config = dict(src=str(dummy_binary_dataset_path))
226+
return AseDBDataset(config=config)
216227

217228

218229
@pytest.fixture(autouse=True)

tests/core/modules/test_element_references.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626

2727

2828
@pytest.fixture(scope="session", params=(True, False))
29-
def element_refs(dummy_binary_dataset, max_num_elements, request):
29+
def element_refs(dummy_binary_db_dataset, max_num_elements, request):
3030
return fit_linear_references(
3131
["energy"],
32-
dataset=dummy_binary_dataset,
32+
dataset=dummy_binary_db_dataset,
3333
batch_size=16,
3434
shuffle=False,
3535
max_num_elements=max_num_elements,
@@ -39,12 +39,12 @@ def element_refs(dummy_binary_dataset, max_num_elements, request):
3939

4040

4141
def test_apply_linear_references(
42-
element_refs, dummy_binary_dataset, dummy_element_refs
42+
element_refs, dummy_binary_db_dataset, dummy_element_refs
4343
):
4444
max_noise = 0.05 * dummy_element_refs.mean()
4545

4646
# check that removing element refs keeps only values within max noise
47-
batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True)
47+
batch = data_list_collater(list(dummy_binary_db_dataset), otf_graph=True)
4848
energy = batch.energy.clone().view(len(batch), -1)
4949
deref_energy = element_refs["energy"].dereference(energy, batch)
5050
assert all(deref_energy <= max_noise)
@@ -96,14 +96,14 @@ def test_create_element_references(element_refs, tmp_path):
9696

9797

9898
def test_fit_linear_references(
99-
element_refs, dummy_binary_dataset, max_num_elements, dummy_element_refs
99+
element_refs, dummy_binary_db_dataset, max_num_elements, dummy_element_refs
100100
):
101101
# create the composition matrix
102-
energy = np.array([d.energy for d in dummy_binary_dataset]).reshape(-1)
102+
energy = np.array([d.energy for d in dummy_binary_db_dataset]).reshape(-1)
103103
cmatrix = np.vstack(
104104
[
105105
np.bincount(d.atomic_numbers.int().numpy(), minlength=max_num_elements + 1)
106-
for d in dummy_binary_dataset
106+
for d in dummy_binary_db_dataset
107107
]
108108
)
109109
mask = cmatrix.sum(axis=0) != 0.0
@@ -130,30 +130,30 @@ def test_fit_linear_references(
130130
)
131131

132132

133-
def test_fit_seed_no_seed(dummy_binary_dataset, max_num_elements):
133+
def test_fit_seed_no_seed(dummy_binary_db_dataset, max_num_elements):
134134
refs_seed = fit_linear_references(
135135
["energy"],
136-
dataset=dummy_binary_dataset,
136+
dataset=dummy_binary_db_dataset,
137137
batch_size=16,
138-
num_batches=len(dummy_binary_dataset) // 16 - 2,
138+
num_batches=len(dummy_binary_db_dataset) // 16 - 2,
139139
shuffle=True,
140140
max_num_elements=max_num_elements,
141141
seed=0,
142142
)
143143
refs_seed1 = fit_linear_references(
144144
["energy"],
145-
dataset=dummy_binary_dataset,
145+
dataset=dummy_binary_db_dataset,
146146
batch_size=16,
147-
num_batches=len(dummy_binary_dataset) // 16 - 2,
147+
num_batches=len(dummy_binary_db_dataset) // 16 - 2,
148148
shuffle=True,
149149
max_num_elements=max_num_elements,
150150
seed=0,
151151
)
152152
refs_noseed = fit_linear_references(
153153
["energy"],
154-
dataset=dummy_binary_dataset,
154+
dataset=dummy_binary_db_dataset,
155155
batch_size=16,
156-
num_batches=len(dummy_binary_dataset) // 16 - 2,
156+
num_batches=len(dummy_binary_db_dataset) // 16 - 2,
157157
shuffle=True,
158158
max_num_elements=max_num_elements,
159159
seed=1,

tests/core/modules/test_normalizer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@
2020

2121

2222
@pytest.fixture(scope="session")
23-
def normalizers(dummy_binary_dataset):
23+
def normalizers(dummy_binary_db_dataset):
2424
return fit_normalizers(
2525
["energy", "forces"],
2626
override_values={"forces": {"mean": 0.0}},
27-
dataset=dummy_binary_dataset,
27+
dataset=dummy_binary_db_dataset,
2828
batch_size=16,
2929
shuffle=False,
3030
)
3131

3232

33-
def test_norm_denorm(normalizers, dummy_binary_dataset, dummy_element_refs):
34-
batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True)
33+
def test_norm_denorm(normalizers, dummy_binary_db_dataset, dummy_element_refs):
34+
batch = data_list_collater(list(dummy_binary_db_dataset), otf_graph=True)
3535
# test norm and denorm
3636
for target, normalizer in normalizers.items():
3737
normed = normalizer.norm(batch[target])
@@ -43,7 +43,7 @@ def test_norm_denorm(normalizers, dummy_binary_dataset, dummy_element_refs):
4343
)
4444

4545

46-
def test_create_normalizers(normalizers, dummy_binary_dataset, tmp_path):
46+
def test_create_normalizers(normalizers, dummy_binary_db_dataset, tmp_path):
4747
# test that forces mean was overriden
4848
assert normalizers["forces"].mean.item() == 0.0
4949

@@ -81,7 +81,7 @@ def test_create_normalizers(normalizers, dummy_binary_dataset, tmp_path):
8181
assert norm.state_dict() == sdict
8282

8383
# from tensor directly
84-
batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True)
84+
batch = data_list_collater(list(dummy_binary_db_dataset), otf_graph=True)
8585
norm = create_normalizer(tensor=batch.energy)
8686
assert isinstance(norm, Normalizer)
8787
# assert norm.state_dict() == sdict

0 commit comments

Comments
 (0)