Skip to content

Commit

Permalink
Merge pull request #14 from LSSTDESC/simple_unit_tests
Browse files Browse the repository at this point in the history
Add some simple unit tests for some of the supporting functions
  • Loading branch information
jeremykubica authored Oct 1, 2024
2 parents 17a4c33 + 614bd89 commit 423984b
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 11 deletions.
6 changes: 6 additions & 0 deletions src/resspect/samples_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ def sep_samples(all_ids: np.array, n_test_val: int,
keys are the sample names, values are the ids of
objects in each sample.
"""
if n_train + 2 * n_test_val > len(all_ids):
raise ValueError(
f"Unable to draw samples of sizes {n_train}, {n_test_val}, and {n_test_val} "
f"from only {len(all_ids)} indices."
)

samples = {}

# separate ids for training
Expand Down
44 changes: 33 additions & 11 deletions tests/resspect/test_bump.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,50 @@

import numpy as np
import os
import pytest

from pandas import read_csv
from resspect.bump import bump, fit_bump, protected_exponent, protected_sig

from resspect import bump, fit_bump

def test_protected_exponent():
"""Test the protected_exponent() function."""
values = np.arange(0, 20)
results = protected_exponent(values)

def test_bump():
"""
Test the Bump function evaluation.
"""
# Input values 0-10 should all return exp(x). Anything above that should
# return exp(10.0).
np.testing.assert_allclose(results[:11], np.exp(np.arange(0, 11)))
np.testing.assert_allclose(results[10:], [np.exp(10.0)] * 10)


def test_protected_sig():
"""Test the protected_sig() function."""
values = np.arange(-20, 10)
results = protected_sig(values)

expected_upper = 1.0 / (1.0 + np.exp(np.arange(10, -10, -1)))
expected_lower = 1.0 / (1.0 + np.exp(np.full(10, 10.0)))

# Input values [-20, -10] should return 1.0 / (1.0 + exp(10.0))
# and input values [-10, 10] should return 1.0 / (1.0 + exp(-x))
np.testing.assert_allclose(results[:10], expected_lower)
np.testing.assert_allclose(results[10:], expected_upper)


time = np.array([0])
def test_bump():
"""Test the Bump function evaluation."""
time = np.arange(-1, 5, 1)
p1 = 0.225
p2 = -2.5
p3 = 0.038

res = bump(time, p1, p2, p3)

assert not np.isnan(res).any()


# These were manually computed using the function and so this currently
# only will detect future changes in behavior (breakages).
expected = [0.86683499, 0.87300292, 0.87822063, 0.88254351, 0.88601519, 0.88866704]
np.testing.assert_allclose(res, expected)


def test_fit_bump(test_data_path):
"""
Expand Down
52 changes: 52 additions & 0 deletions tests/resspect/test_sample_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Tests for sample_utils.py."""

import numpy as np
import pytest

from resspect.samples_utils import sep_samples


def test_sep_samples():
"""Test that we can generate separate samples."""
all_ids = np.arange(0, 100)
samples = sep_samples(all_ids, n_test_val=10, n_train=50)
assert len(samples) == 4

# Check that each partition is the correct size and disjoint.
assert len(samples["train"]) == 50
assert len(np.unique(samples["train"])) == 50
all_seen = np.copy(samples["train"])

assert len(samples["val"]) == 10
assert len(np.unique(samples["val"])) == 10
all_seen = np.union1d(all_seen, samples["val"])
assert len(all_seen) == 60

assert len(samples["test"]) == 10
assert len(np.unique(samples["test"])) == 10
all_seen = np.union1d(all_seen, samples["test"])
assert len(all_seen) == 70

assert len(samples["query"]) == 30
assert len(np.unique(samples["query"])) == 30
all_seen = np.union1d(all_seen, samples["query"])
assert len(all_seen) == 100


def test_sep_samples_too_many():
"""Test that we fail if we try to generate more samples than IDs."""
all_ids = np.arange(0, 100)
with pytest.raises(ValueError):
samples = sep_samples(all_ids, n_test_val=50, n_train=80)
with pytest.raises(ValueError):
samples = sep_samples(all_ids, n_test_val=15, n_train=80)

# We are okay with exactly the same number of samples and IDs.
# But the 'query' bucket is empty
samples = sep_samples(all_ids, n_test_val=10, n_train=80)
assert len(samples) == 4
assert len(samples["query"]) == 0


if __name__ == '__main__':
pytest.main()

0 comments on commit 423984b

Please sign in to comment.