diff --git a/tests/test_for_utils/test_whittaker_base.py b/tests/test_for_utils/test_whittaker_base.py index 0ee0412..f8a0e4e 100644 --- a/tests/test_for_utils/test_whittaker_base.py +++ b/tests/test_for_utils/test_whittaker_base.py @@ -229,11 +229,12 @@ def test_get_checked_lambda( ("error", TypeError), # Number 5 ], ) -def test_weight_generator( +def test_weight_generator_identical_weights( combination: Tuple[Any, Union[np.ndarray, float, Type[Exception]]] ) -> None: """ - Tests the weight generator. + Tests the weight generator when provided with weights that are identical for all + signals. The ``combination`` parameter defines @@ -260,16 +261,38 @@ def test_weight_generator( # otherwise, the output is compared to the expected output # Case 1: the expected output is a scalar if isinstance(expected_output, (float, int)): - for w in get_weight_generator(weights=weights, n_series=n_series): - assert isinstance(w, (float, int)) - assert w == expected_output + for wght in get_weight_generator(weights=weights, n_series=n_series): + assert isinstance(wght, (float, int)) + assert wght == expected_output return # Case 2: the expected output is an array - for w in get_weight_generator(weights=weights, n_series=n_series): - assert isinstance(w, np.ndarray) - assert np.array_equal(w, expected_output) + for wght in get_weight_generator(weights=weights, n_series=n_series): + assert isinstance(wght, np.ndarray) + assert np.array_equal(wght, expected_output) + + +def test_weight_generator_different_weights() -> None: + """ + Tests the weight generator when provided with weights that are different for each + signal. + + """ + + # the weights are defined + weights = np.array( + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [6.0, 7.0, 8.0, 9.0, 10.0], + [11.0, 12.0, 13.0, 14.0, 15.0], + ] + ) + weights_ref = weights.copy() + + # the generator is tested + for idx, wght in enumerate(get_weight_generator(weights=weights, n_series=3)): + assert np.array_equal(wght, weights_ref[idx, ::]) @pytest.mark.parametrize("combination", [(True, 244_9755_000.0), (False, 490_000.0)])