diff --git a/tests/metrics/expected_register.pkl b/tests/metrics/expected_register_p2t.pkl similarity index 100% rename from tests/metrics/expected_register.pkl rename to tests/metrics/expected_register_p2t.pkl diff --git a/tests/metrics/expected_register_t2p.pkl b/tests/metrics/expected_register_t2p.pkl new file mode 100644 index 000000000..e74109b39 Binary files /dev/null and b/tests/metrics/expected_register_t2p.pkl differ diff --git a/tests/metrics/test_confusion_matrix.py b/tests/metrics/test_confusion_matrix.py index f3ee8d8eb..396934fe0 100644 --- a/tests/metrics/test_confusion_matrix.py +++ b/tests/metrics/test_confusion_matrix.py @@ -104,9 +104,16 @@ def test_pairwise_iou_matching(target, prediction): assert torch.allclose(result, expected_result, atol=1e-4) -def test_match_prediction(target, prediction): +def test_match_prediction_to_target(target, prediction): result = match_predictions_to_targets(target, prediction, iou_threshold=0.5) - with open(Path(__file__).parent / "expected_register.pkl", mode="rb") as infile: + with open(Path(__file__).parent / "expected_register_p2t.pkl", mode="rb") as infile: + expected_result = pickle.load(infile) + assert result == expected_result + + +def test_match_target_to_prediction(target, prediction): + result = match_targets_to_predictions(target, prediction, iou_threshold=0.5) + with open(Path(__file__).parent / "expected_register_t2p.pkl", mode="rb") as infile: expected_result = pickle.load(infile) assert result == expected_result