Skip to content

Commit 879595d

Browse files
SkafteNickiBordastancldpre-commit-ci[bot]mergify[bot]
authored
Fix precision issue in calibration error (#1919)
* fix implementation * add tests * changelog * skip on older versions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * skip testing on older --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Daniel Stancl <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent c7bca4e commit 879595d

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

CHANGELOG.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3131

3232
### Fixed
3333

34-
-
34+
- Fixed bug in `CalibrationError` where calculations for double precision input was performed in float precision ([#1919](https://github.com/Lightning-AI/torchmetrics/pull/1919))
35+
36+
37+
- Fixed bug related to the `prefix/postfix` arguments in `MetricCollection` and `ClasswiseWrapper` being duplicated ([#1918](https://github.com/Lightning-AI/torchmetrics/pull/1918))
3538

3639

3740
## [1.0.1] - 2023-07-13
@@ -44,8 +47,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4447
- Fixed bug related to expected input format of pycoco in `MeanAveragePrecision` ([#1913](https://github.com/Lightning-AI/torchmetrics/pull/1913))
4548

4649

47-
- Fixed bug related to the `prefix/postfix` arguments in `MetricCollection` and `ClasswiseWrapper` being duplicated ([#1918](https://github.com/Lightning-AI/torchmetrics/pull/1918))
48-
4950
## [1.0.0] - 2022-07-04
5051

5152
### Added

src/torchmetrics/functional/classification/calibration_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _ce_compute(
8484
8585
"""
8686
if isinstance(bin_boundaries, int):
87-
bin_boundaries = torch.linspace(0, 1, bin_boundaries + 1, dtype=torch.float, device=confidences.device)
87+
bin_boundaries = torch.linspace(0, 1, bin_boundaries + 1, dtype=confidences.dtype, device=confidences.device)
8888

8989
if norm not in {"l1", "l2", "max"}:
9090
raise ValueError(f"Argument `norm` is expected to be one of 'l1', 'l2', 'max' but got {norm}")

tests/unittests/classification/test_calibration_error.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
binary_calibration_error,
2525
multiclass_calibration_error,
2626
)
27-
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_9
27+
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_9, _TORCH_GREATER_EQUAL_1_13
2828

2929
from unittests import NUM_CLASSES
3030
from unittests.classification.inputs import _binary_cases, _multiclass_cases
@@ -108,7 +108,8 @@ def test_binary_calibration_error_differentiability(self, inputs):
108108
def test_binary_calibration_error_dtype_cpu(self, inputs, dtype):
109109
"""Test dtype support of the metric on CPU."""
110110
preds, target = inputs
111-
111+
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_13:
112+
pytest.xfail(reason="torch.linspace in metric not supported before pytorch v1.13 for cpu + half")
112113
if (preds < 0).any() and dtype == torch.half:
113114
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
114115
self.run_precision_test_cpu(
@@ -123,6 +124,8 @@ def test_binary_calibration_error_dtype_cpu(self, inputs, dtype):
123124
@pytest.mark.parametrize("dtype", [torch.half, torch.double])
124125
def test_binary_calibration_error_dtype_gpu(self, inputs, dtype):
125126
"""Test dtype support of the metric on GPU."""
127+
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_13:
128+
pytest.xfail(reason="torch.searchsorted in metric not supported before pytorch v1.13 for gpu + half")
126129
preds, target = inputs
127130
self.run_precision_test_gpu(
128131
preds=preds,
@@ -246,3 +249,23 @@ def test_multiclass_calibration_error_dtype_gpu(self, inputs, dtype):
246249
metric_args={"num_classes": NUM_CLASSES},
247250
dtype=dtype,
248251
)
252+
253+
254+
def test_corner_case_due_to_dtype():
255+
"""Test that metric works with edge case where the precision is really important for the right result.
256+
257+
See issue: https://github.com/Lightning-AI/torchmetrics/issues/1907
258+
259+
"""
260+
preds = torch.tensor(
261+
[0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.8000, 0.8000, 0.0100, 0.3300, 0.3400, 0.9900, 0.6100],
262+
dtype=torch.float64,
263+
)
264+
target = torch.tensor([1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0])
265+
266+
assert np.allclose(
267+
ECE(99).measure(preds.numpy(), target.numpy()), binary_calibration_error(preds, target, n_bins=99)
268+
), "The metric should be close to the netcal implementation"
269+
assert np.allclose(
270+
ECE(100).measure(preds.numpy(), target.numpy()), binary_calibration_error(preds, target, n_bins=100)
271+
), "The metric should be close to the netcal implementation"

0 commit comments

Comments
 (0)