Skip to content

Commit c9f58f8

Browse files
committed
adjusts tests_signal.py to increase test coverage (and code checks/formatting)
1 parent 18eadf3 commit c9f58f8

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

neurokit2/signal/signal_flatintervals.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def signal_flatintervals(signal, sampling_rate, threshold=0.01, duration_min=60)
4141
signal = np.concatenate([ecg, flatline_1, ecg, flatline_2, ecg, flatline_1])
4242
4343
nk.signal_flatintervals(signal)
44-
44+
4545
"""
4646

4747
# Identify flanks: +1 for beginning plateau; -1 for ending plateau.

tests/tests_signal.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -197,24 +197,45 @@ def test_signal_interpolate():
197197

198198

199199
def test_signal_flatintervals():
200+
# parameters
200201
sampling_rate = 128
201202
one_minute = 60
202203
duration_min = 60
203204

205+
# signal components
204206
ecg = nk.ecg_simulate(duration=10 * one_minute, sampling_rate=sampling_rate)
205207
flatline_1 = np.full(10 * one_minute * sampling_rate, -4.0)
206208
flatline_2 = np.zeros(10 * one_minute * sampling_rate)
207-
signal = np.concatenate([ecg, flatline_1, ecg, flatline_2, ecg, flatline_1])
208209

209-
flatintervals = np.array(nk.signal_flatintervals(signal, sampling_rate, duration_min=duration_min)) / sampling_rate
210+
# test signals
211+
signal1 = np.concatenate([ecg, flatline_1, ecg, flatline_2, ecg, flatline_1])
212+
signal2 = np.concatenate([flatline_1, ecg, flatline_2, ecg, flatline_1, ecg])
213+
ground_truth1 = np.array([(10, 20), (30, 40), (50, 60)]) * one_minute
214+
ground_truth2 = np.array([(0, 10), (20, 30), (40, 50)]) * one_minute
210215

211-
assert len(flatintervals) == 3
212-
ground_truth = np.array([(10, 20), (30, 40), (50, 60)]) * one_minute
216+
# flatline interval detection
217+
flatintervals1 = (
218+
np.array(nk.signal_flatintervals(signal1, sampling_rate, duration_min=duration_min)) / sampling_rate
219+
)
220+
flatintervals2 = (
221+
np.array(nk.signal_flatintervals(signal2, sampling_rate, duration_min=duration_min)) / sampling_rate
222+
)
223+
224+
# check correct number of flatline intervals
225+
assert len(flatintervals1) == 3
226+
assert len(flatintervals2) == 3
227+
228+
# check correct detection of flatline intervals
229+
for interval, truth in zip(flatintervals1, ground_truth1):
230+
interval_begin, interval_end = interval
231+
true_begin, true_end = truth
232+
assert interval_begin >= true_begin and interval_begin < true_begin + duration_min
233+
assert interval_end > true_end - duration_min and interval_end < true_end
213234

214-
for interval, truth in zip(flatintervals, ground_truth):
235+
for interval, truth in zip(flatintervals2, ground_truth2):
215236
interval_begin, interval_end = interval
216237
true_begin, true_end = truth
217-
assert interval_begin > true_begin and interval_begin < true_begin + duration_min
238+
assert interval_begin >= true_begin and interval_begin < true_begin + duration_min
218239
assert interval_end > true_end - duration_min and interval_end < true_end
219240

220241

0 commit comments

Comments
 (0)