diff --git a/src/resspect/samples_utils.py b/src/resspect/samples_utils.py index 3452db29..95fdc04e 100644 --- a/src/resspect/samples_utils.py +++ b/src/resspect/samples_utils.py @@ -46,6 +46,12 @@ def sep_samples(all_ids: np.array, n_test_val: int, keys are the sample names, values are the ids of objects in each sample. """ + if n_train + 2 * n_test_val > len(all_ids): + raise ValueError( + f"Unable to draw samples of sizes {n_train}, {n_test_val}, and {n_test_val} " + f"from only {len(all_ids)} indices." + ) + samples = {} # separate ids for training diff --git a/tests/resspect/test_bump.py b/tests/resspect/test_bump.py index c6dda80f..98a6e9da 100644 --- a/tests/resspect/test_bump.py +++ b/tests/resspect/test_bump.py @@ -4,28 +4,50 @@ import numpy as np import os -import pytest -from pandas import read_csv +from resspect.bump import bump, fit_bump, protected_exponent, protected_sig -from resspect import bump, fit_bump +def test_protected_exponent(): + """Test the protected_exponent() function.""" + values = np.arange(0, 20) + results = protected_exponent(values) -def test_bump(): - """ - Test the Bump function evaluation. - """ + # Input values 0-10 should all return exp(x). Anything above that should + # return exp(10.0). + np.testing.assert_allclose(results[:11], np.exp(np.arange(0, 11))) + np.testing.assert_allclose(results[10:], [np.exp(10.0)] * 10) + + +def test_protected_sig(): + """Test the protected_sig() function.""" + values = np.arange(-20, 10) + results = protected_sig(values) + + expected_upper = 1.0 / (1.0 + np.exp(np.arange(10, -10, -1))) + expected_lower = 1.0 / (1.0 + np.exp(np.full(10, 10.0))) + + # Input values [-20, -10] should return 1.0 / (1.0 + exp(10.0)) + # and input values [-10, 10] should return 1.0 / (1.0 + exp(-x)) + np.testing.assert_allclose(results[:10], expected_lower) + np.testing.assert_allclose(results[10:], expected_upper) - time = np.array([0]) +def test_bump(): + """Test the Bump function evaluation.""" + time = np.arange(-1, 5, 1) p1 = 0.225 p2 = -2.5 p3 = 0.038 - + res = bump(time, p1, p2, p3) - assert not np.isnan(res).any() - + + # These were manually computed using the function and so this currently + # only will detect future changes in behavior (breakages). + expected = [0.86683499, 0.87300292, 0.87822063, 0.88254351, 0.88601519, 0.88866704] + np.testing.assert_allclose(res, expected) + def test_fit_bump(test_data_path): """ diff --git a/tests/resspect/test_sample_utils.py b/tests/resspect/test_sample_utils.py new file mode 100644 index 00000000..1bc4aebb --- /dev/null +++ b/tests/resspect/test_sample_utils.py @@ -0,0 +1,52 @@ +"""Tests for sample_utils.py.""" + +import numpy as np +import pytest + +from resspect.samples_utils import sep_samples + + +def test_sep_samples(): + """Test that we can generate separate samples.""" + all_ids = np.arange(0, 100) + samples = sep_samples(all_ids, n_test_val=10, n_train=50) + assert len(samples) == 4 + + # Check that each partition is the correct size and disjoint. + assert len(samples["train"]) == 50 + assert len(np.unique(samples["train"])) == 50 + all_seen = np.copy(samples["train"]) + + assert len(samples["val"]) == 10 + assert len(np.unique(samples["val"])) == 10 + all_seen = np.union1d(all_seen, samples["val"]) + assert len(all_seen) == 60 + + assert len(samples["test"]) == 10 + assert len(np.unique(samples["test"])) == 10 + all_seen = np.union1d(all_seen, samples["test"]) + assert len(all_seen) == 70 + + assert len(samples["query"]) == 30 + assert len(np.unique(samples["query"])) == 30 + all_seen = np.union1d(all_seen, samples["query"]) + assert len(all_seen) == 100 + + +def test_sep_samples_too_many(): + """Test that we fail if we try to generate more samples than IDs.""" + all_ids = np.arange(0, 100) + with pytest.raises(ValueError): + samples = sep_samples(all_ids, n_test_val=50, n_train=80) + with pytest.raises(ValueError): + samples = sep_samples(all_ids, n_test_val=15, n_train=80) + + # We are okay with exactly the same number of samples and IDs. + # But the 'query' bucket is empty + samples = sep_samples(all_ids, n_test_val=10, n_train=80) + assert len(samples) == 4 + assert len(samples["query"]) == 0 + + +if __name__ == '__main__': + pytest.main()