diff --git a/resspect/cosmo_metric_utils.py b/resspect/cosmo_metric_utils.py index 5bc99bb3..e11799ca 100644 --- a/resspect/cosmo_metric_utils.py +++ b/resspect/cosmo_metric_utils.py @@ -83,10 +83,10 @@ def fish_deriv_m(redshift, model, step, screen=False): at given redshifts. """ - Ob0=0.022 - Om0=model[1] - Ode0 =model[2] - cosmo = w0waCDM(model[0], Ob0, Om0, Ode0, model[3],model[4], name="w0waCDM") + Ob0 = 0.022 + Om0 = model[1] + Ode0 = model[2] + cosmo = w0waCDM(H0=model[0], Ob0=Ob0, Om0=Om0, Ode0=Ode0, w0=model[3], wa=model[4], name="w0waCDM") cosmo=assign_cosmo(cosmo, model) diff --git a/tests/conftest.py b/tests/conftest.py index b3f5319e..eb681c86 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,10 @@ def test_data_path(): return Path(__file__).parent.parent / "data" / "tests" +@pytest.fixture +def test_des_data_path(): + return Path(__file__).parent.parent / "data" / "tests" / "DES_data" + @pytest.fixture(scope="session") def base_temp(tmp_path_factory): diff --git a/tests/test_batch_functions.py b/tests/test_batch_functions.py index 02183401..8b2c4327 100644 --- a/tests/test_batch_functions.py +++ b/tests/test_batch_functions.py @@ -3,13 +3,15 @@ import numpy as np -@pytest.mark.skip(reason='Test failing with numpy.AxisError - Check inputs') def test_entropy_from_probs_b_M_C(): from resspect import batch_functions np.random.seed(42) - x = np.random.rand(100) + num_data_points = 3 + committee_size = 10 + num_classes = 5 + x = np.random.rand(num_data_points, committee_size, num_classes) foo = batch_functions.entropy_from_probs_b_M_C(x) diff --git a/tests/test_learn_loop.py b/tests/test_learn_loop.py index 049dc0f9..110a4fde 100644 --- a/tests/test_learn_loop.py +++ b/tests/test_learn_loop.py @@ -1,33 +1,28 @@ - +import os import pytest - - -# ToDo @emilleishida - Check this test later -@pytest.mark.skip("Test failing for now") -def test_can_run_learn_loop(extract_feature): - """Just a sanity test""" - - from resspect.learn_loop import learn_loop - - learn_loop(nloops=1, - features_method="bazin", - strategy="RandomSampling", - path_to_features=extract_feature, - output_metrics_file="just_a_name.csv", - output_queried_file="just_other_name.csv") - - -@pytest.fixture(scope="function") -def extract_feature(path_to_test_data): - from resspect import fit_snpcc - - path_to_data_dir = path_to_test_data - output_file = 'output_file.dat' - - fit_snpcc(path_to_data_dir=path_to_data_dir, features_file=output_file) - - return output_file - +import tempfile + +from resspect import fit_snpcc +from resspect.learn_loop import learn_loop + +def test_can_run_learn_loop(test_des_data_path): + """Test that learn_loop can load data and run.""" + with tempfile.TemporaryDirectory() as dir_name: + # Create the feature files to use for the learning loop. + output_file = os.path.join(dir_name, "output_file.dat") + fit_snpcc( + path_to_data_dir=test_des_data_path, + features_file=output_file, + ) + + learn_loop( + nloops=1, + features_method="bazin", + strategy="RandomSampling", + path_to_features=output_file, + output_metrics_file=os.path.join(dir_name,"just_a_name.csv"), + output_queried_file=os.path.join(dir_name,"just_other_name.csv"), + ) if __name__ == '__main__': pytest.main()