Skip to content

Commit

Permalink
Merge pull request #5051 from neutrinoceros/rfc/match_geometry
Browse files Browse the repository at this point in the history
RFC: refactor geometry matching clauses with Python 3.10's pattern matching
  • Loading branch information
neutrinoceros authored Nov 14, 2024
2 parents 8a23215 + ad90a6b commit 9e483ce
Show file tree
Hide file tree
Showing 10 changed files with 355 additions and 359 deletions.
68 changes: 34 additions & 34 deletions yt/data_objects/selection_objects/data_selection_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,23 +609,22 @@ def to_frb(self, width, resolution, center=None, height=None, periodic=False):
>>> write_image(np.log10(frb["gas", "density"]), "density_100kpc.png")
"""

if (self.ds.geometry is Geometry.CYLINDRICAL and self.axis == 1) or (
self.ds.geometry is Geometry.POLAR and self.axis == 2
):
if center is not None and center != (0.0, 0.0):
raise NotImplementedError(
"Currently we only support images centered at R=0. "
+ "We plan to generalize this in the near future"
match (self.ds.geometry, self.axis):
case (Geometry.CYLINDRICAL, 1) | (Geometry.POLAR, 2):
if center is not None and center != (0.0, 0.0):
raise NotImplementedError(
"Currently we only support images centered at R=0. "
+ "We plan to generalize this in the near future"
)
from yt.visualization.fixed_resolution import (
CylindricalFixedResolutionBuffer,
)
from yt.visualization.fixed_resolution import (
CylindricalFixedResolutionBuffer,
)

validate_width_tuple(width)
if is_sequence(resolution):
resolution = max(resolution)
frb = CylindricalFixedResolutionBuffer(self, width, resolution)
return frb
validate_width_tuple(width)
if is_sequence(resolution):
resolution = max(resolution)
frb = CylindricalFixedResolutionBuffer(self, width, resolution)
return frb

if center is None:
center = self.center
Expand Down Expand Up @@ -1401,25 +1400,26 @@ def get_bbox(self) -> tuple[unyt_array, unyt_array]:
"""
Return the bounding box for this data container.
"""
geometry: Geometry = self.ds.geometry
if geometry is Geometry.CARTESIAN:
le, re = self._get_bbox()
le.convert_to_units("code_length")
re.convert_to_units("code_length")
return le, re
elif (
geometry is Geometry.CYLINDRICAL
or geometry is Geometry.POLAR
or geometry is Geometry.SPHERICAL
or geometry is Geometry.GEOGRAPHIC
or geometry is Geometry.INTERNAL_GEOGRAPHIC
or geometry is Geometry.SPECTRAL_CUBE
):
raise NotImplementedError(
f"get_bbox is currently not implemented for {geometry=}!"
)
else:
assert_never(geometry)
match self.ds.geometry:
case Geometry.CARTESIAN:
le, re = self._get_bbox()
le.convert_to_units("code_length")
re.convert_to_units("code_length")
return le, re
case (
Geometry.CYLINDRICAL
| Geometry.POLAR
| Geometry.SPHERICAL
| Geometry.GEOGRAPHIC
| Geometry.INTERNAL_GEOGRAPHIC
| Geometry.SPECTRAL_CUBE
):
geometry = self.ds.geometry
raise NotImplementedError(
f"get_bbox is currently not implemented for {geometry=}!"
)
case _:
assert_never(self.ds.geometry)

def volume(self):
"""
Expand Down
51 changes: 26 additions & 25 deletions yt/data_objects/static_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,28 +798,29 @@ def _setup_coordinate_handler(self, axis_order: AxisOrder | None) -> None:
f"Got {self.geometry=} with type {type(self.geometry)}"
)

if self.geometry is Geometry.CARTESIAN:
cls = CartesianCoordinateHandler
elif self.geometry is Geometry.CYLINDRICAL:
cls = CylindricalCoordinateHandler
elif self.geometry is Geometry.POLAR:
cls = PolarCoordinateHandler
elif self.geometry is Geometry.SPHERICAL:
cls = SphericalCoordinateHandler
# It shouldn't be required to reset self.no_cgs_equiv_length
# to the default value (False) here, but it's still necessary
# see https://github.com/yt-project/yt/pull/3618
self.no_cgs_equiv_length = False
elif self.geometry is Geometry.GEOGRAPHIC:
cls = GeographicCoordinateHandler
self.no_cgs_equiv_length = True
elif self.geometry is Geometry.INTERNAL_GEOGRAPHIC:
cls = InternalGeographicCoordinateHandler
self.no_cgs_equiv_length = True
elif self.geometry is Geometry.SPECTRAL_CUBE:
cls = SpectralCubeCoordinateHandler
else:
assert_never(self.geometry)
match self.geometry:
case Geometry.CARTESIAN:
cls = CartesianCoordinateHandler
case Geometry.CYLINDRICAL:
cls = CylindricalCoordinateHandler
case Geometry.POLAR:
cls = PolarCoordinateHandler
case Geometry.SPHERICAL:
cls = SphericalCoordinateHandler
# It shouldn't be required to reset self.no_cgs_equiv_length
# to the default value (False) here, but it's still necessary
# see https://github.com/yt-project/yt/pull/3618
self.no_cgs_equiv_length = False
case Geometry.GEOGRAPHIC:
cls = GeographicCoordinateHandler
self.no_cgs_equiv_length = True
case Geometry.INTERNAL_GEOGRAPHIC:
cls = InternalGeographicCoordinateHandler
self.no_cgs_equiv_length = True
case Geometry.SPECTRAL_CUBE:
cls = SpectralCubeCoordinateHandler
case _:
assert_never(self.geometry)

self.coordinates = cls(self, ordering=axis_order)

Expand Down Expand Up @@ -1948,9 +1949,9 @@ def add_gradient_fields(self, fields=None):
... ("gas", "density_gradient_magnitude"),
... ]
Note that the above example assumes ds.geometry == 'cartesian'. In general,
the function will create gradient components along the axes of the dataset
coordinate system.
Note that the above example assumes ds.geometry is Geometry.CARTESIAN.
In general, the function will create gradient components along the axes
of the dataset coordinate system.
For instance, with cylindrical data, one gets 'density_gradient_<r,theta,z>'
"""
Expand Down
101 changes: 47 additions & 54 deletions yt/fields/field_info_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,29 +222,26 @@ def get_aliases_gallery(self) -> list[FieldName]:
if self.ds is None:
return aliases_gallery

geometry: Geometry = self.ds.geometry
if (
geometry is Geometry.POLAR
or geometry is Geometry.CYLINDRICAL
or geometry is Geometry.SPHERICAL
):
aliases: list[FieldName]
for field in sorted(self.field_list):
if field[0] in self.ds.particle_types:
continue
args = known_other_fields.get(field[1], ("", [], None))
units, aliases, display_name = args
aliases_gallery.extend(aliases)
elif (
geometry is Geometry.CARTESIAN
or geometry is Geometry.GEOGRAPHIC
or geometry is Geometry.INTERNAL_GEOGRAPHIC
or geometry is Geometry.SPECTRAL_CUBE
):
# nothing to do
pass
else:
assert_never(geometry)
match self.ds.geometry:
case Geometry.POLAR | Geometry.CYLINDRICAL | Geometry.SPHERICAL:
aliases: list[FieldName]
for field in sorted(self.field_list):
if field[0] in self.ds.particle_types:
continue
args = known_other_fields.get(field[1], ("", [], None))
units, aliases, display_name = args
aliases_gallery.extend(aliases)
case (
Geometry.CARTESIAN
| Geometry.GEOGRAPHIC
| Geometry.INTERNAL_GEOGRAPHIC
| Geometry.SPECTRAL_CUBE
):
# nothing to do
pass
case _:
assert_never(self.ds.geometry)

return aliases_gallery

def setup_fluid_aliases(self, ftype: FieldType = "gas") -> None:
Expand Down Expand Up @@ -280,38 +277,34 @@ def setup_fluid_aliases(self, ftype: FieldType = "gas") -> None:
field, sampling_type="cell", units=units, display_name=display_name
)
axis_names = self.ds.coordinates.axis_order
geometry: Geometry = self.ds.geometry
for alias in aliases:
if (
geometry is Geometry.POLAR
or geometry is Geometry.CYLINDRICAL
or geometry is Geometry.SPHERICAL
):
if alias[-2:] not in ["_x", "_y", "_z"]:
to_convert = False
else:
for suffix in ["x", "y", "z"]:
if f"{alias[:-2]}_{suffix}" not in aliases_gallery:
to_convert = False
break
to_convert = True
if to_convert:
if alias[-2:] == "_x":
alias = f"{alias[:-2]}_{axis_names[0]}"
elif alias[-2:] == "_y":
alias = f"{alias[:-2]}_{axis_names[1]}"
elif alias[-2:] == "_z":
alias = f"{alias[:-2]}_{axis_names[2]}"
elif (
geometry is Geometry.CARTESIAN
or geometry is Geometry.GEOGRAPHIC
or geometry is Geometry.INTERNAL_GEOGRAPHIC
or geometry is Geometry.SPECTRAL_CUBE
):
# nothing to do
pass
else:
assert_never(geometry)
match self.ds.geometry:
case Geometry.POLAR | Geometry.CYLINDRICAL | Geometry.SPHERICAL:
if alias[-2:] not in ["_x", "_y", "_z"]:
to_convert = False
else:
for suffix in ["x", "y", "z"]:
if f"{alias[:-2]}_{suffix}" not in aliases_gallery:
to_convert = False
break
to_convert = True
if to_convert:
if alias[-2:] == "_x":
alias = f"{alias[:-2]}_{axis_names[0]}"
elif alias[-2:] == "_y":
alias = f"{alias[:-2]}_{axis_names[1]}"
elif alias[-2:] == "_z":
alias = f"{alias[:-2]}_{axis_names[2]}"
case (
Geometry.CARTESIAN
| Geometry.GEOGRAPHIC
| Geometry.INTERNAL_GEOGRAPHIC
| Geometry.SPECTRAL_CUBE
):
# nothing to do
pass
case _:
assert_never(self.ds.geometry)
self.alias((ftype, alias), field)

@staticmethod
Expand Down
Loading

0 comments on commit 9e483ce

Please sign in to comment.