Skip to content

Commit

Permalink
test/feat: tested weight generator more thoroughly
Browse files Browse the repository at this point in the history
  • Loading branch information
MothNik committed May 20, 2024
1 parent 6885f7c commit e78823e
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions tests/test_for_utils/test_whittaker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,12 @@ def test_get_checked_lambda(
("error", TypeError), # Number 5
],
)
def test_weight_generator(
def test_weight_generator_identical_weights(
combination: Tuple[Any, Union[np.ndarray, float, Type[Exception]]]
) -> None:
"""
Tests the weight generator.
Tests the weight generator when provided with weights that are identical for all
signals.
The ``combination`` parameter defines
Expand All @@ -260,16 +261,38 @@ def test_weight_generator(
# otherwise, the output is compared to the expected output
# Case 1: the expected output is a scalar
if isinstance(expected_output, (float, int)):
for w in get_weight_generator(weights=weights, n_series=n_series):
assert isinstance(w, (float, int))
assert w == expected_output
for wght in get_weight_generator(weights=weights, n_series=n_series):
assert isinstance(wght, (float, int))
assert wght == expected_output

return

# Case 2: the expected output is an array
for w in get_weight_generator(weights=weights, n_series=n_series):
assert isinstance(w, np.ndarray)
assert np.array_equal(w, expected_output)
for wght in get_weight_generator(weights=weights, n_series=n_series):
assert isinstance(wght, np.ndarray)
assert np.array_equal(wght, expected_output)


def test_weight_generator_different_weights() -> None:
"""
Tests the weight generator when provided with weights that are different for each
signal.
"""

# the weights are defined
weights = np.array(
[
[1.0, 2.0, 3.0, 4.0, 5.0],
[6.0, 7.0, 8.0, 9.0, 10.0],
[11.0, 12.0, 13.0, 14.0, 15.0],
]
)
weights_ref = weights.copy()

# the generator is tested
for idx, wght in enumerate(get_weight_generator(weights=weights, n_series=3)):
assert np.array_equal(wght, weights_ref[idx, ::])


@pytest.mark.parametrize("combination", [(True, 244_9755_000.0), (False, 490_000.0)])
Expand Down

0 comments on commit e78823e

Please sign in to comment.