Skip to content

Commit 6dcb61c

Browse files
ChristophReich1996pre-commit-ci[bot]SkafteNicki
authored
Add support for SQ & RQ as well as per-class metrics (#2381)
* Fix RQ and SQ * Change return type and refactor flag name * Fix typing * changelog * input/output docstring * guard against older versions --------- Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
1 parent aead080 commit 6dcb61c

File tree

8 files changed

+199
-12
lines changed

8 files changed

+199
-12
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
- Added `QualityWithNoReference` metric ([#2288](https://github.com/Lightning-AI/torchmetrics/pull/2288))
1919

2020

21+
- Added support for calculating segmentation quality and recognition quality in `PanopticQuality` metric ([#2381](https://github.com/Lightning-AI/torchmetrics/pull/2381))
22+
23+
2124
### Changed
2225

2326
- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))

src/torchmetrics/detection/_deprecated.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
from typing import Any, Collection
22

33
from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality
4+
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
45
from torchmetrics.utilities.prints import _deprecated_root_import_class
56

7+
if not _TORCH_GREATER_EQUAL_1_12:
8+
__doctest_skip__ = [
9+
"_PanopticQuality",
10+
"_PanopticQuality.*",
11+
"_ModifiedPanopticQuality",
12+
"_ModifiedPanopticQuality.*",
13+
]
14+
615

716
class _ModifiedPanopticQuality(ModifiedPanopticQuality):
817
"""Wrapper for deprecated import.

src/torchmetrics/detection/panoptic_qualities.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@
2626
_validate_inputs,
2727
)
2828
from torchmetrics.metric import Metric
29-
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
29+
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_GREATER_EQUAL_1_12
3030
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
3131

3232
if not _MATPLOTLIB_AVAILABLE:
3333
__doctest_skip__ = ["PanopticQuality.plot", "ModifiedPanopticQuality.plot"]
3434

3535

36+
if not _TORCH_GREATER_EQUAL_1_12:
37+
__doctest_skip__ = ["PanopticQuality", "PanopticQuality.*", "ModifiedPanopticQuality", "ModifiedPanopticQuality.*"]
38+
39+
3640
class PanopticQuality(Metric):
3741
r"""Compute the `Panoptic Quality`_ for panoptic segmentations.
3842
@@ -47,6 +51,23 @@ class PanopticQuality(Metric):
4751
Points in the target tensor that do not map to a known category ID are automatically ignored in the metric
4852
computation.
4953
54+
As input to ``forward`` and ``update`` the metric accepts the following input:
55+
56+
- ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, *spatial_dims, 2)``, where there needs to
57+
be at least one spatial dimension.
58+
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, *spatial_dims, 2)``, where there needs to
59+
be at least one spatial dimension.
60+
61+
As output to ``forward`` and ``compute`` the metric returns the following output:
62+
63+
- ``quality`` (:class:`~torch.Tensor`): If ``return_sq_and_rq=False`` and ``return_per_class=False`` then a
64+
single scalar tensor is returned with average panoptic quality over all classes. If ``return_sq_and_rq=True``
65+
and ``return_per_class=False`` a tensor of length 3 is returned with panoptic, segmentation and recognition
66+
quality (in that order). If If ``return_sq_and_rq=False`` and ``return_per_class=True`` a tensor of length
67+
equal to the number of classes are returned, with panoptic quality for each class. Finally, if both arguments
68+
are ``True`` a tensor of shape ``(3, C)`` is returned with individual panoptic, segmentation and recognition
69+
quality for each class.
70+
5071
Args:
5172
things:
5273
Set of ``category_id`` for countable things.
@@ -55,6 +76,10 @@ class PanopticQuality(Metric):
5576
allow_unknown_preds_category:
5677
Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
5778
computation or raise an exception when found.
79+
return_sq_and_rq:
80+
Boolean flag to specify if Segmentation Quality and Recognition Quality should be also returned.
81+
return_per_class:
82+
Boolean flag to specify if the per-class values should be returned or the class average.
5883
5984
6085
Raises:
@@ -80,6 +105,40 @@ class PanopticQuality(Metric):
80105
>>> panoptic_quality(preds, target)
81106
tensor(0.5463, dtype=torch.float64)
82107
108+
You can also return the segmentation and recognition quality alognside the PQ
109+
>>> from torch import tensor
110+
>>> from torchmetrics.detection import PanopticQuality
111+
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
112+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
113+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
114+
... [[0, 0], [7, 0], [6, 0], [1, 0]],
115+
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
116+
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
117+
... [[0, 1], [0, 1], [6, 0], [0, 1]],
118+
... [[0, 1], [0, 1], [6, 0], [1, 0]],
119+
... [[0, 1], [7, 0], [1, 0], [1, 0]],
120+
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
121+
>>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}, return_sq_and_rq=True)
122+
>>> panoptic_quality(preds, target)
123+
tensor([0.5463, 0.6111, 0.6667], dtype=torch.float64)
124+
125+
You can also specify to return the per-class metrics
126+
>>> from torch import tensor
127+
>>> from torchmetrics.detection import PanopticQuality
128+
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
129+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
130+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
131+
... [[0, 0], [7, 0], [6, 0], [1, 0]],
132+
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
133+
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
134+
... [[0, 1], [0, 1], [6, 0], [0, 1]],
135+
... [[0, 1], [0, 1], [6, 0], [1, 0]],
136+
... [[0, 1], [7, 0], [1, 0], [1, 0]],
137+
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
138+
>>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}, return_per_class=True)
139+
>>> panoptic_quality(preds, target)
140+
tensor([[0.5185, 0.0000, 0.6667, 1.0000]], dtype=torch.float64)
141+
83142
"""
84143

85144
is_differentiable: bool = False
@@ -98,16 +157,22 @@ def __init__(
98157
things: Collection[int],
99158
stuffs: Collection[int],
100159
allow_unknown_preds_category: bool = False,
160+
return_sq_and_rq: bool = False,
161+
return_per_class: bool = False,
101162
**kwargs: Any,
102163
) -> None:
103164
super().__init__(**kwargs)
165+
if not _TORCH_GREATER_EQUAL_1_12:
166+
raise RuntimeError("Panoptic Quality metric requires PyTorch 1.12 or later")
104167

105168
things, stuffs = _parse_categories(things, stuffs)
106169
self.things = things
107170
self.stuffs = stuffs
108171
self.void_color = _get_void_color(things, stuffs)
109172
self.cat_id_to_continuous_id = _get_category_id_to_continuous_id(things, stuffs)
110173
self.allow_unknown_preds_category = allow_unknown_preds_category
174+
self.return_sq_and_rq = return_sq_and_rq
175+
self.return_per_class = return_per_class
111176

112177
# per category intermediate metrics
113178
num_categories = len(things) + len(stuffs)
@@ -154,7 +219,16 @@ def update(self, preds: Tensor, target: Tensor) -> None:
154219

155220
def compute(self) -> Tensor:
156221
"""Compute panoptic quality based on inputs passed in to ``update`` previously."""
157-
return _panoptic_quality_compute(self.iou_sum, self.true_positives, self.false_positives, self.false_negatives)
222+
pq, sq, rq, pq_avg, sq_avg, rq_avg = _panoptic_quality_compute(
223+
self.iou_sum, self.true_positives, self.false_positives, self.false_negatives
224+
)
225+
if self.return_per_class:
226+
if self.return_sq_and_rq:
227+
return torch.stack((pq, sq, rq), dim=-1)
228+
return pq.view(1, -1)
229+
if self.return_sq_and_rq:
230+
return torch.stack((pq_avg, sq_avg, rq_avg), dim=0)
231+
return pq_avg
158232

159233
def plot(
160234
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
@@ -337,7 +411,10 @@ def update(self, preds: Tensor, target: Tensor) -> None:
337411

338412
def compute(self) -> Tensor:
339413
"""Compute panoptic quality based on inputs passed in to ``update`` previously."""
340-
return _panoptic_quality_compute(self.iou_sum, self.true_positives, self.false_positives, self.false_negatives)
414+
_, _, _, pq_avg, _, _ = _panoptic_quality_compute(
415+
self.iou_sum, self.true_positives, self.false_positives, self.false_negatives
416+
)
417+
return pq_avg
341418

342419
def plot(
343420
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None

src/torchmetrics/functional/detection/_deprecated.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
from torch import Tensor
44

55
from torchmetrics.functional.detection.panoptic_qualities import modified_panoptic_quality, panoptic_quality
6+
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
67
from torchmetrics.utilities.prints import _deprecated_root_import_func
78

9+
if not _TORCH_GREATER_EQUAL_1_12:
10+
__doctest_skip__ = ["_panoptic_quality", "_modified_panoptic_quality"]
11+
812

913
def _modified_panoptic_quality(
1014
preds: Tensor,

src/torchmetrics/functional/detection/_panoptic_quality_common.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def _panoptic_quality_compute(
449449
true_positives: Tensor,
450450
false_positives: Tensor,
451451
false_negatives: Tensor,
452-
) -> Tensor:
452+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
453453
"""Compute the final panoptic quality from interim values.
454454
455455
Args:
@@ -459,11 +459,17 @@ def _panoptic_quality_compute(
459459
false_negatives: the FN value from the update step
460460
461461
Returns:
462-
Panoptic quality as a tensor containing a single scalar.
462+
A tuple containing the per-class panoptic, segmentation and recognition quality followed by the averages
463463
464464
"""
465-
# per category calculation
466-
denominator = (true_positives + 0.5 * false_positives + 0.5 * false_negatives).double()
467-
panoptic_quality = torch.where(denominator > 0.0, iou_sum / denominator, 0.0)
468-
# Reduce across categories. TODO: is it useful to have the option of returning per class metrics?
469-
return torch.mean(panoptic_quality[denominator > 0])
465+
# compute segmentation and recognition quality (per-class)
466+
sq: Tensor = torch.where(true_positives > 0.0, iou_sum / true_positives, 0.0)
467+
denominator: Tensor = true_positives + 0.5 * false_positives + 0.5 * false_negatives
468+
rq: Tensor = torch.where(denominator > 0.0, true_positives / denominator, 0.0)
469+
# compute per-class panoptic quality
470+
pq: Tensor = sq * rq
471+
# compute averages
472+
pq_avg: Tensor = torch.mean(pq[denominator > 0])
473+
sq_avg: Tensor = torch.mean(sq[denominator > 0])
474+
rq_avg: Tensor = torch.mean(rq[denominator > 0])
475+
return pq, sq, rq, pq_avg, sq_avg, rq_avg

src/torchmetrics/functional/detection/panoptic_qualities.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from typing import Collection
1515

16+
import torch
1617
from torch import Tensor
1718

1819
from torchmetrics.functional.detection._panoptic_quality_common import (
@@ -24,6 +25,10 @@
2425
_prepocess_inputs,
2526
_validate_inputs,
2627
)
28+
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
29+
30+
if not _TORCH_GREATER_EQUAL_1_12:
31+
__doctest_skip__ = ["panoptic_quality", "modified_panoptic_quality"]
2732

2833

2934
def panoptic_quality(
@@ -32,6 +37,8 @@ def panoptic_quality(
3237
things: Collection[int],
3338
stuffs: Collection[int],
3439
allow_unknown_preds_category: bool = False,
40+
return_sq_and_rq: bool = False,
41+
return_per_class: bool = False,
3542
) -> Tensor:
3643
r"""Compute `Panoptic Quality`_ for panoptic segmentations.
3744
@@ -61,6 +68,10 @@ def panoptic_quality(
6168
allow_unknown_preds_category:
6269
Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
6370
computation or raise an exception when found.
71+
return_sq_and_rq:
72+
Boolean flag to specify if Segmentation Quality and Recognition Quality should be also returned.
73+
return_per_class:
74+
Boolean flag to specify if the per-class values should be returned or the class average.
6475
6576
Raises:
6677
ValueError:
@@ -91,7 +102,59 @@ def panoptic_quality(
91102
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7})
92103
tensor(0.5463, dtype=torch.float64)
93104
105+
You can also return the segmentation and recognition quality alognside the PQ
106+
>>> from torch import tensor
107+
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
108+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
109+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
110+
... [[0, 0], [7, 0], [6, 0], [1, 0]],
111+
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
112+
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
113+
... [[0, 1], [0, 1], [6, 0], [0, 1]],
114+
... [[0, 1], [0, 1], [6, 0], [1, 0]],
115+
... [[0, 1], [7, 0], [1, 0], [1, 0]],
116+
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
117+
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}, return_sq_and_rq=True)
118+
tensor([0.5463, 0.6111, 0.6667], dtype=torch.float64)
119+
120+
You can also specify to return the per-class metrics
121+
>>> from torch import tensor
122+
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
123+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
124+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
125+
... [[0, 0], [7, 0], [6, 0], [1, 0]],
126+
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
127+
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
128+
... [[0, 1], [0, 1], [6, 0], [0, 1]],
129+
... [[0, 1], [0, 1], [6, 0], [1, 0]],
130+
... [[0, 1], [7, 0], [1, 0], [1, 0]],
131+
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
132+
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}, return_per_class=True)
133+
tensor([[0.5185, 0.0000, 0.6667, 1.0000]], dtype=torch.float64)
134+
135+
You can also specify to return the per-class metrics and the segmentation and recognition quality
136+
>>> from torch import tensor
137+
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
138+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
139+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
140+
... [[0, 0], [7, 0], [6, 0], [1, 0]],
141+
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
142+
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
143+
... [[0, 1], [0, 1], [6, 0], [0, 1]],
144+
... [[0, 1], [0, 1], [6, 0], [1, 0]],
145+
... [[0, 1], [7, 0], [1, 0], [1, 0]],
146+
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
147+
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7},
148+
... return_per_class=True, return_sq_and_rq=True)
149+
tensor([[0.5185, 0.7778, 0.6667],
150+
[0.0000, 0.0000, 0.0000],
151+
[0.6667, 0.6667, 1.0000],
152+
[1.0000, 1.0000, 1.0000]], dtype=torch.float64)
153+
94154
"""
155+
if not _TORCH_GREATER_EQUAL_1_12:
156+
raise RuntimeError("Panoptic Quality metric requires PyTorch 1.12 or later")
157+
95158
things, stuffs = _parse_categories(things, stuffs)
96159
_validate_inputs(preds, target)
97160
void_color = _get_void_color(things, stuffs)
@@ -101,7 +164,19 @@ def panoptic_quality(
101164
iou_sum, true_positives, false_positives, false_negatives = _panoptic_quality_update(
102165
flatten_preds, flatten_target, cat_id_to_continuous_id, void_color
103166
)
104-
return _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives)
167+
pq, sq, rq, pq_avg, sq_avg, rq_avg = _panoptic_quality_compute(
168+
iou_sum,
169+
true_positives,
170+
false_positives,
171+
false_negatives,
172+
)
173+
if return_per_class:
174+
if return_sq_and_rq:
175+
return torch.stack((pq, sq, rq), dim=-1)
176+
return pq.view(1, -1)
177+
if return_sq_and_rq:
178+
return torch.stack((pq_avg, sq_avg, rq_avg), dim=0)
179+
return pq_avg
105180

106181

107182
def modified_panoptic_quality(
@@ -177,4 +252,5 @@ def modified_panoptic_quality(
177252
void_color,
178253
modified_metric_stuffs=stuffs,
179254
)
180-
return _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives)
255+
_, _, _, pq_avg, _, _ = _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives)
256+
return pq_avg

tests/unittests/detection/test_modified_panoptic_quality.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from torchmetrics.detection import ModifiedPanopticQuality
2020
from torchmetrics.functional.detection import modified_panoptic_quality
21+
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
2122

2223
from unittests import _Input
2324
from unittests._helpers import seed_all
@@ -76,6 +77,7 @@ def _reference_fn_1_2(preds, target) -> np.ndarray:
7677
return np.array([23 / 30])
7778

7879

80+
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
7981
class TestModifiedPanopticQuality(MetricTester):
8082
"""Test class for `ModifiedPanopticQuality` metric."""
8183

@@ -111,6 +113,7 @@ def test_panoptic_quality_functional(self):
111113
)
112114

113115

116+
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
114117
def test_empty_metric():
115118
"""Test empty metric."""
116119
with pytest.raises(ValueError, match="At least one of `things` and `stuffs` must be non-empty"):
@@ -120,6 +123,7 @@ def test_empty_metric():
120123
assert torch.isnan(metric.compute())
121124

122125

126+
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
123127
def test_error_on_wrong_input():
124128
"""Test class input validation."""
125129
with pytest.raises(TypeError, match="Expected argument `stuffs` to contain `int` categories.*"):
@@ -162,6 +166,7 @@ def test_error_on_wrong_input():
162166
metric.update(preds, preds)
163167

164168

169+
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
165170
def test_extreme_values():
166171
"""Test that the metric returns expected values in trivial cases."""
167172
# Exact match between preds and target => metric is 1
@@ -170,6 +175,7 @@ def test_extreme_values():
170175
assert modified_panoptic_quality(_INPUTS_0.target[0], _INPUTS_0.target[0] + 1, **_ARGS_0) == 0.0
171176

172177

178+
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
173179
@pytest.mark.parametrize(
174180
("inputs", "args", "cat_dim"),
175181
[

0 commit comments

Comments
 (0)