Skip to content

Commit 691999b

Browse files
authored
Selection for GeoDataFrame engine in plotting routines (#987)
* add engine, fix project * add test case and comments * initial updates to poly and lc * fix comment and update tests * update to_geodataframe docstring * update edge plot * add default clabel for edge plot * update docstrings in geometry functions * update to_geodataframe docstring to warn about split polygon projections * remove unused parameter * update call after removed unused argument * remove commented out bit
1 parent e491f57 commit 691999b

File tree

5 files changed

+83
-56
lines changed

5 files changed

+83
-56
lines changed

Diff for: test/test_plot.py

+22-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
import uxarray as ux
3+
import holoviews as hv
4+
35

46
from unittest import TestCase
57
from pathlib import Path
@@ -44,42 +46,33 @@ def test_face_centered_data(self):
4446
uxds = ux.open_dataset(gridfile_mpas, gridfile_mpas)
4547

4648
for backend in ['matplotlib', 'bokeh']:
47-
48-
uxds['bottomDepth'].plot(backend=backend)
49-
50-
uxds['bottomDepth'].plot.polygons(backend=backend)
51-
52-
uxds['bottomDepth'].plot.points(backend=backend)
53-
54-
uxds['bottomDepth'].plot.rasterize(method='polygon',
55-
backend=backend)
49+
assert(isinstance(uxds['bottomDepth'].plot(backend=backend), hv.DynamicMap))
50+
assert(isinstance(uxds['bottomDepth'].plot.polygons(backend=backend), hv.DynamicMap))
51+
assert(isinstance(uxds['bottomDepth'].plot.points(backend=backend), hv.Points))
5652

5753
def test_face_centered_remapped_dim(self):
5854
"""Tests execution of plotting method on a data variable whose
5955
dimension needed to be re-mapped."""
6056
uxds = ux.open_dataset(gridfile_ne30, datafile_ne30)
6157

6258
for backend in ['matplotlib', 'bokeh']:
59+
assert(isinstance(uxds['psi'].plot(backend=backend), hv.DynamicMap))
60+
assert(isinstance(uxds['psi'].plot.polygons(backend=backend), hv.DynamicMap))
61+
assert(isinstance(uxds['psi'].plot.points(backend=backend), hv.Points))
6362

64-
uxds['psi'].plot(backend=backend)
65-
66-
uxds['psi'].plot.polygons(backend=backend)
67-
68-
uxds['psi'].plot.points(backend=backend)
69-
70-
uxds['psi'].plot.rasterize(method='polygon', backend=backend)
7163

7264
def test_node_centered_data(self):
7365
"""Tests execution of plotting methods on node-centered data."""
7466

7567
uxds = ux.open_dataset(gridfile_geoflow, datafile_geoflow)
7668

7769
for backend in ['matplotlib', 'bokeh']:
78-
uxds['v1'][0][0].plot(backend=backend)
70+
assert(isinstance(uxds['v1'][0][0].plot(backend=backend), hv.Points))
7971

80-
uxds['v1'][0][0].plot.points(backend=backend)
72+
assert(isinstance(uxds['v1'][0][0].plot.points(backend=backend), hv.Points))
73+
74+
assert(isinstance(uxds['v1'][0][0].topological_mean(destination='face').plot.polygons(backend=backend), hv.DynamicMap))
8175

82-
uxds['v1'][0][0].topological_mean(destination='face').plot.polygons(backend=backend)
8376

8477

8578
def test_clabel(self):
@@ -88,9 +81,18 @@ def test_clabel(self):
8881
uxds = ux.open_dataset(gridfile_geoflow, datafile_geoflow)
8982

9083
raster_no_clabel = uxds['v1'][0][0].plot.rasterize(method='point')
91-
9284
raster_with_clabel = uxds['v1'][0][0].plot.rasterize(method='point', clabel='Foo')
9385

86+
def test_engine(self):
87+
uxds = ux.open_dataset(gridfile_mpas, gridfile_mpas)
88+
_plot_sp = uxds['bottomDepth'].plot.polygons(rasterize=True, engine='spatialpandas')
89+
_plot_gp = uxds['bottomDepth'].plot.polygons(rasterize=True, engine='geopandas')
90+
91+
assert isinstance(_plot_sp, hv.DynamicMap)
92+
assert isinstance(_plot_gp, hv.DynamicMap)
93+
94+
95+
9496
class TestXarrayMethods(TestCase):
9597

9698
def test_dataset(self):

Diff for: uxarray/core/dataarray.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,11 @@ def to_geodataframe(
154154
self,
155155
periodic_elements: Optional[str] = "exclude",
156156
projection: Optional[ccrs.Projection] = None,
157-
project: Optional[bool] = False,
158157
cache: Optional[bool] = True,
159158
override: Optional[bool] = False,
160159
engine: Optional[str] = "spatialpandas",
161160
exclude_antimeridian: Optional[bool] = None,
161+
**kwargs,
162162
):
163163
"""Constructs a ``GeoDataFrame`` consisting of polygons representing
164164
the faces of the current ``Grid`` with a face-centered data variable
@@ -178,7 +178,8 @@ def to_geodataframe(
178178
- 'split': Periodic elements will be identified and split using the ``antimeridian`` package
179179
- 'ignore': No processing will be applied to periodic elements.
180180
projection: ccrs.Projection, optional
181-
Geographic projection used to transform polygons
181+
Geographic projection used to transform polygons. Only supported when periodic_elements is set to
182+
'ignore' or 'exclude'
182183
cache: bool, optional
183184
Flag used to select whether to cache the computed GeoDataFrame
184185
override: bool, optional
@@ -191,7 +192,7 @@ def to_geodataframe(
191192
192193
Returns
193194
-------
194-
gdf : spatialpandas.GeoDataFrame
195+
gdf : spatialpandas.GeoDataFrame or geopandas.GeoDataFrame
195196
The output ``GeoDataFrame`` with a filled out "geometry" column of polygons and a data column with the
196197
same name as the ``UxDataArray`` (or named ``var`` if no name exists)
197198
"""
@@ -207,7 +208,7 @@ def to_geodataframe(
207208
gdf, non_nan_polygon_indices = self.uxgrid.to_geodataframe(
208209
periodic_elements=periodic_elements,
209210
projection=projection,
210-
project=project,
211+
project=kwargs.get("project", True),
211212
cache=cache,
212213
override=override,
213214
exclude_antimeridian=exclude_antimeridian,

Diff for: uxarray/grid/geometry.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def _correct_central_longitude(node_lon, node_lat, projection):
162162

163163
def _grid_to_polygon_geodataframe(grid, periodic_elements, projection, project, engine):
164164
"""Converts the faces of a ``Grid`` into a ``spatialpandas.GeoDataFrame``
165-
with a geometry column of polygons."""
165+
or ``geopandas.GeoDataFrame`` with a geometry column of polygons."""
166166

167167
node_lon, node_lat, central_longitude = _correct_central_longitude(
168168
grid.node_lon.values, grid.node_lat.values, projection
@@ -214,9 +214,8 @@ def _grid_to_polygon_geodataframe(grid, periodic_elements, projection, project,
214214
gdf = _build_geodataframe_with_antimeridian(
215215
polygon_shells,
216216
projected_polygon_shells,
217-
projection,
218217
antimeridian_face_indices,
219-
engine=geopandas,
218+
engine=engine,
220219
)
221220
elif periodic_elements == "ignore":
222221
if engine == "geopandas":
@@ -248,8 +247,9 @@ def _grid_to_polygon_geodataframe(grid, periodic_elements, projection, project,
248247
def _build_geodataframe_without_antimeridian(
249248
polygon_shells, projected_polygon_shells, antimeridian_face_indices, engine
250249
):
251-
"""Builds a ``spatialpandas.GeoDataFrame`` excluding any faces that cross
252-
the antimeridian."""
250+
"""Builds a ``spatialpandas.GeoDataFrame`` or
251+
``geopandas.GeoDataFrame``excluding any faces that cross the
252+
antimeridian."""
253253
if projected_polygon_shells is not None:
254254
# use projected shells if a projection is applied
255255
shells_without_antimeridian = np.delete(
@@ -276,12 +276,11 @@ def _build_geodataframe_without_antimeridian(
276276
def _build_geodataframe_with_antimeridian(
277277
polygon_shells,
278278
projected_polygon_shells,
279-
projection,
280279
antimeridian_face_indices,
281280
engine,
282281
):
283-
"""Builds a ``spatialpandas.GeoDataFrame`` including any faces that cross
284-
the antimeridian."""
282+
"""Builds a ``spatialpandas.GeoDataFrame`` or ``geopandas.GeoDataFrame``
283+
including any faces that cross the antimeridian."""
285284
polygons = _build_corrected_shapely_polygons(
286285
polygon_shells, projected_polygon_shells, antimeridian_face_indices
287286
)
@@ -425,7 +424,8 @@ def _grid_to_matplotlib_polycollection(
425424
# Handle unsupported configuration: splitting periodic elements with projection
426425
if periodic_elements == "split" and projection is not None:
427426
raise ValueError(
428-
"Projections are not supported when splitting periodic elements.'"
427+
"Explicitly projecting lines is not supported. Please pass in your projection"
428+
"using the 'transform' parameter"
429429
)
430430

431431
# Correct the central longitude and build polygon shells
@@ -533,7 +533,7 @@ def _grid_to_matplotlib_polycollection(
533533
return PolyCollection(polygon_shells, **kwargs), []
534534

535535

536-
def _get_polygons(grid, periodic_elements, projection=None):
536+
def _get_polygons(grid, periodic_elements, projection=None, apply_projection=True):
537537
# Correct the central longitude if projection is provided
538538
node_lon, node_lat, central_longitude = _correct_central_longitude(
539539
grid.node_lon.values, grid.node_lat.values, projection
@@ -552,7 +552,7 @@ def _get_polygons(grid, periodic_elements, projection=None):
552552
)
553553

554554
# If projection is provided, create the projected polygon shells
555-
if projection:
555+
if projection and apply_projection:
556556
projected_polygon_shells = _build_polygon_shells(
557557
node_lon,
558558
node_lat,
@@ -625,8 +625,14 @@ def _grid_to_matplotlib_linecollection(
625625
):
626626
"""Constructs and returns a ``matplotlib.collections.LineCollection``"""
627627

628+
if periodic_elements == "split" and projection is not None:
629+
apply_projection = False
630+
else:
631+
apply_projection = True
632+
633+
# do not explicitly project when splitting elements
628634
polygons, central_longitude, _, _ = _get_polygons(
629-
grid, periodic_elements, projection
635+
grid, periodic_elements, projection, apply_projection
630636
)
631637

632638
# Convert polygons to line segments for the LineCollection
@@ -639,14 +645,13 @@ def _grid_to_matplotlib_linecollection(
639645
else:
640646
lines.append(np.array(boundary.coords))
641647

642-
# Set default transform if not provided
643648
if "transform" not in kwargs:
644-
if projection is None:
649+
# Set default transform if one is not provided not provided
650+
if projection is None or not apply_projection:
645651
kwargs["transform"] = ccrs.PlateCarree(central_longitude=central_longitude)
646652
else:
647653
kwargs["transform"] = projection
648654

649-
# Return a LineCollection of the line segments
650655
return LineCollection(lines, **kwargs)
651656

652657

Diff for: uxarray/grid/grid.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -1635,13 +1635,13 @@ def to_geodataframe(
16351635
self,
16361636
periodic_elements: Optional[str] = "exclude",
16371637
projection: Optional[ccrs.Projection] = None,
1638-
project: Optional[bool] = False,
16391638
cache: Optional[bool] = True,
16401639
override: Optional[bool] = False,
16411640
engine: Optional[str] = "spatialpandas",
16421641
exclude_antimeridian: Optional[bool] = None,
16431642
return_non_nan_polygon_indices: Optional[bool] = False,
16441643
exclude_nan_polygons: Optional[bool] = True,
1644+
**kwargs,
16451645
):
16461646
"""Constructs a ``GeoDataFrame`` consisting of polygons representing
16471647
the faces of the current ``Grid``
@@ -1661,7 +1661,8 @@ def to_geodataframe(
16611661
- 'split': Periodic elements will be identified and split using the ``antimeridian`` package
16621662
- 'ignore': No processing will be applied to periodic elements.
16631663
projection: ccrs.Projection, optional
1664-
Geographic projection used to transform polygons
1664+
Geographic projection used to transform polygons. Only supported when periodic_elements is set to
1665+
'ignore' or 'exclude'
16651666
cache: bool, optional
16661667
Flag used to select whether to cache the computed GeoDataFrame
16671668
override: bool, optional
@@ -1679,7 +1680,7 @@ def to_geodataframe(
16791680
16801681
Returns
16811682
-------
1682-
gdf : spatialpandas.GeoDataFrame
1683+
gdf : spatialpandas.GeoDataFrame or geopandas.GeoDataFrame
16831684
The output ``GeoDataFrame`` with a filled out "geometry" column of polygons.
16841685
"""
16851686

@@ -1688,6 +1689,9 @@ def to_geodataframe(
16881689
f"Invalid engine. Expected one of ['spatialpandas', 'geopandas'] but received {engine}"
16891690
)
16901691

1692+
# if project is false, projection is only used for determining central coordinates
1693+
project = kwargs.get("project", True)
1694+
16911695
if projection and project:
16921696
if periodic_elements == "split":
16931697
raise ValueError(
@@ -1871,13 +1875,6 @@ def to_linecollection(
18711875
f"Invalid value for 'periodic_elements'. Expected one of ['ignore', 'exclude', 'split'] but received: {periodic_elements}"
18721876
)
18731877

1874-
if projection is not None:
1875-
if periodic_elements == "split":
1876-
raise ValueError(
1877-
"Setting ``periodic_elements='split'`` is not supported when a "
1878-
"projection is provided."
1879-
)
1880-
18811878
if self._line_collection_cached_parameters["line_collection"] is not None:
18821879
if (
18831880
self._line_collection_cached_parameters["periodic_elements"]

Diff for: uxarray/plot/accessor.py

+29-7
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,13 @@ def face_centers(self, backend=None, **kwargs):
163163

164164
face_centers.__doc__ = face_coords.__doc__
165165

166-
def edges(self, periodic_elements="exclude", backend=None, **kwargs):
166+
def edges(
167+
self,
168+
periodic_elements="exclude",
169+
backend=None,
170+
engine="spatialpandas",
171+
**kwargs,
172+
):
167173
"""Plots the edges of a Grid.
168174
169175
This function plots the edges of the grid as geographical paths using `hvplot`.
@@ -182,6 +188,8 @@ def edges(self, periodic_elements="exclude", backend=None, **kwargs):
182188
- "split": Split periodic elements.
183189
backend : str or None, optional
184190
Plotting backend to use. One of ['matplotlib', 'bokeh']. Equivalent to running holoviews.extension(backend)
191+
engine: str, optional
192+
Engine to use for GeoDataFrame construction. One of ['spatialpandas', 'geopandas']
185193
**kwargs : dict
186194
Additional keyword arguments passed to `hvplot.paths`. These can include:
187195
- "rasterize" (bool): Whether to rasterize the plot (default: False),
@@ -195,7 +203,6 @@ def edges(self, periodic_elements="exclude", backend=None, **kwargs):
195203
gdf.hvplot.paths : hvplot.paths
196204
A paths plot of the edges of the unstructured grid
197205
"""
198-
199206
uxarray.plot.utils.backend.assign(backend)
200207

201208
if "rasterize" not in kwargs:
@@ -212,8 +219,11 @@ def edges(self, periodic_elements="exclude", backend=None, **kwargs):
212219
kwargs["crs"] = ccrs.PlateCarree(central_longitude=central_longitude)
213220

214221
gdf = self._uxgrid.to_geodataframe(
215-
periodic_elements=periodic_elements, projection=kwargs.get("projection")
216-
)[["geometry"]]
222+
periodic_elements=periodic_elements,
223+
projection=kwargs.get("projection"),
224+
engine=engine,
225+
project=False,
226+
)
217227

218228
return gdf.hvplot.paths(geo=True, **kwargs)
219229

@@ -260,8 +270,15 @@ def __getattr__(self, name: str) -> Any:
260270
else:
261271
raise AttributeError(f"Unsupported Plotting Method: '{name}'")
262272

263-
def polygons(self, periodic_elements="exclude", backend=None, *args, **kwargs):
264-
"""Generate a shaded polygon plot of a face-centered data variable.
273+
def polygons(
274+
self,
275+
periodic_elements="exclude",
276+
backend=None,
277+
engine="spatialpandas",
278+
*args,
279+
**kwargs,
280+
):
281+
"""Generated a shaded polygon plot.
265282
266283
This function plots the faces of an unstructured grid shaded with a face-centered data variable using hvplot.
267284
It allows for rasterization, projection settings, and labeling of the data variable to be
@@ -278,6 +295,8 @@ def polygons(self, periodic_elements="exclude", backend=None, *args, **kwargs):
278295
- "ignore": Include periodic elements without any corrections
279296
backend : str or None, optional
280297
Plotting backend to use. One of ['matplotlib', 'bokeh']. Equivalent to running holoviews.extension(backend)
298+
engine: str, optional
299+
Engine to use for GeoDataFrame construction. One of ['spatialpandas', 'geopandas']
281300
*args : tuple
282301
Additional positional arguments to be passed to `hvplot.polygons`.
283302
**kwargs : dict
@@ -309,7 +328,10 @@ def polygons(self, periodic_elements="exclude", backend=None, *args, **kwargs):
309328
kwargs["crs"] = ccrs.PlateCarree(central_longitude=central_longitude)
310329

311330
gdf = self._uxda.to_geodataframe(
312-
periodic_elements=periodic_elements, projection=kwargs.get("projection")
331+
periodic_elements=periodic_elements,
332+
projection=kwargs.get("projection"),
333+
engine=engine,
334+
project=False,
313335
)
314336

315337
return gdf.hvplot.polygons(

0 commit comments

Comments
 (0)