Skip to content

Commit dc88090

Browse files
Remove getter
1 parent 423a23c commit dc88090

File tree

2 files changed

+3
-48
lines changed

2 files changed

+3
-48
lines changed

janus_core/helpers/observables.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,23 @@ class Observable:
1717
----------
1818
dimension : int
1919
The dimension of the observed data.
20-
getter : Optional[callable]
21-
An optional callable to construct the Observable from.
2220
"""
2321

24-
def __init__(self, dimension: int = 1, *, getter: Optional[callable] = None):
22+
def __init__(self, dimension: int = 1):
2523
"""
2624
Initialise an observable with a given dimensionality.
2725
2826
Parameters
2927
----------
3028
dimension : int
3129
The dimension of the observed data.
32-
getter : Optional[callable]
33-
An optional callable to construct the Observable from.
3430
"""
3531
self._dimension = dimension
36-
self._getter = getter
3732
self.atoms = None
3833

3934
def __call__(self, atoms: Atoms, *args, **kwargs) -> list[float]:
4035
"""
41-
Call the user supplied getter if it exits.
36+
Signature for returning observed value from atoms.
4237
4338
Parameters
4439
----------
@@ -53,18 +48,7 @@ def __call__(self, atoms: Atoms, *args, **kwargs) -> list[float]:
5348
-------
5449
list[float]
5550
The observed value, with dimensions atoms by self.dimension.
56-
57-
Raises
58-
------
59-
ValueError
60-
If user supplied getter is None.
6151
"""
62-
if self._getter:
63-
value = self._getter(atoms, *args, **kwargs)
64-
if not isinstance(value, list):
65-
return [value]
66-
return value
67-
raise ValueError("No user getter supplied")
6852

6953
@property
7054
def dimension(self):

tests/test_correlator.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from collections.abc import Iterable
44
from pathlib import Path
55

6-
from ase import Atoms
76
from ase.io import read
87
from ase.units import GPa
98
import numpy as np
@@ -15,7 +14,6 @@
1514
from janus_core.calculations.single_point import SinglePoint
1615
from janus_core.helpers import post_process
1716
from janus_core.helpers.correlator import Correlator
18-
from janus_core.helpers.janus_types import Observable
1917
from janus_core.helpers.observables import Stress, Velocity
2018

2119
DATA_PATH = Path(__file__).parent / "data"
@@ -146,15 +144,6 @@ def test_md_correlations(tmp_path):
146144
calc_kwargs={"model": MODEL_PATH},
147145
)
148146

149-
def user_observable_a(atoms: Atoms, kappa, *, gamma) -> float:
150-
"""User specified getter for correlation."""
151-
return (
152-
gamma
153-
* kappa
154-
* atoms.get_stress(include_ideal_gas=True, voigt=True)[-1]
155-
/ GPa
156-
)
157-
158147
nve = NVE(
159148
struct=single_point.struct,
160149
temp=300.0,
@@ -164,15 +153,6 @@ def user_observable_a(atoms: Atoms, kappa, *, gamma) -> float:
164153
stats_every=1,
165154
file_prefix=file_prefix,
166155
correlation_kwargs=[
167-
{
168-
"a": (Observable(1, getter=user_observable_a), (2,), {"gamma": 2}),
169-
"b": Stress([("xy")]),
170-
"name": "user_correlation",
171-
"blocks": 1,
172-
"points": 11,
173-
"averaging": 1,
174-
"update_frequency": 1,
175-
},
176156
{
177157
"a": Stress([("xy")]),
178158
"b": Stress([("xy")]),
@@ -195,8 +175,7 @@ def user_observable_a(atoms: Atoms, kappa, *, gamma) -> float:
195175
assert cor_path.exists()
196176
with open(cor_path, encoding="utf8") as in_file:
197177
cor = load(in_file, Loader=Loader)
198-
assert len(cor) == 2
199-
assert "user_correlation" in cor
178+
assert len(cor) == 1
200179
assert "stress_xy_auto_cor" in cor
201180

202181
stress_cor = cor["stress_xy_auto_cor"]
@@ -206,11 +185,3 @@ def user_observable_a(atoms: Atoms, kappa, *, gamma) -> float:
206185
direct = correlate(pxy, pxy, fft=False)
207186
# input data differs due to i/o, error is expected 1e-5
208187
assert direct == approx(value, rel=1e-5)
209-
210-
user_cor = cor["user_correlation"]
211-
value, lags = user_cor["value"], stress_cor["lags"]
212-
assert len(value) == len(lags) == 11
213-
214-
direct = correlate([v * 4.0 for v in pxy], pxy, fft=False)
215-
# input data differs due to i/o, error is expected 1e-5
216-
assert direct == approx(value, rel=1e-5)

0 commit comments

Comments
 (0)