Skip to content

Commit 9e483ce

Browse files
Merge pull request #5051 from neutrinoceros/rfc/match_geometry
RFC: refactor geometry matching clauses with Python 3.10's pattern matching
2 parents 8a23215 + ad90a6b commit 9e483ce

File tree

10 files changed

+355
-359
lines changed

10 files changed

+355
-359
lines changed

yt/data_objects/selection_objects/data_selection_objects.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -609,23 +609,22 @@ def to_frb(self, width, resolution, center=None, height=None, periodic=False):
609609
>>> write_image(np.log10(frb["gas", "density"]), "density_100kpc.png")
610610
"""
611611

612-
if (self.ds.geometry is Geometry.CYLINDRICAL and self.axis == 1) or (
613-
self.ds.geometry is Geometry.POLAR and self.axis == 2
614-
):
615-
if center is not None and center != (0.0, 0.0):
616-
raise NotImplementedError(
617-
"Currently we only support images centered at R=0. "
618-
+ "We plan to generalize this in the near future"
612+
match (self.ds.geometry, self.axis):
613+
case (Geometry.CYLINDRICAL, 1) | (Geometry.POLAR, 2):
614+
if center is not None and center != (0.0, 0.0):
615+
raise NotImplementedError(
616+
"Currently we only support images centered at R=0. "
617+
+ "We plan to generalize this in the near future"
618+
)
619+
from yt.visualization.fixed_resolution import (
620+
CylindricalFixedResolutionBuffer,
619621
)
620-
from yt.visualization.fixed_resolution import (
621-
CylindricalFixedResolutionBuffer,
622-
)
623622

624-
validate_width_tuple(width)
625-
if is_sequence(resolution):
626-
resolution = max(resolution)
627-
frb = CylindricalFixedResolutionBuffer(self, width, resolution)
628-
return frb
623+
validate_width_tuple(width)
624+
if is_sequence(resolution):
625+
resolution = max(resolution)
626+
frb = CylindricalFixedResolutionBuffer(self, width, resolution)
627+
return frb
629628

630629
if center is None:
631630
center = self.center
@@ -1401,25 +1400,26 @@ def get_bbox(self) -> tuple[unyt_array, unyt_array]:
14011400
"""
14021401
Return the bounding box for this data container.
14031402
"""
1404-
geometry: Geometry = self.ds.geometry
1405-
if geometry is Geometry.CARTESIAN:
1406-
le, re = self._get_bbox()
1407-
le.convert_to_units("code_length")
1408-
re.convert_to_units("code_length")
1409-
return le, re
1410-
elif (
1411-
geometry is Geometry.CYLINDRICAL
1412-
or geometry is Geometry.POLAR
1413-
or geometry is Geometry.SPHERICAL
1414-
or geometry is Geometry.GEOGRAPHIC
1415-
or geometry is Geometry.INTERNAL_GEOGRAPHIC
1416-
or geometry is Geometry.SPECTRAL_CUBE
1417-
):
1418-
raise NotImplementedError(
1419-
f"get_bbox is currently not implemented for {geometry=}!"
1420-
)
1421-
else:
1422-
assert_never(geometry)
1403+
match self.ds.geometry:
1404+
case Geometry.CARTESIAN:
1405+
le, re = self._get_bbox()
1406+
le.convert_to_units("code_length")
1407+
re.convert_to_units("code_length")
1408+
return le, re
1409+
case (
1410+
Geometry.CYLINDRICAL
1411+
| Geometry.POLAR
1412+
| Geometry.SPHERICAL
1413+
| Geometry.GEOGRAPHIC
1414+
| Geometry.INTERNAL_GEOGRAPHIC
1415+
| Geometry.SPECTRAL_CUBE
1416+
):
1417+
geometry = self.ds.geometry
1418+
raise NotImplementedError(
1419+
f"get_bbox is currently not implemented for {geometry=}!"
1420+
)
1421+
case _:
1422+
assert_never(self.ds.geometry)
14231423

14241424
def volume(self):
14251425
"""

yt/data_objects/static_output.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -798,28 +798,29 @@ def _setup_coordinate_handler(self, axis_order: AxisOrder | None) -> None:
798798
f"Got {self.geometry=} with type {type(self.geometry)}"
799799
)
800800

801-
if self.geometry is Geometry.CARTESIAN:
802-
cls = CartesianCoordinateHandler
803-
elif self.geometry is Geometry.CYLINDRICAL:
804-
cls = CylindricalCoordinateHandler
805-
elif self.geometry is Geometry.POLAR:
806-
cls = PolarCoordinateHandler
807-
elif self.geometry is Geometry.SPHERICAL:
808-
cls = SphericalCoordinateHandler
809-
# It shouldn't be required to reset self.no_cgs_equiv_length
810-
# to the default value (False) here, but it's still necessary
811-
# see https://github.com/yt-project/yt/pull/3618
812-
self.no_cgs_equiv_length = False
813-
elif self.geometry is Geometry.GEOGRAPHIC:
814-
cls = GeographicCoordinateHandler
815-
self.no_cgs_equiv_length = True
816-
elif self.geometry is Geometry.INTERNAL_GEOGRAPHIC:
817-
cls = InternalGeographicCoordinateHandler
818-
self.no_cgs_equiv_length = True
819-
elif self.geometry is Geometry.SPECTRAL_CUBE:
820-
cls = SpectralCubeCoordinateHandler
821-
else:
822-
assert_never(self.geometry)
801+
match self.geometry:
802+
case Geometry.CARTESIAN:
803+
cls = CartesianCoordinateHandler
804+
case Geometry.CYLINDRICAL:
805+
cls = CylindricalCoordinateHandler
806+
case Geometry.POLAR:
807+
cls = PolarCoordinateHandler
808+
case Geometry.SPHERICAL:
809+
cls = SphericalCoordinateHandler
810+
# It shouldn't be required to reset self.no_cgs_equiv_length
811+
# to the default value (False) here, but it's still necessary
812+
# see https://github.com/yt-project/yt/pull/3618
813+
self.no_cgs_equiv_length = False
814+
case Geometry.GEOGRAPHIC:
815+
cls = GeographicCoordinateHandler
816+
self.no_cgs_equiv_length = True
817+
case Geometry.INTERNAL_GEOGRAPHIC:
818+
cls = InternalGeographicCoordinateHandler
819+
self.no_cgs_equiv_length = True
820+
case Geometry.SPECTRAL_CUBE:
821+
cls = SpectralCubeCoordinateHandler
822+
case _:
823+
assert_never(self.geometry)
823824

824825
self.coordinates = cls(self, ordering=axis_order)
825826

@@ -1948,9 +1949,9 @@ def add_gradient_fields(self, fields=None):
19481949
... ("gas", "density_gradient_magnitude"),
19491950
... ]
19501951
1951-
Note that the above example assumes ds.geometry == 'cartesian'. In general,
1952-
the function will create gradient components along the axes of the dataset
1953-
coordinate system.
1952+
Note that the above example assumes ds.geometry is Geometry.CARTESIAN.
1953+
In general, the function will create gradient components along the axes
1954+
of the dataset coordinate system.
19541955
For instance, with cylindrical data, one gets 'density_gradient_<r,theta,z>'
19551956
19561957
"""

yt/fields/field_info_container.py

Lines changed: 47 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -222,29 +222,26 @@ def get_aliases_gallery(self) -> list[FieldName]:
222222
if self.ds is None:
223223
return aliases_gallery
224224

225-
geometry: Geometry = self.ds.geometry
226-
if (
227-
geometry is Geometry.POLAR
228-
or geometry is Geometry.CYLINDRICAL
229-
or geometry is Geometry.SPHERICAL
230-
):
231-
aliases: list[FieldName]
232-
for field in sorted(self.field_list):
233-
if field[0] in self.ds.particle_types:
234-
continue
235-
args = known_other_fields.get(field[1], ("", [], None))
236-
units, aliases, display_name = args
237-
aliases_gallery.extend(aliases)
238-
elif (
239-
geometry is Geometry.CARTESIAN
240-
or geometry is Geometry.GEOGRAPHIC
241-
or geometry is Geometry.INTERNAL_GEOGRAPHIC
242-
or geometry is Geometry.SPECTRAL_CUBE
243-
):
244-
# nothing to do
245-
pass
246-
else:
247-
assert_never(geometry)
225+
match self.ds.geometry:
226+
case Geometry.POLAR | Geometry.CYLINDRICAL | Geometry.SPHERICAL:
227+
aliases: list[FieldName]
228+
for field in sorted(self.field_list):
229+
if field[0] in self.ds.particle_types:
230+
continue
231+
args = known_other_fields.get(field[1], ("", [], None))
232+
units, aliases, display_name = args
233+
aliases_gallery.extend(aliases)
234+
case (
235+
Geometry.CARTESIAN
236+
| Geometry.GEOGRAPHIC
237+
| Geometry.INTERNAL_GEOGRAPHIC
238+
| Geometry.SPECTRAL_CUBE
239+
):
240+
# nothing to do
241+
pass
242+
case _:
243+
assert_never(self.ds.geometry)
244+
248245
return aliases_gallery
249246

250247
def setup_fluid_aliases(self, ftype: FieldType = "gas") -> None:
@@ -280,38 +277,34 @@ def setup_fluid_aliases(self, ftype: FieldType = "gas") -> None:
280277
field, sampling_type="cell", units=units, display_name=display_name
281278
)
282279
axis_names = self.ds.coordinates.axis_order
283-
geometry: Geometry = self.ds.geometry
284280
for alias in aliases:
285-
if (
286-
geometry is Geometry.POLAR
287-
or geometry is Geometry.CYLINDRICAL
288-
or geometry is Geometry.SPHERICAL
289-
):
290-
if alias[-2:] not in ["_x", "_y", "_z"]:
291-
to_convert = False
292-
else:
293-
for suffix in ["x", "y", "z"]:
294-
if f"{alias[:-2]}_{suffix}" not in aliases_gallery:
295-
to_convert = False
296-
break
297-
to_convert = True
298-
if to_convert:
299-
if alias[-2:] == "_x":
300-
alias = f"{alias[:-2]}_{axis_names[0]}"
301-
elif alias[-2:] == "_y":
302-
alias = f"{alias[:-2]}_{axis_names[1]}"
303-
elif alias[-2:] == "_z":
304-
alias = f"{alias[:-2]}_{axis_names[2]}"
305-
elif (
306-
geometry is Geometry.CARTESIAN
307-
or geometry is Geometry.GEOGRAPHIC
308-
or geometry is Geometry.INTERNAL_GEOGRAPHIC
309-
or geometry is Geometry.SPECTRAL_CUBE
310-
):
311-
# nothing to do
312-
pass
313-
else:
314-
assert_never(geometry)
281+
match self.ds.geometry:
282+
case Geometry.POLAR | Geometry.CYLINDRICAL | Geometry.SPHERICAL:
283+
if alias[-2:] not in ["_x", "_y", "_z"]:
284+
to_convert = False
285+
else:
286+
for suffix in ["x", "y", "z"]:
287+
if f"{alias[:-2]}_{suffix}" not in aliases_gallery:
288+
to_convert = False
289+
break
290+
to_convert = True
291+
if to_convert:
292+
if alias[-2:] == "_x":
293+
alias = f"{alias[:-2]}_{axis_names[0]}"
294+
elif alias[-2:] == "_y":
295+
alias = f"{alias[:-2]}_{axis_names[1]}"
296+
elif alias[-2:] == "_z":
297+
alias = f"{alias[:-2]}_{axis_names[2]}"
298+
case (
299+
Geometry.CARTESIAN
300+
| Geometry.GEOGRAPHIC
301+
| Geometry.INTERNAL_GEOGRAPHIC
302+
| Geometry.SPECTRAL_CUBE
303+
):
304+
# nothing to do
305+
pass
306+
case _:
307+
assert_never(self.ds.geometry)
315308
self.alias((ftype, alias), field)
316309

317310
@staticmethod

0 commit comments

Comments
 (0)