Skip to content

Commit

Permalink
TYP: fix some type-check incompatibility with matplotlib 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Aug 11, 2023
1 parent e8fc609 commit 715f441
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
41 changes: 27 additions & 14 deletions yt/visualization/_handlers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import weakref
from numbers import Real
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
Expand All @@ -12,6 +13,11 @@
from yt.config import ytcfg
from yt.funcs import get_brewer_cmap, is_sequence, mylog

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias


class NormHandler:
"""
Expand Down Expand Up @@ -73,7 +79,7 @@ def __init__(
self._linthresh = linthresh
self.prefer_log = True

if self.has_norm and self.has_constraints:
if self.norm is not None and self.has_constraints:
raise TypeError(
"NormHandler input is malformed. "
"A norm cannot be passed along other constraints."
Expand All @@ -100,12 +106,8 @@ def _reset_constraints(self) -> None:
for name in constraints.keys():
setattr(self, name, None)

@property
def has_norm(self) -> bool:
return self._norm is not None

def _reset_norm(self):
if not self.has_norm:
def _reset_norm(self) -> None:
if self.norm is None:
return
mylog.warning("Dropping norm (%s)", self.norm)
self._norm = None
Expand Down Expand Up @@ -281,7 +283,7 @@ def linthresh(self, newval: Optional[Union[Quantity, float]]) -> None:
self.norm_type = SymLogNorm

def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize:
if self.has_norm:
if self.norm is not None:
return self.norm

dvmin = dvmax = None
Expand Down Expand Up @@ -312,6 +314,7 @@ def get_norm(self, data: np.ndarray, *args, **kw) -> Normalize:
dvmax = 1 * getattr(data, "units", 1)
kw.setdefault("vmax", dvmax)

norm_type: Type[Normalize]
if data.ndim == 3:
assert data.shape[-1] == 4
# this is an RGBA array, only linear normalization makes sense here
Expand Down Expand Up @@ -395,6 +398,16 @@ def get_minmax(data):
return linthresh


BackgroundColor: TypeAlias = Union[
Tuple[float, float, float, float],
# np.ndarray is only runtime-subscribtable since numpy 1.22
"np.ndarray[Any, Any]",
str,
None,
]
ColormapInput: TypeAlias = Union[Colormap, str, None]


class ColorbarHandler:
__slots__ = ("_draw_cbar", "_draw_minorticks", "_cmap", "_background_color")

Expand All @@ -403,14 +416,14 @@ def __init__(
*,
draw_cbar: bool = True,
draw_minorticks: bool = True,
cmap: Optional[Union[Colormap, str]] = None,
cmap: ColormapInput = None,
background_color: Optional[str] = None,
):
self._draw_cbar = draw_cbar
self._draw_minorticks = draw_minorticks
self._cmap: Optional[Colormap] = None
self.cmap = cmap
self._background_color = background_color
self._background_color: BackgroundColor = background_color

@property
def draw_cbar(self) -> bool:
Expand Down Expand Up @@ -441,12 +454,12 @@ def cmap(self) -> Colormap:
return self._cmap or mpl.colormaps[ytcfg.get("yt", "default_colormap")]

@cmap.setter
def cmap(self, newval) -> None:
def cmap(self, newval: ColormapInput) -> None:
if isinstance(newval, Colormap) or newval is None:
self._cmap = newval
elif isinstance(newval, str):
self._cmap = mpl.colormaps[newval]
elif is_sequence(newval):
elif is_sequence(newval): # type: ignore [unreachable]
# tuple colormaps are from palettable (or brewer2mpl)
self._cmap = get_brewer_cmap(newval)
else:
Expand All @@ -456,11 +469,11 @@ def cmap(self, newval) -> None:
)

@property
def background_color(self) -> Any:
def background_color(self) -> BackgroundColor:
return self._background_color or "white"

@background_color.setter
def background_color(self, newval: Any):
def background_color(self, newval: BackgroundColor) -> None:
# not attempting to constrain types here because
# down the line it really depends on matplotlib.axes.Axes.set_faceolor
# which is very type-flexibile
Expand Down
17 changes: 12 additions & 5 deletions yt/visualization/base_plot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import matplotlib
import numpy as np
from matplotlib.scale import SymmetricalLogTransform
from matplotlib.ticker import LogFormatterMathtext

from yt.funcs import (
Expand All @@ -28,6 +29,7 @@
from ._commons import _MPL38_SymmetricalLogLocator as SymmetricalLogLocator

if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.axis import Axis
from matplotlib.figure import Figure

Expand Down Expand Up @@ -96,11 +98,11 @@ class PlotMPL:
def __init__(
self,
fsize,
axrect,
axrect: Tuple[float, float, float, float],
*,
norm_handler: NormHandler,
figure: Optional["Figure"] = None,
axes: Optional["Axis"] = None,
axes: Optional["Axes"] = None,
):
"""Initialize PlotMPL class"""
import matplotlib.figure
Expand Down Expand Up @@ -132,7 +134,7 @@ def __init__(

self.norm_handler = norm_handler

def _create_axes(self, axrect):
def _create_axes(self, axrect: Tuple[float, float, float, float]) -> None:
self.axes = self.figure.add_axes(axrect)

def _get_canvas_classes(self):
Expand Down Expand Up @@ -231,8 +233,8 @@ def __init__(
norm_handler: NormHandler,
colorbar_handler: ColorbarHandler,
figure: Optional["Figure"] = None,
axes: Optional["Axis"] = None,
cax: Optional["Axis"] = None,
axes: Optional["Axes"] = None,
cax: Optional["Axes"] = None,
):
"""Initialize ImagePlotMPL class object"""
self.colorbar_handler = colorbar_handler
Expand Down Expand Up @@ -323,6 +325,7 @@ def _set_axes(self) -> None:
self.cax.tick_params(which="both", direction="in")
self.cb = self.figure.colorbar(self.image, self.cax)

cb_axis: "Axis"
if self.cb.orientation == "vertical":
cb_axis = self.cb.ax.yaxis
else:
Expand All @@ -331,6 +334,8 @@ def _set_axes(self) -> None:
cb_scale = cb_axis.get_scale()
if cb_scale == "symlog":
trf = cb_axis.get_transform()
if not isinstance(trf, SymmetricalLogTransform):
raise RuntimeError
cb_axis.set_major_locator(SymmetricalLogLocator(trf))
cb_axis.set_major_formatter(
LogFormatterMathtext(linthresh=trf.linthresh, base=trf.base)
Expand All @@ -343,6 +348,8 @@ def _set_axes(self) -> None:
# no minor ticks are drawn by default in symlog, as of matplotlib 3.7.1
# see https://github.com/matplotlib/matplotlib/issues/25994
trf = cb_axis.get_transform()
if not isinstance(trf, SymmetricalLogTransform):
raise RuntimeError
if float(trf.base).is_integer():
locator = SymmetricalLogLocator(trf, subs=np.arange(1, trf.base))
cb_axis.set_minor_locator(locator)
Expand Down

0 comments on commit 715f441

Please sign in to comment.