Skip to content

Commit 0b49b58

Browse files
authored
Merge pull request #9 from zoccoler/improve_tests
Add comprehensive test suite for signal classifier
2 parents 2c6f2c8 + 1bfa2be commit 0b49b58

File tree

10 files changed

+581
-11
lines changed

10 files changed

+581
-11
lines changed

.github/workflows/test_and_deploy.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ jobs:
4646
git clone --depth 1 https://github.com/pyvista/gl-ci-helpers.git
4747
powershell gl-ci-helpers/appveyor/install_opengl.ps1
4848
49+
# Install OpenMP for macOS to enable dtaidistance C extensions
50+
- name: Install OpenMP on macOS
51+
if: runner.os == 'macOS'
52+
run: |
53+
brew install libomp
54+
4955
# note: if you need dependencies from conda, considering using
5056
# setup-miniconda: https://github.com/conda-incubator/setup-miniconda
5157
# and

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ dependencies = [
3636
"tsfresh",
3737
"dtaidistance",
3838
"napari-signal-selector>=0.0.6",
39-
"cmap"
39+
"cmap",
40+
"packaging"
4041
]
4142

4243
[project.optional-dependencies]
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
import tempfile
5+
from pathlib import Path
6+
7+
from napari_signal_classifier._classification import (
8+
split_table_train_test, train_signal_classifier, predict_signal_labels,
9+
generate_sub_signals_table, train_sub_signal_classifier,
10+
generate_sub_signal_templates_from_annotations, predict_sub_signal_labels,
11+
_get_classifier_file_path
12+
)
13+
14+
15+
@pytest.fixture
16+
def sample_signal_table():
17+
"""Create a sample signal table for testing."""
18+
np.random.seed(42)
19+
n_labels = 10
20+
n_frames = 20
21+
22+
data = []
23+
for label in range(n_labels):
24+
for frame in range(n_frames):
25+
# Create different patterns for different annotations
26+
if label < 5: # Class 1
27+
intensity = 10 + 5 * np.sin(frame / 3) + np.random.normal(0, 0.5)
28+
annotation = 1
29+
else: # Class 2
30+
intensity = 15 + 3 * np.cos(frame / 2) + np.random.normal(0, 0.5)
31+
annotation = 2
32+
33+
data.append({
34+
'label': label,
35+
'frame': frame,
36+
'mean_intensity': intensity,
37+
'Annotations': annotation
38+
})
39+
40+
return pd.DataFrame(data)
41+
42+
43+
@pytest.fixture
44+
def sample_sub_signal_table():
45+
"""Create a sample table with sub-signal annotations."""
46+
np.random.seed(42)
47+
data = []
48+
49+
for label in range(5):
50+
for frame in range(30):
51+
# Create sub-signals at different positions
52+
if 5 <= frame < 10:
53+
annotation = 1 # Sub-signal type 1
54+
intensity = 20 + 3 * np.sin(frame)
55+
elif 15 <= frame < 20:
56+
annotation = 2 # Sub-signal type 2
57+
intensity = 15 + 2 * np.cos(frame)
58+
else:
59+
annotation = 0 # Background
60+
intensity = 10 + np.random.normal(0, 0.5)
61+
62+
data.append({
63+
'label': label,
64+
'frame': frame,
65+
'mean_intensity': intensity,
66+
'Annotations': annotation
67+
})
68+
69+
return pd.DataFrame(data)
70+
71+
72+
def test_get_classifier_file_path():
73+
"""Test classifier file path generation."""
74+
with tempfile.TemporaryDirectory() as tmpdir:
75+
# Test with directory path
76+
path = _get_classifier_file_path(tmpdir)
77+
assert path.name == 'signal_classifier.pkl'
78+
assert path.parent == Path(tmpdir)
79+
80+
# Test with None
81+
path = _get_classifier_file_path(None)
82+
assert path.name == 'signal_classifier.pkl'
83+
84+
# Test with specific file path
85+
file_path = Path(tmpdir) / 'custom.pkl'
86+
path = _get_classifier_file_path(str(file_path))
87+
assert path == file_path
88+
89+
90+
def test_split_table_train_test(sample_signal_table):
91+
"""Test train/test split of signal table."""
92+
train, test = split_table_train_test(
93+
sample_signal_table, train_size=0.8, random_state=42
94+
)
95+
96+
assert len(train) > 0
97+
assert len(test) > 0
98+
assert len(train) + len(test) == len(sample_signal_table)
99+
100+
# Check that labels are properly split
101+
train_labels = train['label'].unique()
102+
test_labels = test['label'].unique()
103+
assert len(set(train_labels).intersection(set(test_labels))) == 0
104+
105+
106+
def test_train_signal_classifier(sample_signal_table):
107+
"""Test signal classifier training."""
108+
with tempfile.TemporaryDirectory() as tmpdir:
109+
classifier_path = train_signal_classifier(
110+
sample_signal_table,
111+
classifier_path=tmpdir,
112+
train_size=0.6,
113+
random_state=42,
114+
n_estimators=10
115+
)
116+
117+
assert classifier_path is not None
118+
assert Path(classifier_path).exists()
119+
120+
121+
def test_predict_signal_labels(sample_signal_table):
122+
"""Test signal label prediction."""
123+
with tempfile.TemporaryDirectory() as tmpdir:
124+
# Train classifier
125+
classifier_path = train_signal_classifier(
126+
sample_signal_table,
127+
classifier_path=tmpdir,
128+
train_size=0.6,
129+
random_state=42,
130+
n_estimators=10
131+
)
132+
133+
# Predict
134+
result_table = predict_signal_labels(
135+
sample_signal_table,
136+
classifier_path
137+
)
138+
139+
assert 'Predictions' in result_table.columns
140+
assert result_table['Predictions'].dtype == int
141+
assert len(result_table) == len(sample_signal_table)
142+
143+
144+
def test_train_sub_signal_classifier(sample_sub_signal_table):
145+
"""Test sub-signal classifier training."""
146+
with tempfile.TemporaryDirectory() as tmpdir:
147+
classifier_path = train_sub_signal_classifier(
148+
sample_sub_signal_table,
149+
classifier_path=tmpdir,
150+
train_size=0.6,
151+
random_state=42,
152+
n_estimators=10
153+
)
154+
155+
assert classifier_path is not None
156+
assert Path(classifier_path).exists()
157+
158+
159+
def test_generate_sub_signal_templates_from_annotations(sample_sub_signal_table):
160+
"""Test sub-signal template generation."""
161+
templates = generate_sub_signal_templates_from_annotations(
162+
sample_sub_signal_table
163+
)
164+
165+
assert isinstance(templates, dict)
166+
assert len(templates) > 0
167+
assert all(isinstance(v, np.ndarray) for v in templates.values())
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from napari_signal_classifier._detection import (
6+
normalize, align_signals, generate_template_mean, generate_templates_by_category,
7+
detect_sub_signal_by_template, extract_sub_signals_by_templates
8+
)
9+
from napari_signal_classifier._sub_signals import SubSignalCollection
10+
11+
12+
@pytest.fixture
13+
def sample_sub_signal_table():
14+
"""Create a sample table with sub-signal annotations."""
15+
np.random.seed(42)
16+
data = []
17+
18+
for label in range(5):
19+
for frame in range(30):
20+
# Create sub-signals at different positions
21+
if 5 <= frame < 10:
22+
annotation = 1 # Sub-signal type 1
23+
intensity = 20 + 3 * np.sin(frame)
24+
elif 15 <= frame < 20:
25+
annotation = 2 # Sub-signal type 2
26+
intensity = 15 + 2 * np.cos(frame)
27+
else:
28+
annotation = 0 # Background
29+
intensity = 10 + np.random.normal(0, 0.5)
30+
31+
data.append({
32+
'label': label,
33+
'frame': frame,
34+
'mean_intensity': intensity,
35+
'Annotations': annotation
36+
})
37+
38+
return pd.DataFrame(data)
39+
40+
41+
def test_normalize():
42+
"""Test signal normalization."""
43+
signal = np.array([1, 2, 3, 4, 5])
44+
45+
# Test z-score normalization
46+
normalized_z = normalize(signal, method='zscores')
47+
assert np.isclose(np.mean(normalized_z), 0, atol=1e-10)
48+
assert np.isclose(np.std(normalized_z), 1, atol=1e-10)
49+
50+
# Test min-max normalization
51+
normalized_mm = normalize(signal, method='minmax')
52+
assert np.isclose(np.min(normalized_mm), 0)
53+
assert np.isclose(np.max(normalized_mm), 1)
54+
55+
# Test invalid method
56+
with pytest.raises(ValueError):
57+
normalize(signal, method='invalid')
58+
59+
60+
def test_align_signals():
61+
"""Test signal alignment using DTW."""
62+
reference = np.sin(np.linspace(0, 2*np.pi, 50))
63+
signal = np.sin(np.linspace(0, 2*np.pi, 50) + 0.1) # Slightly shifted
64+
65+
aligned = align_signals(reference, signal, detrend=False)
66+
67+
assert len(aligned) == len(reference)
68+
assert isinstance(aligned, np.ndarray)
69+
70+
71+
def test_generate_template_mean():
72+
"""Test template generation from replicates."""
73+
# Create similar signals with slight variations
74+
replicates = [
75+
np.sin(np.linspace(0, 2*np.pi, 50)) + np.random.normal(0, 0.1, 50)
76+
for _ in range(5)
77+
]
78+
79+
template = generate_template_mean(replicates, detrend=False)
80+
81+
assert len(template) == len(replicates[0])
82+
assert isinstance(template, np.ndarray)
83+
84+
85+
def test_detect_sub_signal_by_template():
86+
"""Test sub-signal detection using template matching."""
87+
# Create a composite signal with a known pattern
88+
template = np.sin(np.linspace(0, 2*np.pi, 20))
89+
composite = np.concatenate([
90+
np.random.normal(0, 0.5, 30),
91+
template,
92+
np.random.normal(0, 0.5, 30),
93+
template,
94+
np.random.normal(0, 0.5, 30)
95+
])
96+
97+
peaks = detect_sub_signal_by_template(
98+
composite, template, threshold=0.5
99+
)
100+
101+
assert len(peaks) >= 2 # Should detect at least 2 peaks
102+
assert isinstance(peaks, np.ndarray)
103+
104+
105+
def test_extract_sub_signals_by_templates(sample_sub_signal_table):
106+
"""Test sub-signal extraction using templates."""
107+
from napari_signal_classifier._classification import generate_sub_signal_templates_from_annotations
108+
109+
# Generate templates first
110+
templates = generate_sub_signal_templates_from_annotations(
111+
sample_sub_signal_table
112+
)
113+
114+
# Extract sub-signals
115+
collection = extract_sub_signals_by_templates(
116+
sample_sub_signal_table,
117+
'mean_intensity',
118+
'label',
119+
'frame',
120+
templates,
121+
threshold=0.5
122+
)
123+
124+
assert isinstance(collection, SubSignalCollection)
125+
assert len(collection.sub_signals) > 0
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from napari_signal_classifier._features import get_signal_features
6+
7+
8+
@pytest.fixture
9+
def sample_signal_table():
10+
"""Create a sample signal table for testing."""
11+
np.random.seed(42)
12+
n_labels = 10
13+
n_frames = 20
14+
15+
data = []
16+
for label in range(n_labels):
17+
for frame in range(n_frames):
18+
# Create different patterns for different annotations
19+
if label < 5: # Class 1
20+
intensity = 10 + 5 * np.sin(frame / 3) + np.random.normal(0, 0.5)
21+
annotation = 1
22+
else: # Class 2
23+
intensity = 15 + 3 * np.cos(frame / 2) + np.random.normal(0, 0.5)
24+
annotation = 2
25+
26+
data.append({
27+
'label': label,
28+
'frame': frame,
29+
'mean_intensity': intensity,
30+
'Annotations': annotation
31+
})
32+
33+
return pd.DataFrame(data)
34+
35+
36+
def test_get_signal_features(sample_signal_table):
37+
"""Test signal feature extraction."""
38+
features = get_signal_features(
39+
sample_signal_table,
40+
column_id='label',
41+
column_sort='frame',
42+
column_value='mean_intensity'
43+
)
44+
45+
assert isinstance(features, pd.DataFrame)
46+
assert len(features) == len(sample_signal_table['label'].unique())
47+
assert features.shape[1] > 0 # Should have multiple feature columns

0 commit comments

Comments
 (0)