Skip to content

Commit c581cb8

Browse files
Clarify values
Use kwargs only in inits atoms -> atoms_slice replace atom_count method with value_count method Stress now can slice atoms
1 parent bad0bd5 commit c581cb8

File tree

3 files changed

+85
-53
lines changed

3 files changed

+85
-53
lines changed

janus_core/helpers/correlator.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -243,18 +243,18 @@ def __init__(
243243
self._get_b = b
244244
self._b_args, self._b_kwargs = (), {}
245245

246-
self._a_atoms = self._get_a.atom_count(n_atoms)
247-
self._b_atoms = self._get_b.atom_count(n_atoms)
246+
a_values = self._get_a.value_count(n_atoms)
247+
b_values = self._get_b.value_count(n_atoms)
248+
249+
if a_values != b_values:
250+
raise ValueError("Observables have inconsistent sizes")
251+
self._values = a_values
248252

249253
self._correlators = []
250-
for _ in zip(range(self._get_a.dimension), range(self._get_b.dimension)):
251-
for _ in zip(
252-
range(max(1, self._a_atoms)),
253-
range(max(1, self._b_atoms)),
254-
):
255-
self._correlators.append(
256-
Correlator(blocks=blocks, points=points, averaging=averaging)
257-
)
254+
for _ in range(self._values):
255+
self._correlators.append(
256+
Correlator(blocks=blocks, points=points, averaging=averaging)
257+
)
258258
self._update_frequency = update_frequency
259259

260260
@property
@@ -298,11 +298,9 @@ def get(self) -> tuple[Iterable[float], Iterable[float]]:
298298
The correlation lag times t'.
299299
"""
300300
if self._correlators:
301-
avg_value, lags = self._correlators[0].get()
302-
for cor in self._correlators[1:]:
303-
value, _ = cor.get()
304-
avg_value += value
305-
return avg_value / max(1, min(self._a_atoms, self._b_atoms)), lags
301+
_, lags = self._correlators[0].get()
302+
avg_value = sum([cor.get()[0] for cor in self._correlators]) / self._values
303+
return avg_value, lags
306304
return [], []
307305

308306
def __str__(self) -> str:

janus_core/helpers/observables.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def __init__(self, dimension: int = 1):
3232
The dimension of the observed data.
3333
"""
3434
self._dimension = dimension
35-
self.atoms = None
3635

3736
def __call__(self, atoms: Atoms, *args, **kwargs) -> list[float]:
3837
"""
@@ -53,37 +52,33 @@ def __call__(self, atoms: Atoms, *args, **kwargs) -> list[float]:
5352
The observed value, with dimensions atoms by self.dimension.
5453
"""
5554

56-
@property
57-
def dimension(self):
55+
def value_count(self, n_atoms: int | None = None) -> int:
5856
"""
59-
Dimension of the observable. Commensurate with self.__call__.
57+
Count of values returned by __call__.
58+
59+
Parameters
60+
----------
61+
n_atoms : int | None
62+
Atom count to expand atoms_slice.
6063
6164
Returns
6265
-------
6366
int
64-
Observables dimension.
67+
The number of values returned by __call__.
6568
"""
66-
return self._dimension
69+
return self.dimension
6770

68-
def atom_count(self, n_atoms: int):
71+
@property
72+
def dimension(self):
6973
"""
70-
Atom count to average over.
71-
72-
Parameters
73-
----------
74-
n_atoms : int
75-
Total possible atoms.
74+
Dimension of the observable. Commensurate with self.__call__.
7675
7776
Returns
7877
-------
7978
int
80-
Atom count averaged over.
79+
Observables dimension.
8180
"""
82-
if self.atoms:
83-
if isinstance(self.atoms, list):
84-
return len(self.atoms)
85-
return slicelike_len_for(self.n_atoms)
86-
return 0
81+
return self._dimension
8782

8883

8984
class ComponentMixin:
@@ -163,18 +158,28 @@ class Stress(Observable, ComponentMixin):
163158
----------
164159
components : list[str]
165160
Symbols for correlated tensor components, xx, yy, etc.
161+
atoms_slice : list[int] | SliceLike | None = None
162+
List or slice of atoms to observe velocities from.
166163
include_ideal_gas : bool
167164
Calculate with the ideal gas contribution.
168165
"""
169166

170-
def __init__(self, components: list[str], *, include_ideal_gas: bool = True):
167+
def __init__(
168+
self,
169+
*,
170+
components: list[str],
171+
atoms_slice: list[int] | SliceLike | None = None,
172+
include_ideal_gas: bool = True,
173+
):
171174
"""
172175
Initialise the observable from a symbolic str component.
173176
174177
Parameters
175178
----------
176179
components : list[str]
177180
Symbols for tensor components, xx, yy, etc.
181+
atoms_slice : list[int] | SliceLike | None = None
182+
List or slice of atoms to observe velocities from.
178183
include_ideal_gas : bool
179184
Calculate with the ideal gas contribution.
180185
"""
@@ -194,6 +199,11 @@ def __init__(self, components: list[str], *, include_ideal_gas: bool = True):
194199
)
195200
self._set_components(components)
196201

202+
if atoms_slice:
203+
self.atoms_slice = atoms_slice
204+
else:
205+
self.atoms_slice = slice(0, None, 1)
206+
197207
Observable.__init__(self, len(components))
198208
self.include_ideal_gas = include_ideal_gas
199209

@@ -215,14 +225,18 @@ def __call__(self, atoms: Atoms, *args, **kwargs) -> list[float]:
215225
list[float]
216226
The stress components in GPa units.
217227
"""
228+
sliced_atoms = atoms[self.atoms_slice]
229+
sliced_atoms.calc = atoms.calc
218230
return (
219-
atoms.get_stress(include_ideal_gas=self.include_ideal_gas, voigt=True)
231+
sliced_atoms.get_stress(
232+
include_ideal_gas=self.include_ideal_gas, voigt=True
233+
)
220234
/ units.GPa
221235
)[self._indices]
222236

223237

224-
StressDiagonal = Stress(["xx", "yy", "zz"])
225-
ShearStress = Stress(["xy", "yz", "zx"])
238+
StressDiagonal = Stress(components=["xx", "yy", "zz"])
239+
ShearStress = Stress(components=["xy", "yz", "zx"])
226240

227241

228242
# pylint: disable=too-few-public-methods
@@ -234,14 +248,15 @@ class Velocity(Observable, ComponentMixin):
234248
----------
235249
components : list[str]
236250
Symbols for velocity components, x, y, z.
237-
atoms : Optional[Union[list[int], SliceLike]]
251+
atoms_slice : list[int] | SliceLike | None = None
238252
List or slice of atoms to observe velocities from.
239253
"""
240254

241255
def __init__(
242256
self,
257+
*,
243258
components: list[str],
244-
atoms: list[int] | SliceLike | None = None,
259+
atoms_slice: list[int] | SliceLike | None = None,
245260
):
246261
"""
247262
Initialise the observable from a symbolic str component and atom index.
@@ -250,17 +265,18 @@ def __init__(
250265
----------
251266
components : list[str]
252267
Symbols for tensor components, x, y, and z.
253-
atoms : Union[list[int], SliceLike]
268+
atoms_slice : Union[list[int], SliceLike]
254269
List or slice of atoms to observe velocities from.
255270
"""
256271
ComponentMixin.__init__(self, components={"x": 0, "y": 1, "z": 2})
257272
self._set_components(components)
258273

259274
Observable.__init__(self, len(components))
260-
if atoms:
261-
self.atoms = atoms
275+
276+
if atoms_slice:
277+
self.atoms_slice = atoms_slice
262278
else:
263-
atoms = slice(None, None, None)
279+
self.atoms_slice = slice(0, None, 1)
264280

265281
def __call__(self, atoms: Atoms, *args, **kwargs) -> list[float]:
266282
"""
@@ -280,4 +296,22 @@ def __call__(self, atoms: Atoms, *args, **kwargs) -> list[float]:
280296
list[float]
281297
The velocity values.
282298
"""
283-
return atoms.get_velocities()[self.atoms, :][:, self._indices].flatten()
299+
return atoms.get_velocities()[self.atoms_slice, :][:, self._indices].flatten()
300+
301+
def value_count(self, n_atoms: int | None = None) -> int:
302+
"""
303+
Count of values returned by __call__.
304+
305+
Parameters
306+
----------
307+
n_atoms : int | None
308+
Atom count to expand atoms_slice.
309+
310+
Returns
311+
-------
312+
int
313+
The number of values returned by __call__.
314+
"""
315+
if isinstance(self.atoms_slice, list):
316+
return len(self.atoms_slice) * self.dimension
317+
return slicelike_len_for(self.atoms_slice, self.n_atoms) * self.dimension

tests/test_correlator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,17 @@ def test_vaf(tmp_path):
9494
file_prefix=file_prefix,
9595
correlation_kwargs=[
9696
{
97-
"a": Velocity(["x", "y", "z"], na),
98-
"b": Velocity(["x", "y", "z"], na),
97+
"a": Velocity(components=["x", "y", "z"], atoms_slice=na),
98+
"b": Velocity(components=["x", "y", "z"], atoms_slice=na),
9999
"name": "vaf_Na",
100100
"blocks": 1,
101101
"points": 11,
102102
"averaging": 1,
103103
"update_frequency": 1,
104104
},
105105
{
106-
"a": Velocity(["x", "y", "z"], cl),
107-
"b": Velocity(["x", "y", "z"], cl),
106+
"a": Velocity(components=["x", "y", "z"], atoms_slice=cl),
107+
"b": Velocity(components=["x", "y", "z"], atoms_slice=cl),
108108
"name": "vaf_Cl",
109109
"blocks": 1,
110110
"points": 11,
@@ -128,8 +128,8 @@ def test_vaf(tmp_path):
128128
vaf = safe_load(cor)
129129
vaf_na = np.array(vaf["vaf_Na"]["value"])
130130
vaf_cl = np.array(vaf["vaf_Cl"]["value"])
131-
assert vaf_na == approx(vaf_post[0], rel=1e-5)
132-
assert vaf_cl == approx(vaf_post[1], rel=1e-5)
131+
assert vaf_na * 3 == approx(vaf_post[0], rel=1e-5)
132+
assert vaf_cl * 3 == approx(vaf_post[1], rel=1e-5)
133133

134134

135135
def test_md_correlations(tmp_path):
@@ -154,8 +154,8 @@ def test_md_correlations(tmp_path):
154154
file_prefix=file_prefix,
155155
correlation_kwargs=[
156156
{
157-
"a": Stress([("xy")]),
158-
"b": Stress([("xy")]),
157+
"a": Stress(components=[("xy")]),
158+
"b": Stress(components=[("xy")]),
159159
"name": "stress_xy_auto_cor",
160160
"blocks": 1,
161161
"points": 11,

0 commit comments

Comments
 (0)