Skip to content

Commit ce710bb

Browse files
committed
Fix tests and respond to comments
1 parent fe7ad20 commit ce710bb

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

janus_core/processing/post_process.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def compute_vaf(
191191
index: SliceLike = (0, None, 1),
192192
filter_atoms: MaybeSequence[MaybeSequence[Optional[int]]] = ((None),),
193193
time_step: float = 1.0,
194-
) -> NDArray[float64]:
194+
) -> tuple[NDArray[float64], list[NDArray[float64]]]:
195195
"""
196196
Compute the velocity autocorrelation function (VAF) of `data`.
197197
@@ -219,8 +219,30 @@ def compute_vaf(
219219
220220
Returns
221221
-------
222-
MaybeSequence[NDArray[float64]]
222+
lags : numpy.ndarray
223+
Lags at which the VAFs have been computed.
224+
vafs : list[numpy.ndarray]
223225
Computed VAF(s).
226+
227+
Notes
228+
-----
229+
`filter_atoms` is given as a series of sequences of atoms, where
230+
each element in the series denotes a VAF subset to calculate and
231+
each sequence determines the atoms (by index) to be included in that VAF.
232+
233+
E.g.
234+
235+
.. code-block: Python
236+
237+
# Species indices in cell
238+
na = (1, 3, 5, 7)
239+
cl = (2, 4, 6, 8)
240+
241+
compute_vaf(..., filter_atoms=(na, cl))
242+
243+
Would compute separate VAFs for each species.
244+
245+
By default, one VAF will be computed for all atoms in the structure.
224246
"""
225247
# Ensure if passed scalars they are turned into correct dimensionality
226248
if not isinstance(filter_atoms, Sequence):

tests/test_post_process.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,21 +179,21 @@ def test_vaf(tmp_path):
179179
vaf_filter = ((3, 4), (1, 2, 3))
180180

181181
data = read(DATA_PATH / "lj-traj.xyz", index=":")
182-
vaf = post_process.compute_vaf(data)
182+
lags, vaf = post_process.compute_vaf(data)
183183
expected = np.loadtxt(DATA_PATH / "vaf-lj.dat")
184184

185185
assert isinstance(vaf, list)
186186
assert len(vaf) == 1
187187
assert isinstance(vaf[0], np.ndarray)
188188
assert vaf[0] == approx(expected, rel=1e-9)
189189

190-
vaf = post_process.compute_vaf(data, fft=True)
190+
lags, vaf = post_process.compute_vaf(data, fft=True)
191191

192192
assert isinstance(vaf, list)
193193
assert len(vaf) == 1
194194
assert isinstance(vaf[0], np.ndarray)
195195

196-
vaf = post_process.compute_vaf(
196+
lags, vaf = post_process.compute_vaf(
197197
data, filter_atoms=vaf_filter, filenames=[tmp_path / name for name in vaf_names]
198198
)
199199

@@ -205,5 +205,8 @@ def test_vaf(tmp_path):
205205
assert (tmp_path / name).exists()
206206
expected = np.loadtxt(DATA_PATH / name)
207207
written = np.loadtxt(tmp_path / name)
208+
w_lag, w_vaf = written[:, 0], written[:, 1]
209+
208210
assert vaf[i] == approx(expected, rel=1e-9)
209-
assert vaf[i] == approx(written, rel=1e-9)
211+
assert lags == approx(w_lag, rel=1e-9)
212+
assert vaf[i] == approx(w_vaf, rel=1e-9)

0 commit comments

Comments
 (0)