Skip to content

Commit

Permalink
Add simple unit tests for sample_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Sep 27, 2024
1 parent a681428 commit 602e308
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ __pycache__/

auxiliary_files/cosmo
auxiliary_files/cosmo.hpp
.DS_Store
5 changes: 5 additions & 0 deletions resspect/samples_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def sep_samples(all_ids: np.array, n_test_val: int,
keys are the sample names, values are the ids of
objects in each sample.
"""
if n_train + 2 * n_test_val > len(all_ids):
raise ValueError(
f"Unable to draw samples of sizes {n_train}, {n_test_val}, and {n_test_val} "
f"from only {len(all_ids)} indices."
)
samples = {}

# separate ids for training
Expand Down
52 changes: 52 additions & 0 deletions tests/test_sample_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Tests for sample_utils.py."""

import numpy as np
import pytest

from resspect.samples_utils import sep_samples


def test_sep_samples():
"""Test that we can generate separate samples."""
all_ids = np.arange(0, 100)
samples = sep_samples(all_ids, n_test_val=10, n_train=50)
assert len(samples) == 4

# Check that each partition is the correct size and disjoint.
assert len(samples["train"]) == 50
assert len(np.unique(samples["train"])) == 50
all_seen = np.copy(samples["train"])

assert len(samples["val"]) == 10
assert len(np.unique(samples["val"])) == 10
all_seen = np.union1d(all_seen, samples["val"])
assert len(all_seen) == 60

assert len(samples["test"]) == 10
assert len(np.unique(samples["test"])) == 10
all_seen = np.union1d(all_seen, samples["test"])
assert len(all_seen) == 70

assert len(samples["query"]) == 30
assert len(np.unique(samples["query"])) == 30
all_seen = np.union1d(all_seen, samples["query"])
assert len(all_seen) == 100


def test_sep_samples_too_many():
"""Test that we fail if we try to generate more samples than IDs."""
all_ids = np.arange(0, 100)
with pytest.raises(ValueError):
samples = sep_samples(all_ids, n_test_val=50, n_train=80)
with pytest.raises(ValueError):
samples = sep_samples(all_ids, n_test_val=15, n_train=80)

# We are okay with exactly the same number of samples and IDs.
# But the 'query' bucket is empty
samples = sep_samples(all_ids, n_test_val=10, n_train=80)
assert len(samples) == 4
assert len(samples["query"]) == 0


if __name__ == '__main__':
pytest.main()

0 comments on commit 602e308

Please sign in to comment.