Skip to content

Commit 423a23c

Browse files
Move Observable into observables.py
1 parent 3b4cb56 commit 423a23c

File tree

4 files changed

+113
-108
lines changed

4 files changed

+113
-108
lines changed

janus_core/calculations/md.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,8 @@ def _parse_correlations(self) -> None:
697697
"""Parse correlation kwargs into Correlations."""
698698
if self.correlation_kwargs:
699699
self._correlations = [
700-
Correlation(self.n_atoms, **cor) for cor in self.correlation_kwargs
700+
Correlation(n_atoms=self.n_atoms, **cor)
701+
for cor in self.correlation_kwargs
701702
]
702703
else:
703704
self._correlations = ()

janus_core/helpers/correlator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ class Correlation:
180180
----------
181181
n_atoms : int
182182
Number of possible atoms to track.
183-
a : tuple[Observable, dict]
183+
a : Union[Observable, tuple[Observable, tuple, dict]]
184184
Getter for a and kwargs.
185-
b : tuple[Observable, dict]
185+
b : Union[Observable, tuple[Observable, tuple, dict]]
186186
Getter for b and kwargs.
187187
name : str
188188
Name of correlation.
@@ -198,6 +198,7 @@ class Correlation:
198198

199199
def __init__(
200200
self,
201+
*,
201202
n_atoms: int,
202203
a: Union[Observable, tuple[Observable, tuple, dict]],
203204
b: Union[Observable, tuple[Observable, tuple, dict]],
@@ -214,9 +215,9 @@ def __init__(
214215
----------
215216
n_atoms : int
216217
Number of possible atoms to track.
217-
a : tuple[Observable, tuple, dict]
218+
a : Union[Observable, tuple[Observable, tuple, dict]]
218219
Getter for a and kwargs.
219-
b : tuple[Observable, tuple, dict]
220+
b : Union[Observable, tuple[Observable, tuple, dict]]
220221
Getter for b and kwargs.
221222
name : str
222223
Name of correlation.

janus_core/helpers/janus_types.py

Lines changed: 2 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import numpy as np
1212
from numpy.typing import NDArray
1313

14+
from janus_core.helpers.observables import Observable
15+
1416
# General
1517

1618
T = TypeVar("T")
@@ -133,106 +135,6 @@ class EoSResults(TypedDict, total=False):
133135
e_0: float
134136

135137

136-
# pylint: disable=too-few-public-methods
137-
class Observable:
138-
"""
139-
Observable data that may be correlated.
140-
141-
Parameters
142-
----------
143-
dimension : int
144-
The dimension of the observed data.
145-
getter : Optional[callable]
146-
An optional callable to construct the Observable from.
147-
"""
148-
149-
def __init__(self, dimension: int = 1, *, getter: Optional[callable] = None):
150-
"""
151-
Initialise an observable with a given dimensionality.
152-
153-
Parameters
154-
----------
155-
dimension : int
156-
The dimension of the observed data.
157-
getter : Optional[callable]
158-
An optional callable to construct the Observable from.
159-
"""
160-
self._dimension = dimension
161-
self._getter = getter
162-
self.atoms = None
163-
164-
def __call__(self, atoms: Atoms, *args, **kwargs) -> list[float]:
165-
"""
166-
Call the user supplied getter if it exits.
167-
168-
Parameters
169-
----------
170-
atoms : Atoms
171-
Atoms object to extract values from.
172-
*args : tuple
173-
Additional positional arguments passed to getter.
174-
**kwargs : dict
175-
Additional kwargs passed getter.
176-
177-
Returns
178-
-------
179-
list[float]
180-
The observed value, with dimensions atoms by self.dimension.
181-
182-
Raises
183-
------
184-
ValueError
185-
If user supplied getter is None.
186-
"""
187-
if self._getter:
188-
value = self._getter(atoms, *args, **kwargs)
189-
if not isinstance(value, list):
190-
return [value]
191-
return value
192-
raise ValueError("No user getter supplied")
193-
194-
@property
195-
def dimension(self):
196-
"""
197-
Dimension of the observable. Commensurate with self.__call__.
198-
199-
Returns
200-
-------
201-
int
202-
Observables dimension.
203-
"""
204-
return self._dimension
205-
206-
def atom_count(self, n_atoms: int):
207-
"""
208-
Atom count to average over.
209-
210-
Parameters
211-
----------
212-
n_atoms : int
213-
Total possible atoms.
214-
215-
Returns
216-
-------
217-
int
218-
Atom count averaged over.
219-
"""
220-
if self.atoms:
221-
if isinstance(self.atoms, list):
222-
return len(self.atoms)
223-
if isinstance(self.atoms, int):
224-
return 1
225-
226-
start = self.atoms.start
227-
stop = self.atoms.stop
228-
step = self.atoms.step
229-
start = start if start is None else 0
230-
stop = stop if stop is None else n_atoms
231-
step = step if step is None else 1
232-
return len(range(start, stop, step))
233-
return 0
234-
235-
236138
class CorrelationKwargs(TypedDict, total=True):
237139
"""Arguments for on-the-fly correlations <ab>."""
238140

janus_core/helpers/observables.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,111 @@
11
"""Module for built-in correlation observables."""
22

3-
from typing import Optional, Union
3+
from typing import TYPE_CHECKING, Optional, Union
44

55
from ase import Atoms, units
66

7-
from janus_core.helpers.janus_types import Observable, SliceLike
7+
if TYPE_CHECKING:
8+
from janus_core.helpers.janus_types import SliceLike
9+
10+
11+
# pylint: disable=too-few-public-methods
12+
class Observable:
13+
"""
14+
Observable data that may be correlated.
15+
16+
Parameters
17+
----------
18+
dimension : int
19+
The dimension of the observed data.
20+
getter : Optional[callable]
21+
An optional callable to construct the Observable from.
22+
"""
23+
24+
def __init__(self, dimension: int = 1, *, getter: Optional[callable] = None):
25+
"""
26+
Initialise an observable with a given dimensionality.
27+
28+
Parameters
29+
----------
30+
dimension : int
31+
The dimension of the observed data.
32+
getter : Optional[callable]
33+
An optional callable to construct the Observable from.
34+
"""
35+
self._dimension = dimension
36+
self._getter = getter
37+
self.atoms = None
38+
39+
def __call__(self, atoms: Atoms, *args, **kwargs) -> list[float]:
40+
"""
41+
Call the user supplied getter if it exits.
42+
43+
Parameters
44+
----------
45+
atoms : Atoms
46+
Atoms object to extract values from.
47+
*args : tuple
48+
Additional positional arguments passed to getter.
49+
**kwargs : dict
50+
Additional kwargs passed getter.
51+
52+
Returns
53+
-------
54+
list[float]
55+
The observed value, with dimensions atoms by self.dimension.
56+
57+
Raises
58+
------
59+
ValueError
60+
If user supplied getter is None.
61+
"""
62+
if self._getter:
63+
value = self._getter(atoms, *args, **kwargs)
64+
if not isinstance(value, list):
65+
return [value]
66+
return value
67+
raise ValueError("No user getter supplied")
68+
69+
@property
70+
def dimension(self):
71+
"""
72+
Dimension of the observable. Commensurate with self.__call__.
73+
74+
Returns
75+
-------
76+
int
77+
Observables dimension.
78+
"""
79+
return self._dimension
80+
81+
def atom_count(self, n_atoms: int):
82+
"""
83+
Atom count to average over.
84+
85+
Parameters
86+
----------
87+
n_atoms : int
88+
Total possible atoms.
89+
90+
Returns
91+
-------
92+
int
93+
Atom count averaged over.
94+
"""
95+
if self.atoms:
96+
if isinstance(self.atoms, list):
97+
return len(self.atoms)
98+
if isinstance(self.atoms, int):
99+
return 1
100+
101+
start = self.atoms.start
102+
stop = self.atoms.stop
103+
step = self.atoms.step
104+
start = start if start is None else 0
105+
stop = stop if stop is None else n_atoms
106+
step = step if step is None else 1
107+
return len(range(start, stop, step))
108+
return 0
8109

9110

10111
class ComponentMixin:
@@ -162,7 +263,7 @@ class Velocity(Observable, ComponentMixin):
162263
def __init__(
163264
self,
164265
components: list[str],
165-
atoms: Optional[Union[list[int], SliceLike]] = None,
266+
atoms: Optional[Union[list[int], "SliceLike"]] = None,
166267
):
167268
"""
168269
Initialise the observable from a symbolic str component and atom index.

0 commit comments

Comments
 (0)