Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: fix type-checking incompatibilities with matplotlib 3.8 #4629

Merged
merged 1 commit into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions yt/visualization/_handlers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import weakref
from numbers import Real
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
Expand All @@ -12,6 +13,11 @@
from yt.config import ytcfg
from yt.funcs import get_brewer_cmap, is_sequence, mylog

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias


class NormHandler:
"""
Expand Down Expand Up @@ -73,7 +79,7 @@ def __init__(
self._linthresh = linthresh
self.prefer_log = True

if self.has_norm and self.has_constraints:
if self.norm is not None and self.has_constraints:
raise TypeError(
"NormHandler input is malformed. "
"A norm cannot be passed along other constraints."
Expand All @@ -100,12 +106,8 @@ def _reset_constraints(self) -> None:
for name in constraints.keys():
setattr(self, name, None)

@property
def has_norm(self) -> bool:
return self._norm is not None

def _reset_norm(self):
if not self.has_norm:
def _reset_norm(self) -> None:
if self.norm is None:
return
mylog.warning("Dropping norm (%s)", self.norm)
self._norm = None
Expand Down Expand Up @@ -281,7 +283,7 @@ def linthresh(self, newval: Optional[Union[Quantity, float]]) -> None:
self.norm_type = SymLogNorm

def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize:
if self.has_norm:
if self.norm is not None:
return self.norm

dvmin = dvmax = None
Expand Down Expand Up @@ -312,6 +314,7 @@ def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize:
dvmax = 1 * getattr(data, "units", 1)
kw.setdefault("vmax", dvmax)

norm_type: Type[Normalize]
if data.ndim == 3:
assert data.shape[-1] == 4
# this is an RGBA array, only linear normalization makes sense here
Expand Down Expand Up @@ -395,6 +398,16 @@ def get_minmax(data):
return linthresh


BackgroundColor: TypeAlias = Union[
Tuple[float, float, float, float],
# np.ndarray is only runtime-subscribtable since numpy 1.22
"np.ndarray[Any, Any]",
str,
None,
]
ColormapInput: TypeAlias = Union[Colormap, str, None]


class ColorbarHandler:
__slots__ = ("_draw_cbar", "_draw_minorticks", "_cmap", "_background_color")

Expand All @@ -403,14 +416,14 @@ def __init__(
*,
draw_cbar: bool = True,
draw_minorticks: bool = True,
cmap: Optional[Union[Colormap, str]] = None,
cmap: ColormapInput = None,
background_color: Optional[str] = None,
):
self._draw_cbar = draw_cbar
self._draw_minorticks = draw_minorticks
self._cmap: Optional[Colormap] = None
self.cmap = cmap
self._background_color = background_color
self._set_cmap(cmap)
self._background_color: BackgroundColor = background_color

@property
def draw_cbar(self) -> bool:
Expand Down Expand Up @@ -441,7 +454,13 @@ def cmap(self) -> Colormap:
return self._cmap or mpl.colormaps[ytcfg.get("yt", "default_colormap")]

@cmap.setter
def cmap(self, newval) -> None:
def cmap(self, newval: ColormapInput) -> None:
self._set_cmap(newval)

def _set_cmap(self, newval: ColormapInput) -> None:
# a separate setter function is better supported by type checkers (mypy)
# than relying purely on a property setter to narrow type
# from ColormapInput to Colormap
if isinstance(newval, Colormap) or newval is None:
self._cmap = newval
elif isinstance(newval, str):
Expand All @@ -456,11 +475,11 @@ def cmap(self, newval) -> None:
)

@property
def background_color(self) -> Any:
def background_color(self) -> BackgroundColor:
return self._background_color or "white"

@background_color.setter
def background_color(self, newval: Any):
def background_color(self, newval: BackgroundColor) -> None:
# not attempting to constrain types here because
# down the line it really depends on matplotlib.axes.Axes.set_faceolor
# which is very type-flexibile
Expand Down
39 changes: 29 additions & 10 deletions yt/visualization/base_plot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import warnings
from abc import ABC
from io import BytesIO
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, TypedDict, Union

import matplotlib
import numpy as np
from matplotlib.scale import SymmetricalLogTransform
from matplotlib.ticker import LogFormatterMathtext

from yt.funcs import (
Expand All @@ -28,9 +29,18 @@
from ._commons import _MPL38_SymmetricalLogLocator as SymmetricalLogLocator

if TYPE_CHECKING:
from typing import Literal

from matplotlib.axes import Axes
from matplotlib.axis import Axis
from matplotlib.figure import Figure

class FormatKwargs(TypedDict):
style: Literal["scientific"]
scilimits: Tuple[int, int]
useMathText: bool


BACKEND_SPECS = {
"GTK": ["backend_gtk", "FigureCanvasGTK", "FigureManagerGTK"],
"GTKAgg": ["backend_gtkagg", "FigureCanvasGTKAgg", None],
Expand Down Expand Up @@ -96,11 +106,11 @@ class PlotMPL:
def __init__(
self,
fsize,
axrect,
axrect: Tuple[float, float, float, float],
*,
norm_handler: NormHandler,
figure: Optional["Figure"] = None,
axes: Optional["Axis"] = None,
axes: Optional["Axes"] = None,
):
"""Initialize PlotMPL class"""
import matplotlib.figure
Expand All @@ -116,7 +126,7 @@ def __init__(
if axes is None:
self._create_axes(axrect)
else:
axes.cla()
axes.clear()
axes.set_position(axrect)
self.axes = axes
self.interactivity = get_interactivity()
Expand All @@ -132,7 +142,7 @@ def __init__(

self.norm_handler = norm_handler

def _create_axes(self, axrect):
def _create_axes(self, axrect: Tuple[float, float, float, float]) -> None:
self.axes = self.figure.add_axes(axrect)

def _get_canvas_classes(self):
Expand Down Expand Up @@ -231,8 +241,8 @@ def __init__(
norm_handler: NormHandler,
colorbar_handler: ColorbarHandler,
figure: Optional["Figure"] = None,
axes: Optional["Axis"] = None,
cax: Optional["Axis"] = None,
axes: Optional["Axes"] = None,
cax: Optional["Axes"] = None,
):
"""Initialize ImagePlotMPL class object"""
self.colorbar_handler = colorbar_handler
Expand All @@ -252,7 +262,7 @@ def __init__(
if cax is None:
self.cax = self.figure.add_axes(caxrect)
else:
cax.cla()
cax.clear()
cax.set_position(caxrect)
self.cax = cax

Expand Down Expand Up @@ -316,13 +326,18 @@ def _init_image(self, data, extent, aspect):
self._set_axes()

def _set_axes(self) -> None:
fmt_kwargs = {"style": "scientific", "scilimits": (-2, 3), "useMathText": True}
fmt_kwargs: "FormatKwargs" = {
"style": "scientific",
"scilimits": (-2, 3),
"useMathText": True,
}
self.image.axes.ticklabel_format(**fmt_kwargs)
self.image.axes.set_facecolor(self.colorbar_handler.background_color)

self.cax.tick_params(which="both", direction="in")
self.cb = self.figure.colorbar(self.image, self.cax)

cb_axis: "Axis"
if self.cb.orientation == "vertical":
cb_axis = self.cb.ax.yaxis
else:
Expand All @@ -331,6 +346,8 @@ def _set_axes(self) -> None:
cb_scale = cb_axis.get_scale()
if cb_scale == "symlog":
trf = cb_axis.get_transform()
if not isinstance(trf, SymmetricalLogTransform):
raise RuntimeError
cb_axis.set_major_locator(SymmetricalLogLocator(trf))
cb_axis.set_major_formatter(
LogFormatterMathtext(linthresh=trf.linthresh, base=trf.base)
Expand All @@ -343,8 +360,10 @@ def _set_axes(self) -> None:
# no minor ticks are drawn by default in symlog, as of matplotlib 3.7.1
# see https://github.com/matplotlib/matplotlib/issues/25994
trf = cb_axis.get_transform()
if not isinstance(trf, SymmetricalLogTransform):
raise RuntimeError
if float(trf.base).is_integer():
locator = SymmetricalLogLocator(trf, subs=np.arange(1, trf.base))
locator = SymmetricalLogLocator(trf, subs=list(range(1, int(trf.base))))
cb_axis.set_minor_locator(locator)
elif self.colorbar_handler.draw_minorticks:
self.cb.minorticks_on()
Expand Down
2 changes: 1 addition & 1 deletion yt/visualization/volume_rendering/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def save(
fig = Figure((shape[0] / 100.0, shape[1] / 100.0))
canvas = get_canvas(fig, fname)

ax = fig.add_axes([0, 0, 1, 1])
ax = fig.add_axes((0, 0, 1, 1))
ax.set_axis_off()
out = self._last_render
if sigma_clip is not None:
Expand Down