From 43185a5a6f7df499c282d31cafa25467d05f5d59 Mon Sep 17 00:00:00 2001 From: Jan Wodnicki <25946310+janwodnicki@users.noreply.github.com> Date: Tue, 3 Dec 2024 13:24:34 +0100 Subject: [PATCH] [Fix] Squeeze y_values in signal_interpolate --- neurokit2/signal/signal_interpolate.py | 4 +++- tests/tests_signal.py | 13 ++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/neurokit2/signal/signal_interpolate.py b/neurokit2/signal/signal_interpolate.py index c44e8053c1..9574ebbc78 100644 --- a/neurokit2/signal/signal_interpolate.py +++ b/neurokit2/signal/signal_interpolate.py @@ -101,6 +101,8 @@ def signal_interpolate( x_values = np.squeeze(x_values.values) if isinstance(x_new, pd.Series): x_new = np.squeeze(x_new.values) + if isinstance(y_values, pd.Series): + y_values = np.squeeze(y_values.values) if len(x_values) != len(y_values): raise ValueError("x_values and y_values must be of the same length.") @@ -158,7 +160,7 @@ def signal_interpolate( # scipy.interpolate.PchipInterpolator for constant extrapolation akin to the behavior of # scipy.interpolate.interp1d with fill_value=([y_values[0]], [y_values[-1]]. fill_value = ([interpolated[first_index]], [interpolated[last_index]]) - elif isinstance(fill_value, float) or isinstance(fill_value, int): + elif isinstance(fill_value, (float, int)): # if only a single integer or float is provided as a fill value, format as a tuple fill_value = ([fill_value], [fill_value]) diff --git a/tests/tests_signal.py b/tests/tests_signal.py index 399ee7babd..121ab2496a 100644 --- a/tests/tests_signal.py +++ b/tests/tests_signal.py @@ -208,14 +208,25 @@ def test_signal_filter_with_missing(): def test_signal_interpolate(): + # Test with arrays x_axis = np.linspace(start=10, stop=30, num=10) signal = np.cos(x_axis) + x_new = np.arange(1000) - interpolated = nk.signal_interpolate(x_axis, signal, x_new=np.arange(1000)) + interpolated = nk.signal_interpolate(x_axis, signal, x_new) assert len(interpolated) == 1000 assert interpolated[0] == signal[0] assert interpolated[-1] == signal[-1] + # Test with Series + x_axis = pd.Series(x_axis) + signal = pd.Series(signal) + x_new = pd.Series(x_new) + + interpolated = nk.signal_interpolate(x_axis, signal, x_new) + assert len(interpolated) == 1000 + assert interpolated[0] == signal.iloc[0] + assert interpolated[-1] == signal.iloc[-1] def test_signal_findpeaks():