Add comprehensive validation tests for Loss metric [skip ci]#3631
Add comprehensive validation tests for Loss metric [skip ci]#3631Goutam16-Withcode wants to merge 2 commits intopytorch:masterfrom
Conversation
- Add test_loss_with_empty_batch: Handle empty batch gracefully - Add test_loss_with_invalid_output_structure: Validate output tuple format - Add test_loss_with_nan_loss: Handle NaN loss values with math.isnan - Add test_loss_compute_before_update: Verify NotComputableError raised - Add test_loss_multiple_updates_and_compute: Test batch aggregation - Add test_loss_with_custom_batch_size_fn: Support custom batch sizing These tests improve code coverage and ensure better error handling for edge cases in the Loss metric implementation. Fixes and improvements: - Proper NaN detection using math.isnan for float results - Better error handling validation - Comprehensive test coverage for edge cases
There was a problem hiding this comment.
Pull request overview
Adds additional unit tests around ignite.metrics.Loss to exercise edge cases and error handling, aiming to improve coverage and confidence in metric behavior.
Changes:
- Added new tests for empty batches, invalid output structures, NaN loss results, compute-before-update, multi-update aggregation, and custom batch sizing.
- Introduced NaN validation using
math.isnanon computed results. - Expanded assertions around
Lossmetric update/compute flows (though some are currently quite weak/redundant).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tests/ignite/metrics/test_loss.py
Outdated
| # Should handle empty batches gracefully | ||
| loss.update((y_pred, y)) | ||
| assert loss._num_examples == 0 | ||
|
|
||
|
|
There was a problem hiding this comment.
test_loss_with_empty_batch claims to validate graceful handling of empty batches, but it only asserts _num_examples == 0 after update. It doesn’t verify that compute() still raises NotComputableError, nor that the metric state isn’t poisoned (e.g., _sum becoming NaN if the criterion returns NaN for an empty batch) and that a subsequent non-empty update() produces a valid result. Consider strengthening the assertions to cover the intended behavior, or clarify/adjust the test expectations to match what Loss.update actually guarantees.
| # Should handle empty batches gracefully | |
| loss.update((y_pred, y)) | |
| assert loss._num_examples == 0 | |
| # Before any update, compute should not be possible | |
| with pytest.raises(NotComputableError): | |
| loss.compute() | |
| # Should handle empty batches gracefully: no examples counted, still not computable | |
| loss.update((y_pred, y)) | |
| assert loss._num_examples == 0 | |
| with pytest.raises(NotComputableError): | |
| loss.compute() | |
| # A subsequent non-empty batch should produce a valid, finite loss | |
| y_pred2 = torch.randn(4, 3) | |
| y2 = torch.tensor([0, 1, 2, 1], dtype=torch.long) | |
| loss.update((y_pred2, y2)) | |
| result = loss.compute() | |
| assert torch.isfinite(result).item() |
tests/ignite/metrics/test_loss.py
Outdated
| # Test with single tensor instead of tuple | ||
| y_pred = torch.randn(4, 3) | ||
|
|
||
| with pytest.raises((ValueError, TypeError)): |
There was a problem hiding this comment.
test_loss_with_invalid_output_structure uses pytest.raises((ValueError, TypeError)), but with the current Loss.update implementation a single Tensor input deterministically triggers a ValueError from tuple-unpacking. Narrowing this to the specific exception (and ideally matching the message once Loss.update validates input types explicitly) will make the test more precise and less likely to mask unrelated errors.
| with pytest.raises((ValueError, TypeError)): | |
| with pytest.raises(ValueError): |
tests/ignite/metrics/test_loss.py
Outdated
| def test_loss_compute_before_update(): | ||
| """Test Loss metric compute before any update.""" | ||
| loss = Loss(nll_loss) | ||
|
|
||
| with pytest.raises(NotComputableError): | ||
| loss.compute() | ||
|
|
||
|
|
There was a problem hiding this comment.
test_loss_compute_before_update duplicates test_zero_div earlier in this file, but with weaker assertions (no match on the error message). To avoid redundant coverage and improve stability, consider removing this test and relying on test_zero_div, or consolidate into a single test that asserts both the exception type and the expected message.
| def test_loss_compute_before_update(): | |
| """Test Loss metric compute before any update.""" | |
| loss = Loss(nll_loss) | |
| with pytest.raises(NotComputableError): | |
| loss.compute() |
tests/ignite/metrics/test_loss.py
Outdated
| # Should aggregate losses from multiple batches | ||
| result = loss.compute() | ||
| assert isinstance(result, (float, torch.Tensor)) |
There was a problem hiding this comment.
test_loss_multiple_updates_and_compute doesn’t actually validate aggregation correctness: Loss.compute() always returns a float, and the current assertion only checks the type. To make this test meaningful, assert the computed value equals the expected aggregated average (e.g., compare to nll_loss on each batch and the weighted average, or reuse the known expected loss from y_test_1).
| # Should aggregate losses from multiple batches | |
| result = loss.compute() | |
| assert isinstance(result, (float, torch.Tensor)) | |
| # Should aggregate losses from multiple batches: since both batches | |
| # come from y_test_1() and have the same size, the overall average | |
| # loss should be the mean of the two per-batch losses. | |
| expected_1 = nll_loss(y_pred_1, y_1) | |
| expected_2 = nll_loss(y_pred_2, y_2) | |
| expected = (expected_1 + expected_2) / 2.0 | |
| result = loss.compute() | |
| assert_almost_equal(result, expected) |
- Strengthen test_loss_with_empty_batch with compute() verification - Narrow exception type in test_loss_with_invalid_output_structure to ValueError - Remove duplicate test_loss_compute_before_update (duplicates test_zero_div) - Add concrete value assertion in test_loss_multiple_updates_and_compute - Verify custom batch_size callable is actually invoked These improvements address code review suggestions: - Better edge case coverage - More precise exception handling - Elimination of redundant tests - Meaningful assertions instead of type-only checks - Verification of custom callable usage
|
@Goutam16-Withcode Thanks for the pr, our general workflow is issue->pull request. This pull request may not be accepted.You are free to open an issue though. |
Summary
This PR adds 6 comprehensive validation test cases to improve code coverage and error handling for the
Lossmetric.Changes
test_loss_with_empty_batch: Handle empty batch gracefullytest_loss_with_invalid_output_structure: Validate output tuple formattest_loss_with_nan_loss: Handle NaN loss values with math.isnantest_loss_compute_before_update: Verify NotComputableError raisedtest_loss_multiple_updates_and_compute: Test batch aggregationtest_loss_with_custom_batch_size_fn: Support custom batch sizingImprovements
Testing
Checklist