Skip to content

Commit 198d0ee

Browse files
Correctly guess output chunks for SpatialAverager (#308)
* correctly guess output chunks for Savg * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e5a233a commit 198d0ee

File tree

3 files changed

+26
-11
lines changed

3 files changed

+26
-11
lines changed

CHANGES.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
What's new
22
==========
33

4+
0.8.2 (unreleased)
5+
------------------
6+
7+
Bug fixes
8+
~~~~~~~~~
9+
* Correct guess of output chunks for the :``SpatialAverager``.
10+
411
0.8.1 (2023-09-05)
512
------------------
613

714
Bug fixes
815
~~~~~~~~~
916
* Change import to support shapely 1 and 2.
1017

11-
1218
0.8.0 (2023-09-01)
1319
------------------
1420

xesmf/frontend.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,21 +590,24 @@ def regrid_array(self, indata, weights, skipna=False, na_thres=1.0, output_chunk
590590

591591
kwargs.update(skipna=skipna, na_thres=na_thres)
592592

593+
weights = self.weights.data.reshape(self.shape_out + self.shape_in)
593594
if isinstance(indata, dask_array_type): # dask
594595
if output_chunks is None:
595-
output_chunks = indata.chunksize[-2:]
596-
elif output_chunks is not None:
597-
if len(output_chunks) != len(self.shape_out):
596+
output_chunks = tuple(
597+
[min(shp, inchnk) for shp, inchnk in zip(self.shape_out, indata.chunksize[-2:])]
598+
)
599+
if len(output_chunks) != len(self.shape_out):
600+
if len(output_chunks) == 1 and self.sequence_out:
601+
output_chunks = (1, output_chunks[0])
602+
else:
598603
raise ValueError(
599604
f'output_chunks must have same dimension as ds_out,'
600605
f' output_chunks dimension ({len(output_chunks)}) does not '
601606
f'match ds_out dimension ({len(self.shape_out)})'
602607
)
603-
weights = da.from_array(self.w.data, chunks=(output_chunks + indata.chunksize[-2:]))
604-
608+
weights = da.from_array(weights, chunks=(output_chunks + indata.chunksize[-2:]))
605609
outdata = self._regrid(indata, weights, **kwargs)
606610
else: # numpy
607-
weights = self.w.data # 4D weights
608611
outdata = self._regrid(indata, weights, **kwargs)
609612
return outdata
610613

xesmf/tests/test_frontend.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
import xarray as xr
99
from numpy.testing import assert_allclose, assert_almost_equal, assert_equal
10+
from shapely import segmentize
1011
from shapely.geometry import MultiPolygon, Polygon
1112

1213
import xesmf as xe
@@ -780,18 +781,22 @@ def test_ds_to_ESMFlocstream():
780781
locstream, shape, names = ds_to_ESMFlocstream(ds_bogus)
781782

782783

784+
@pytest.mark.parametrize('use_dask', [True, False])
783785
@pytest.mark.parametrize('poly,exp', list(zip(polys, exps_polys)))
784-
def test_spatial_averager(poly, exp):
786+
def test_spatial_averager(poly, exp, use_dask):
785787
if isinstance(poly, (Polygon, MultiPolygon)):
786788
poly = [poly]
787-
savg = xe.SpatialAverager(ds_savg, poly, geom_dim_name='my_geom')
788-
out = savg(ds_savg.abc)
789+
if use_dask:
790+
ds_in = ds_savg.chunk(lat=10)
791+
else:
792+
ds_in = ds_savg
793+
savg = xe.SpatialAverager(ds_in, poly, geom_dim_name='my_geom')
794+
out = savg(ds_in.abc)
789795
assert_allclose(out, exp, rtol=1e-3)
790796

791797
assert 'my_geom' in out.dims
792798

793799

794-
@pytest.mark.xfail
795800
def test_spatial_averager_with_zonal_region():
796801
# We expect the spatial average for all regions to be one
797802
zonal_south = Polygon([(0, -90), (10, 0), (0, 0)])
@@ -800,6 +805,7 @@ def test_spatial_averager_with_zonal_region():
800805
zonal_full = Polygon([(0, -90), (10, 0), (0, 90), (0, 0)]) # This yields 0... why?
801806

802807
polys = [zonal_south, zonal_north, zonal_short, zonal_full]
808+
polys = segmentize(polys, 1)
803809

804810
# Create field of ones on a global grid
805811
ds = xe.util.grid_global(20, 12, cf=True)

0 commit comments

Comments
 (0)