Skip to content

Commit 4f1b0f6

Browse files
Fix data model (libAtoms#117)
* Update ASE * Fix duplicate calculated results * Fix getting Atoms results * Fix derived data * Test extxyz file * Fix to ASE function * Test to ASE function * additional tests * Update tests/test_abstract_model.py --------- Co-authored-by: ElliottKasoar <[email protected]> Co-authored-by: Tamas K Stenczel <[email protected]>
1 parent e317bee commit 4f1b0f6

File tree

3 files changed

+280
-22
lines changed

3 files changed

+280
-22
lines changed

abcd/model.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ def __iter__(self):
146146

147147
@classmethod
148148
def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True):
149-
"""ASE's original implementation"""
149+
"""Extract data from Atoms info, arrays and results."""
150+
if not isinstance(atoms, Atoms):
151+
raise ValueError("atoms must be an ASE Atoms object.")
150152

151153
reserved_keys = {
152154
"n_atoms",
@@ -157,11 +159,13 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True):
157159
"derived",
158160
"formula",
159161
}
162+
160163
arrays_keys = set(atoms.arrays.keys())
161164
info_keys = set(atoms.info.keys())
162-
results_keys = (
163-
set(atoms.calc.results.keys()) if store_calc and atoms.calc else {}
164-
)
165+
if store_calc and atoms.calc:
166+
results_keys = atoms.calc.results.keys() - (arrays_keys | info_keys)
167+
else:
168+
results_keys = set()
165169

166170
all_keys = (reserved_keys, arrays_keys, info_keys, results_keys)
167171
if len(set.union(*all_keys)) != sum(map(len, all_keys)):
@@ -172,46 +176,43 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True):
172176

173177
n_atoms = len(atoms)
174178

175-
dct = {
179+
data = {
176180
"n_atoms": n_atoms,
177181
"cell": atoms.cell.tolist(),
178182
"pbc": atoms.pbc.tolist(),
179183
"formula": atoms.get_chemical_formula(),
180184
}
181185

182-
info_keys.update({"n_atoms", "cell", "pbc", "formula"})
186+
info_keys.update(data.keys())
183187

184188
for key, value in atoms.arrays.items():
185189
if isinstance(value, np.ndarray):
186-
dct[key] = value.tolist()
190+
data[key] = value.tolist()
187191
else:
188-
dct[key] = value
192+
data[key] = value
189193

190194
for key, value in atoms.info.items():
191195
if isinstance(value, np.ndarray):
192-
dct[key] = value.tolist()
196+
data[key] = value.tolist()
193197
else:
194-
dct[key] = value
198+
data[key] = value
195199

196200
if store_calc and atoms.calc:
197-
dct["calculator_name"] = atoms.calc.__class__.__name__
198-
dct["calculator_parameters"] = atoms.calc.todict()
201+
data["calculator_name"] = atoms.calc.__class__.__name__
202+
data["calculator_parameters"] = atoms.calc.todict()
199203
info_keys.update({"calculator_name", "calculator_parameters"})
200204

201205
for key, value in atoms.calc.results.items():
202-
203206
if isinstance(value, np.ndarray):
204-
if value.shape[0] == n_atoms:
205-
arrays_keys.update(key)
206-
else:
207-
info_keys.update(key)
208-
dct[key] = value.tolist()
207+
data[key] = value.tolist()
208+
else:
209+
data[key] = value
209210

210211
item.arrays_keys = list(arrays_keys)
211212
item.info_keys = list(info_keys)
212213
item.results_keys = list(results_keys)
213214

214-
item.update(dct)
215+
item.update(data)
215216

216217
if extra_info:
217218
item.info_keys.extend(extra_info.keys())
@@ -240,6 +241,7 @@ def to_ase(self):
240241
# atoms.calc = get_calculator(data['results']['calculator_name'])(**params)
241242

242243
params = self.pop("calculator_parameters", {})
244+
info_keys -= {"calculator_parameters"}
243245

244246
atoms.calc = SinglePointCalculator(atoms, **params)
245247
atoms.calc.results.update((key, self[key]) for key in results_keys)
@@ -256,14 +258,14 @@ def pre_save(self):
256258

257259
if cell:
258260
volume = abs(np.linalg.det(cell)) # atoms.get_volume()
259-
self["volume"] = volume
260261
self.derived_keys.append("volume")
262+
self["volume"] = volume
261263

262264
virial = self.get("virial")
263265
if virial:
264266
# pressure P = -1/3 Tr(stress) = -1/3 Tr(virials/volume)
265-
self["pressure"] = -1 / 3 * np.trace(virial / volume)
266267
self.derived_keys.append("pressure")
268+
self["pressure"] = -1 / 3 * np.trace(virial / volume)
267269

268270
# 'elements': Counter(atoms.get_chemical_symbols()),
269271
self["elements"] = Counter(str(element) for element in self["numbers"])

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ numpy = "^1.26"
1616
tqdm = "^4.66"
1717
pymongo = "^4.7.3"
1818
matplotlib = "^3.9"
19-
ase = "3.22.1"
19+
ase = "^3.23"
2020
lark = "^1.1.9"
2121

2222
[tool.poetry.group.dev.dependencies]

tests/test_abstract_model.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import io
2+
3+
import ase
4+
import pytest
5+
from pytest import approx
6+
7+
from io import StringIO
8+
from ase.io import read, write
9+
import numpy as np
10+
11+
from abcd.model import AbstractModel
12+
from ase.calculators.lj import LennardJones
13+
14+
15+
@pytest.fixture
16+
def extxyz_file():
17+
return StringIO(
18+
"""2
19+
Properties=species:S:1:pos:R:3:forces:R:3 energy=-1 pbc="F T F" info="test"
20+
Si 0.0 0.0 0.0 0.4 0.6 -0.4
21+
Si 0.0 0.0 0.0 -0.1 -0.5 -0.6
22+
"""
23+
)
24+
25+
26+
def test_from_atoms(extxyz_file):
27+
"""Test extracting data from ASE Atoms object."""
28+
expected_forces = np.array([[0.4, 0.6, -0.4], [-0.1, -0.5, -0.6]])
29+
expected_stress = np.array([-1.0, -1.0, -1.0, -2.1, 2.0, 1.8])
30+
31+
atoms = read(extxyz_file, format="extxyz")
32+
atoms.calc.results["stress"] = expected_stress
33+
data = AbstractModel.from_atoms(atoms)
34+
35+
# Test info
36+
info_keys = {
37+
"pbc",
38+
"n_atoms",
39+
"cell",
40+
"formula",
41+
"calculator_name",
42+
"calculator_parameters",
43+
"info",
44+
}
45+
assert info_keys == set(data.info_keys)
46+
assert data["pbc"] == [False, True, False]
47+
assert data["n_atoms"] == 2
48+
assert len(data["cell"]) == 3
49+
assert all(arr == [0.0, 0.0, 0.0] for arr in data["cell"])
50+
assert data["formula"] == "Si2"
51+
assert data["info"] == "test"
52+
53+
# Test arrays
54+
assert {"numbers", "positions"} == set(data.arrays_keys)
55+
56+
# Test results
57+
assert {"energy", "stress", "forces"} == set(data.results_keys)
58+
assert data["energy"] == -1
59+
assert data["forces"] == pytest.approx(expected_forces)
60+
assert data["stress"] == pytest.approx(expected_stress)
61+
62+
# Test derived
63+
derived_keys = {
64+
"elements",
65+
"username",
66+
"uploaded",
67+
"modified",
68+
"volume",
69+
"hash",
70+
"hash_structure",
71+
}
72+
assert derived_keys == set(data.derived_keys)
73+
74+
75+
def test_from_atoms_no_calc(extxyz_file):
76+
"""Test extracting data from ASE Atoms object without results."""
77+
expected_stress = np.array([-1.0, -1.0, -1.0, -2.1, 2.0, 1.8])
78+
79+
atoms = read(extxyz_file, format="extxyz")
80+
atoms.calc.results["stress"] = expected_stress
81+
data = AbstractModel.from_atoms(atoms, store_calc=False)
82+
83+
# Test info
84+
assert {"pbc", "n_atoms", "cell", "formula", "info"} == set(data.info_keys)
85+
assert data["pbc"] == [False, True, False]
86+
assert data["n_atoms"] == 2
87+
assert len(data["cell"]) == 3
88+
assert all(arr == [0.0, 0.0, 0.0] for arr in data["cell"])
89+
assert data["formula"] == "Si2"
90+
assert data["info"] == "test"
91+
92+
# Test arrays
93+
assert {"numbers", "positions"} == set(data.arrays_keys)
94+
95+
# Test results
96+
results_keys = {
97+
"energy",
98+
"forces",
99+
"stress",
100+
"calculator_name",
101+
"calculator_parameters",
102+
}
103+
assert all(key not in data for key in results_keys)
104+
105+
# Test derived
106+
derived_keys = {
107+
"elements",
108+
"username",
109+
"uploaded",
110+
"modified",
111+
"volume",
112+
"hash",
113+
"hash_structure",
114+
}
115+
assert derived_keys == set(data.derived_keys)
116+
117+
118+
def test_to_ase(extxyz_file):
119+
"""Test returning data to ASE Atoms object with results."""
120+
atoms = read(extxyz_file, format="extxyz")
121+
data = AbstractModel.from_atoms(atoms, store_calc=True)
122+
123+
new_atoms = data.to_ase()
124+
125+
# Test info set
126+
assert new_atoms.cell == pytest.approx(atoms.cell)
127+
assert new_atoms.pbc == pytest.approx(atoms.pbc)
128+
assert new_atoms.positions == pytest.approx(atoms.positions)
129+
assert new_atoms.numbers == pytest.approx(atoms.numbers)
130+
131+
assert new_atoms.info["n_atoms"] == len(atoms)
132+
assert new_atoms.info["formula"] == atoms.get_chemical_formula()
133+
134+
assert new_atoms.calc.results["energy"] == pytest.approx(
135+
atoms.calc.results["energy"]
136+
)
137+
assert new_atoms.calc.results["forces"] == pytest.approx(
138+
atoms.calc.results["forces"]
139+
)
140+
141+
142+
def test_to_ase_no_results(extxyz_file):
143+
"""Test returning data to ASE Atoms object without results."""
144+
atoms = read(extxyz_file, format="extxyz")
145+
data = AbstractModel.from_atoms(atoms, store_calc=False)
146+
147+
new_atoms = data.to_ase()
148+
149+
# Test info set
150+
assert new_atoms.cell == pytest.approx(atoms.cell)
151+
assert new_atoms.pbc == pytest.approx(atoms.pbc)
152+
assert new_atoms.positions == pytest.approx(atoms.positions)
153+
assert new_atoms.numbers == pytest.approx(atoms.numbers)
154+
155+
assert new_atoms.info["n_atoms"] == len(atoms)
156+
assert new_atoms.info["formula"] == atoms.get_chemical_formula()
157+
158+
assert new_atoms.calc is None
159+
160+
161+
def test_from_atoms_len_atoms_3():
162+
atoms = ase.Atoms(
163+
"H3",
164+
positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]],
165+
pbc=True,
166+
cell=[2, 2, 2],
167+
)
168+
atoms.calc = LennardJones()
169+
atoms.calc.calculate(atoms)
170+
171+
# convert
172+
abcd_data = AbstractModel.from_atoms(atoms, store_calc=True)
173+
174+
assert set(abcd_data.info_keys) == {
175+
"pbc",
176+
"n_atoms",
177+
"cell",
178+
"formula",
179+
"calculator_name",
180+
"calculator_parameters",
181+
}
182+
assert set(abcd_data.arrays_keys) == {"numbers", "positions"}
183+
assert set(abcd_data.results_keys) == {
184+
"stress",
185+
"energy",
186+
"forces",
187+
"energies",
188+
"stresses",
189+
"free_energy",
190+
}
191+
192+
# check some values as well
193+
assert abcd_data["energy"] == atoms.get_potential_energy()
194+
assert abcd_data["forces"] == approx(atoms.get_forces())
195+
196+
197+
@pytest.mark.parametrize("store_calc", [True, False])
198+
def test_write_and_read(store_calc):
199+
# create atoms & add a calculator
200+
atoms = ase.Atoms(
201+
"H3",
202+
positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]],
203+
pbc=True,
204+
cell=[2, 2, 2],
205+
)
206+
atoms.calc = LennardJones()
207+
atoms.calc.calculate(atoms)
208+
209+
# dump to XYZ
210+
buffer = io.StringIO()
211+
write(buffer, atoms, format="extxyz")
212+
213+
# read back
214+
buffer.seek(0)
215+
atoms_read = read(buffer, format="extxyz")
216+
217+
# read in both of them
218+
abcd_data = AbstractModel.from_atoms(atoms, store_calc=store_calc)
219+
abcd_data_after_read = AbstractModel.from_atoms(atoms_read, store_calc=store_calc)
220+
221+
# check that all results are the same
222+
for key in ["info_keys", "arrays_keys", "derived_keys", "results_keys"]:
223+
assert set(getattr(abcd_data, key)) == set(
224+
getattr(abcd_data_after_read, key)
225+
), f"{key} mismatched"
226+
227+
# info & arrays same, except calc recognised as LJ when not from XYZ
228+
for key in set(abcd_data.info_keys + abcd_data.arrays_keys) - {
229+
"calculator_name",
230+
"calculator_parameters",
231+
}:
232+
assert (
233+
abcd_data[key] == abcd_data_after_read[key]
234+
), f"{key}'s value does not match"
235+
236+
# date & hashed will differ
237+
for key in set(abcd_data.derived_keys) - {
238+
"hash",
239+
"modified",
240+
"uploaded",
241+
"hash_structure", # see issue #118
242+
}:
243+
assert (
244+
abcd_data[key] == abcd_data_after_read[key]
245+
), f"{key}'s value does not match"
246+
247+
# expected differences - n.b. order of calls above
248+
assert abcd_data_after_read["modified"] > abcd_data["modified"]
249+
assert abcd_data_after_read["uploaded"] > abcd_data["uploaded"]
250+
assert abcd_data_after_read["hash"] != abcd_data["hash"]
251+
252+
# expect results to match within fp precision
253+
for key in set(abcd_data.results_keys):
254+
assert abcd_data[key] == approx(
255+
np.array(abcd_data_after_read[key])
256+
), f"{key}'s value does not match"

0 commit comments

Comments
 (0)