Skip to content

Commit 96f4a91

Browse files
authored
Added check_simulation_cell in utils (#117)
* Added equal cell tolerance * Defining logic in __eq__ methods of Cell and AtomicCell * Moved testing to test_model_system Added __ne__ method * Add todo * Fix == testing failing * Added __ne__ testing
1 parent 8cf4685 commit 96f4a91

File tree

5 files changed

+238
-7
lines changed

5 files changed

+238
-7
lines changed

src/nomad_simulations/schema_packages/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@
2323
class NOMADSimulationsEntryPoint(SchemaPackageEntryPoint):
2424
dos_energy_tolerance: float = Field(
2525
8.01088e-21,
26-
description='Tolerance of the DOS energies in Joules to match the reference of energies in the DOS normalize function.',
26+
description='Tolerance (in joules) of the DOS energies to match the reference of energies in the DOS normalize function.',
2727
)
2828
dos_intensities_threshold: float = Field(
2929
1e-8,
30-
description='Threshold value at which the DOS intensities are considered non-zero.',
30+
description='Threshold value (in joules^-1) at which the DOS intensities are considered non-zero.',
3131
)
3232
occupation_tolerance: float = Field(
3333
1e-3,
3434
description='Tolerance for the occupation of a eigenstate to be non-occupied.',
3535
)
3636
fermi_surface_tolerance: float = Field(
3737
1e-8,
38-
description='Tolerance for energies to be close to the Fermi level and hence define the Fermi surface of a material.',
38+
description='Tolerance (in joules) for energies to be close to the Fermi level and hence define the Fermi surface of a material.',
3939
)
4040
symmetry_tolerance: float = Field(
4141
0.1, description='Tolerance for the symmetry analyzer used from MatID.'
@@ -48,6 +48,10 @@ class NOMADSimulationsEntryPoint(SchemaPackageEntryPoint):
4848
64,
4949
description='Limite of the number of atoms in the unit cell to be treated for the system type classification from MatID to work. This is done to avoid overhead of the package.',
5050
)
51+
equal_cell_positions_tolerance: float = Field(
52+
1e-12,
53+
description='Tolerance (in meters) for the cell positions to be considered equal.',
54+
)
5155

5256
def load(self):
5357
from nomad_simulations.schema_packages.general import m_package

src/nomad_simulations/schema_packages/model_system.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,40 @@ class Cell(GeometricSpace):
296296
""",
297297
)
298298

299+
def _check_positions(self, positions_1, positions_2) -> list:
300+
# Check that all the `positions`` of `cell_1` match with the ones in `cell_2`
301+
check_positions = []
302+
for i1, pos1 in enumerate(positions_1):
303+
for i2, pos2 in enumerate(positions_2):
304+
if np.allclose(
305+
pos1, pos2, atol=configuration.equal_cell_positions_tolerance
306+
):
307+
check_positions.append([i1, i2])
308+
break
309+
return check_positions
310+
311+
def __eq__(self, other) -> bool:
312+
# TODO implement checks on `lattice_vectors` and other quantities to ensure the equality of primitive cells
313+
if not isinstance(other, Cell):
314+
return False
315+
316+
# If the `positions` are empty, return False
317+
if self.positions is None or other.positions is None:
318+
return False
319+
320+
# The `positions` should have the same length (same number of positions)
321+
if len(self.positions) != len(other.positions):
322+
return False
323+
n_positions = len(self.positions)
324+
325+
check_positions = self._check_positions(self.positions, other.positions)
326+
if len(check_positions) != n_positions:
327+
return False
328+
return True
329+
330+
def __ne__(self, other) -> bool:
331+
return not self.__eq__(other)
332+
299333
def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
300334
super().normalize(archive, logger)
301335

@@ -339,6 +373,29 @@ def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwarg
339373
# Set the name of the section
340374
self.name = self.m_def.name
341375

376+
def __eq__(self, other) -> bool:
377+
if not isinstance(other, AtomicCell):
378+
return False
379+
380+
# Compare positions using the parent sections's `__eq__` method
381+
if not super().__eq__(other):
382+
return False
383+
384+
# Check that the `chemical_symbol` of the atoms in `cell_1` match with the ones in `cell_2`
385+
check_positions = self._check_positions(self.positions, other.positions)
386+
try:
387+
for atom in check_positions:
388+
element_1 = self.atoms_state[atom[0]].chemical_symbol
389+
element_2 = other.atoms_state[atom[1]].chemical_symbol
390+
if element_1 != element_2:
391+
return False
392+
except Exception:
393+
return False
394+
return True
395+
396+
def __ne__(self, other) -> bool:
397+
return not self.__eq__(other)
398+
342399
def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]:
343400
"""
344401
Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell`

src/nomad_simulations/schema_packages/utils/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,20 @@
2121
from typing import TYPE_CHECKING
2222

2323
import numpy as np
24+
from nomad.config import config
2425

2526
if TYPE_CHECKING:
2627
from typing import Optional
2728

2829
from nomad.datamodel.data import ArchiveSection
2930
from structlog.stdlib import BoundLogger
3031

32+
from nomad_simulations.schema_packages.model_system import Cell
33+
34+
configuration = config.get_plugin_entry_point(
35+
'nomad_simulations.schema_packages:nomad_simulations_plugin'
36+
)
37+
3138

3239
def get_sibling_section(
3340
section: 'ArchiveSection',

tests/test_model_system.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
import pytest
2323
from nomad.datamodel import EntryArchive
2424

25+
from nomad_simulations.schema_packages.atoms_state import AtomsState
2526
from nomad_simulations.schema_packages.model_system import (
27+
AtomicCell,
28+
Cell,
2629
ChemicalFormula,
2730
ModelSystem,
2831
Symmetry,
@@ -32,11 +35,162 @@
3235
from .conftest import generate_atomic_cell
3336

3437

38+
class TestCell:
39+
"""
40+
Test the `Cell` section defined in model_system.py
41+
"""
42+
43+
@pytest.mark.parametrize(
44+
'cell_1, cell_2, result',
45+
[
46+
(Cell(), None, False), # one cell is None
47+
(Cell(), Cell(), False), # both cells are empty
48+
(
49+
Cell(positions=[[1, 0, 0]]),
50+
Cell(),
51+
False,
52+
), # one cell has positions, the other is empty
53+
(
54+
Cell(positions=[[1, 0, 0], [0, 1, 0]]),
55+
Cell(positions=[[1, 0, 0]]),
56+
False,
57+
), # length mismatch
58+
(
59+
Cell(positions=[[1, 0, 0], [0, 1, 0]]),
60+
Cell(positions=[[1, 0, 0], [0, -1, 0]]),
61+
False,
62+
), # different positions
63+
(
64+
Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
65+
Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
66+
True,
67+
), # same ordered positions
68+
(
69+
Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
70+
Cell(positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]]),
71+
True,
72+
), # different ordered positions but same cell
73+
],
74+
)
75+
def test_eq_ne(self, cell_1: Cell, cell_2: Cell, result: bool):
76+
"""
77+
Test the `__eq__` and `__ne__` operator functions of `Cell`.
78+
"""
79+
assert (cell_1 == cell_2) == result
80+
assert (cell_1 != cell_2) != result
81+
82+
3583
class TestAtomicCell:
3684
"""
3785
Test the `AtomicCell`, `Cell` and `GeometricSpace` classes defined in model_system.py
3886
"""
3987

88+
@pytest.mark.parametrize(
89+
'cell_1, cell_2, result',
90+
[
91+
(Cell(), None, False), # one cell is None
92+
(Cell(), Cell(), False), # both cells are empty
93+
(
94+
Cell(positions=[[1, 0, 0]]),
95+
Cell(),
96+
False,
97+
), # one cell has positions, the other is empty
98+
(
99+
Cell(positions=[[1, 0, 0], [0, 1, 0]]),
100+
Cell(positions=[[1, 0, 0]]),
101+
False,
102+
), # length mismatch
103+
(
104+
Cell(positions=[[1, 0, 0], [0, 1, 0]]),
105+
Cell(positions=[[1, 0, 0], [0, -1, 0]]),
106+
False,
107+
), # different positions
108+
(
109+
Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
110+
Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
111+
True,
112+
), # same ordered positions
113+
(
114+
Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
115+
Cell(positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]]),
116+
True,
117+
), # different ordered positions but same cell
118+
(
119+
AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
120+
Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
121+
False,
122+
), # one atomic cell and another cell (missing chemical symbols)
123+
(
124+
AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
125+
AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
126+
False,
127+
), # missing chemical symbols
128+
(
129+
AtomicCell(
130+
positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
131+
atoms_state=[
132+
AtomsState(chemical_symbol='H'),
133+
AtomsState(chemical_symbol='H'),
134+
AtomsState(chemical_symbol='O'),
135+
],
136+
),
137+
AtomicCell(
138+
positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
139+
atoms_state=[
140+
AtomsState(chemical_symbol='H'),
141+
AtomsState(chemical_symbol='H'),
142+
AtomsState(chemical_symbol='O'),
143+
],
144+
),
145+
True,
146+
), # same ordered positions and chemical symbols
147+
(
148+
AtomicCell(
149+
positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
150+
atoms_state=[
151+
AtomsState(chemical_symbol='H'),
152+
AtomsState(chemical_symbol='H'),
153+
AtomsState(chemical_symbol='O'),
154+
],
155+
),
156+
AtomicCell(
157+
positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
158+
atoms_state=[
159+
AtomsState(chemical_symbol='H'),
160+
AtomsState(chemical_symbol='Cu'),
161+
AtomsState(chemical_symbol='O'),
162+
],
163+
),
164+
False,
165+
), # same ordered positions but different chemical symbols
166+
(
167+
AtomicCell(
168+
positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
169+
atoms_state=[
170+
AtomsState(chemical_symbol='H'),
171+
AtomsState(chemical_symbol='H'),
172+
AtomsState(chemical_symbol='O'),
173+
],
174+
),
175+
AtomicCell(
176+
positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]],
177+
atoms_state=[
178+
AtomsState(chemical_symbol='H'),
179+
AtomsState(chemical_symbol='O'),
180+
AtomsState(chemical_symbol='H'),
181+
],
182+
),
183+
True,
184+
), # different ordered positions but same chemical symbols
185+
],
186+
)
187+
def test_eq_ne(self, cell_1: Cell, cell_2: Cell, result: bool):
188+
"""
189+
Test the `__eq__` and `__ne__` operator functions of `AtomicCell`.
190+
"""
191+
assert (cell_1 == cell_2) == result
192+
assert (cell_1 != cell_2) != result
193+
40194
@pytest.mark.parametrize(
41195
'chemical_symbols, atomic_numbers, formula, lattice_vectors, positions, periodic_boundary_conditions',
42196
[

tests/test_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,25 @@ def test_get_sibling_section():
4545
parent_section.symmetry.append(sibling_section)
4646
assert get_sibling_section(section, '', logger) is None
4747
assert get_sibling_section(section, 'symmetry', logger) == sibling_section
48-
assert get_sibling_section(sibling_section, 'cell', logger) == section
48+
assert get_sibling_section(sibling_section, 'cell', logger).type == section.type
4949
assert get_sibling_section(section, 'symmetry', logger, index_sibling=2) is None
5050
section2 = AtomicCell(type='primitive')
5151
parent_section.cell.append(section2)
5252
assert (
53-
get_sibling_section(sibling_section, 'cell', logger, index_sibling=0) == section
53+
get_sibling_section(sibling_section, 'cell', logger, index_sibling=0).type
54+
== 'original'
5455
)
5556
assert (
56-
get_sibling_section(sibling_section, 'cell', logger, index_sibling=1)
57-
== section2
57+
get_sibling_section(sibling_section, 'cell', logger, index_sibling=0).type
58+
== section.type
59+
)
60+
assert (
61+
get_sibling_section(sibling_section, 'cell', logger, index_sibling=1).type
62+
== section2.type
63+
)
64+
assert (
65+
get_sibling_section(sibling_section, 'cell', logger, index_sibling=1).type
66+
== 'primitive'
5867
)
5968

6069

0 commit comments

Comments
 (0)