Skip to content

Commit debca0a

Browse files
authored
Refactoring model system (#170)
* work with basesections.v2 * remove the redundant child model_system subsection * move positions Quantity to ModelSystem * AtomsState(Element), remove helper functions, modify normalize() to check consistency between chemical_symbol and atomic_number * remove positions and velocities from Cell * move ASE Atoms object creation to ModelSystem * add sub_systems to ModelSystem * add total_charge and total_spin to ModelSystem * add ModelSystem.sub_systems instead of ModelSystem.model_system * move comparison functions from AtomicCell to ModelSystem * add ParticleState and CoarseGrainedState * Refactor symmetry resolution: use dict mapping for cell types * remove atom_types and related tests * remove too many warnings in get_chemical_symbols * ModelSystem.particle_states instead of atom_states * remove redundant positions length checks * n_particles and particle_indices (instead of n_atoms and atom_indices) * add a note in GeometricSpace.get_geometric_space_for_atomic_cell * modify to_ase_atoms, test_numerical_settings.py * modify conftest so that model_system.positions are populated * import the v2 Entity in physical_property.py * Fix DOSProfile normalization factor resolution to use model_system.particle_states * Remove unused variable from SlaterKosterBond normalization test * fix APW-refs fixture in conftest.py * fix spectral profile tests * Remove redundant try/except around get_geometric_space_for_atomic_cell * add a test for composition formula
1 parent 7f85cfc commit debca0a

File tree

15 files changed

+935
-1424
lines changed

15 files changed

+935
-1424
lines changed

src/nomad_simulations/schema_packages/atoms_state.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pint
66
from nomad.datamodel.data import ArchiveSection
7-
from nomad.datamodel.metainfo.basesections import Entity
7+
from nomad.datamodel.metainfo.basesections.v2 import Entity
88
from nomad.metainfo import MEnum, Quantity, SubSection
99
from nomad.units import ureg
1010

@@ -545,12 +545,20 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
545545
)
546546

547547

548-
class AtomsState(Entity):
548+
class ParticleState(Entity):
549+
"""
550+
Generic base section representing the state of a particle in a simulation.
551+
This can be extended to include any common quantities in the future.
552+
"""
553+
554+
pass
555+
556+
557+
class AtomsState(ParticleState):
549558
"""
550559
A base section to define each atom state information.
551560
"""
552561

553-
# TODO check what happens with ghost atoms that can have `chemical_symbol='X'`
554562
chemical_symbol = Quantity(
555563
type=MEnum(ase.data.chemical_symbols[1:]),
556564
description="""
@@ -565,8 +573,6 @@ class AtomsState(Entity):
565573
""",
566574
)
567575

568-
orbitals_state = SubSection(sub_section=OrbitalsState.m_def, repeats=True)
569-
570576
charge = Quantity(
571577
type=np.int32,
572578
default=0,
@@ -581,6 +587,16 @@ class AtomsState(Entity):
581587
""",
582588
)
583589

590+
spin = Quantity(
591+
type=np.int32,
592+
default=0,
593+
description="""
594+
Total spin quantum number, S.
595+
""",
596+
)
597+
598+
orbitals_state = SubSection(sub_section=OrbitalsState.m_def, repeats=True)
599+
584600
core_hole = SubSection(sub_section=CoreHole.m_def, repeats=False)
585601

586602
hubbard_interactions = SubSection(
@@ -633,3 +649,7 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
633649
self.chemical_symbol = self.resolve_chemical_symbol(logger=logger)
634650
if self.atomic_number is None:
635651
self.atomic_number = self.resolve_atomic_number(logger=logger)
652+
653+
654+
class CoarseGrainedState(ParticleState):
655+
pass

src/nomad_simulations/schema_packages/general.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from nomad.datamodel.metainfo.basesections import Activity, Entity
1313
from nomad.metainfo import Datetime, Quantity, SchemaPackage, Section, SubSection
1414

15+
from nomad_simulations.schema_packages.atoms_state import AtomsState
1516
from nomad_simulations.schema_packages.model_method import ModelMethod
1617
from nomad_simulations.schema_packages.model_system import ModelSystem
1718
from nomad_simulations.schema_packages.outputs import Outputs
@@ -195,7 +196,7 @@ class Simulation(BaseSimulation, Schema):
195196
def _set_system_branch_depth(
196197
self, system_parent: ModelSystem, branch_depth: int = 0
197198
):
198-
for system_child in system_parent.model_system:
199+
for system_child in system_parent.sub_systems:
199200
system_child.branch_depth = branch_depth + 1
200201
self._set_system_branch_depth(
201202
system_parent=system_child, branch_depth=branch_depth + 1
@@ -222,10 +223,10 @@ def set_composition_formula(
222223
to the atom indices stored in system.
223224
"""
224225
if not subsystems:
225-
if system.atom_indices is not None and atom_labels:
226-
subsystem_labels = [atom_labels[i] for i in system.atom_indices]
227-
elif system.atom_indices is not None:
228-
subsystem_labels = ['Unknown'] * len(system.atom_indices)
226+
if system.particle_indices is not None and atom_labels:
227+
subsystem_labels = [atom_labels[i] for i in system.particle_indices]
228+
elif system.particle_indices is not None:
229+
subsystem_labels = ['Unknown'] * len(system.particle_indices)
229230
else:
230231
subsystem_labels = []
231232
else:
@@ -249,22 +250,25 @@ def get_composition_recurs(system: ModelSystem, atom_labels: list[str]) -> None:
249250
atom_labels (list[str]): The global list of atom labels corresponding
250251
to the atom indices stored in system.
251252
"""
252-
subsystems = system.model_system
253+
subsystems = system.sub_systems
253254
set_composition_formula(
254255
system=system, subsystems=subsystems, atom_labels=atom_labels
255256
)
256257
if subsystems:
257258
for subsystem in subsystems:
258259
get_composition_recurs(system=subsystem, atom_labels=atom_labels)
259260

260-
atoms_state = (
261-
system_parent.cell[0].atoms_state if system_parent.cell is not None else []
262-
)
261+
# Pull chemical symbols straight from AtomsState.chemical_symbol
263262
atom_labels = (
264-
[atom.chemical_symbol for atom in atoms_state]
265-
if atoms_state is not None
263+
[
264+
atom.chemical_symbol
265+
for atom in system_parent.particle_states
266+
if isinstance(atom, AtomsState) and atom.chemical_symbol is not None
267+
]
268+
if system_parent.particle_states
266269
else []
267270
)
271+
268272
get_composition_recurs(system=system_parent, atom_labels=atom_labels)
269273

270274
def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
@@ -284,7 +288,7 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
284288
# Setting up the `branch_depth` in the parent-child tree
285289
for system_parent in self.model_system:
286290
system_parent.branch_depth = 0
287-
if len(system_parent.model_system) == 0:
291+
if len(system_parent.sub_systems) == 0:
288292
continue
289293
self._set_system_branch_depth(system_parent=system_parent)
290294

src/nomad_simulations/schema_packages/model_method.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -483,59 +483,74 @@ def resolve_orbital_references(
483483
model_index: int = -1,
484484
) -> Optional[list[OrbitalsState]]:
485485
"""
486-
Resolves the references to the `OrbitalsState` sections from the child `ModelSystem` section.
486+
Resolves references to the `OrbitalsState` sections from the top-level `ModelSystem`
487+
that has child system(s) typed 'active_atom'. This uses the new design:
488+
489+
- The parent ModelSystem stores per-atom data in `particle_states`.
490+
- The child system(s) typed 'active_atom' list indices in `particle_indices`.
491+
- We gather OrbitalsState from each relevant particle_states entry.
487492
488493
Args:
489494
model_systems (list[ModelSystem]): The list of `ModelSystem` sections.
490495
logger (BoundLogger): The logger to log messages.
491-
model_index (int, optional): The `ModelSystem` section index from which resolve the references. Defaults to -1.
496+
model_index (int, optional): The ModelSystem index to use. Defaults to -1 (the last).
492497
493498
Returns:
494-
Optional[list[OrbitalsState]]: The resolved references to the `OrbitalsState` sections.
499+
Optional[list[OrbitalsState]]: The resolved references to the OrbitalsState sections.
495500
"""
501+
# Check that the requested ModelSystem exists
496502
try:
497503
model_system = model_systems[model_index]
498504
except IndexError:
499-
logger.warning(
500-
f'The `ModelSystem` section with index {model_index} was not found.'
501-
)
505+
logger.warning(f'No ModelSystem at index {model_index}.')
502506
return None
503507

504-
# If `ModelSystem` is not representative, the normalization will not run
508+
# If the system is not representative, bail out of normalization
505509
if is_not_representative(model_system=model_system, logger=logger):
506510
return None
507511

508-
# If `AtomicCell` is not found, the normalization will not run
509-
if not model_system.cell:
510-
logger.warning('`AtomicCell` section was not found.')
512+
# If no child ModelSystem sections exist, bail out of normalization
513+
if not model_system.sub_systems:
514+
logger.warning(
515+
'No child ModelSystem found; cannot find active_atom references.'
516+
)
511517
return None
512-
atomic_cell = model_system.cell[0]
513518

514-
# If there is no child `ModelSystem`, the normalization will not run
515-
atoms_state = atomic_cell.atoms_state
516-
model_system_child = model_system.model_system
517-
if not atoms_state or not model_system_child:
518-
logger.warning('No `AtomsState` or child `ModelSystem` section were found.')
519+
# If no particle_states are present at the top level, we have no orbitals
520+
if not model_system.particle_states:
521+
logger.warning('No particle_states in the parent ModelSystem.')
519522
return None
520523

521-
# We flatten the `OrbitalsState` sections from the `ModelSystem` section
522-
orbitals_ref = []
523-
for active_atom in model_system_child:
524-
# If the child is not an "active_atom", the normalization will not run
525-
if active_atom.type != 'active_atom':
524+
orbitals_ref: list[OrbitalsState] = []
525+
526+
# For each child in sub_systems, if type='active_atom', gather orbitals
527+
for child_sys in model_system.sub_systems:
528+
if child_sys.type != 'active_atom':
529+
continue
530+
# if no particle_indices => skip
531+
if not child_sys.particle_indices:
532+
logger.warning('Child system is active_atom but no particle_indices.')
526533
continue
527-
indices = active_atom.atom_indices
528-
for index in indices:
529-
try:
530-
active_atoms_state = atoms_state[index]
531-
except IndexError:
534+
535+
# For each index in child_sys.particle_indices => fetch from parent’s particle_states
536+
for idx in child_sys.particle_indices:
537+
if idx < 0 or idx >= len(model_system.particle_states):
532538
logger.warning(
533-
f'The `AtomsState` section with index {index} was not found.'
539+
f'Particle index {idx} out of range for particle_states.'
534540
)
535541
continue
536-
orbitals_state = active_atoms_state.orbitals_state
537-
for orbital in orbitals_state:
538-
orbitals_ref.append(orbital)
542+
active_atom_state = model_system.particle_states[idx]
543+
544+
# If no orbitals_state => skip
545+
if not active_atom_state.orbitals_state:
546+
logger.warning(
547+
f'No orbitals_state found in particle_states[{idx}].'
548+
)
549+
continue
550+
551+
orbitals_ref.extend(active_atom_state.orbitals_state)
552+
553+
# Return the collected orbitals
539554
return orbitals_ref
540555

541556
def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:

0 commit comments

Comments
 (0)