Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_confusion_matrix
from torchmetrics.utilities.plot import _AX_TYPE, _CMAP_TYPE, _PLOT_OUT_TYPE, plot_confusion_matrix

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = [
Expand Down Expand Up @@ -151,6 +151,7 @@ def plot(
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[List[str]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Expand All @@ -160,6 +161,8 @@ def plot(
ax: An matplotlib axis object. If provided will add plot to that axis
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html

Returns:
Figure and Axes object
Expand All @@ -181,7 +184,7 @@ def plot(
val = val if val is not None else self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels)
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
return fig, ax


Expand Down Expand Up @@ -292,6 +295,7 @@ def plot(
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[List[str]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Expand All @@ -301,6 +305,8 @@ def plot(
ax: An matplotlib axis object. If provided will add plot to that axis
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html

Returns:
Figure and Axes object
Expand All @@ -322,7 +328,7 @@ def plot(
val = val if val is not None else self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels)
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
return fig, ax


Expand Down Expand Up @@ -436,6 +442,7 @@ def plot(
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[List[str]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Expand All @@ -445,6 +452,8 @@ def plot(
ax: An matplotlib axis object. If provided will add plot to that axis
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html

Returns:
Figure and Axes object
Expand All @@ -466,7 +475,7 @@ def plot(
val = val if val is not None else self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels)
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
return fig, ax


Expand Down
7 changes: 6 additions & 1 deletion src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@

_PLOT_OUT_TYPE = Tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]]
_AX_TYPE = matplotlib.axes.Axes
_CMAP_TYPE = Union[matplotlib.colors.Colormap, str]

style_change = plt.style.context
else:
_PLOT_OUT_TYPE = Tuple[object, object] # type: ignore[misc]
_AX_TYPE = object
_CMAP_TYPE = object # type: ignore[misc]

from contextlib import contextmanager

Expand Down Expand Up @@ -201,6 +203,7 @@ def plot_confusion_matrix(
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[List[Union[int, str]]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot an confusion matrix.

Expand All @@ -213,6 +216,8 @@ def plot_confusion_matrix(
ax: Axis from a figure. If not provided, a new figure and axis will be created
add_text: if text should be added to each cell with the given value
labels: labels to add the x- and y-axis
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html

Returns:
A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure
Expand Down Expand Up @@ -248,7 +253,7 @@ def plot_confusion_matrix(
ax = axs[i] if rows != 1 and cols != 1 else axs
if fig_label is not None:
ax.set_title(f"Label {fig_label[i]}", fontsize=15)
ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach())
ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap)
if i // cols == rows - 1: # bottom row only
ax.set_xlabel("Predicted class", fontsize=15)
if i % cols == 0: # leftmost column only
Expand Down