diff --git a/xesmf/backend.py b/xesmf/backend.py index db1c10ba..6b891fb2 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -58,7 +58,7 @@ def warn_lat_range(lat): class Grid(ESMF.Grid): @classmethod - def from_xarray(cls, lon, lat, periodic=False, mask=None): + def from_xarray(cls, lon, lat, periodic=False, mask=None, pole_kind=None): """ Create an ESMF.Grid object, for constructing ESMF.Field and ESMF.Regrid. @@ -83,6 +83,17 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None): Shape should be ``(Nlon, Nlat)`` for rectilinear grid, or ``(Nx, Ny)`` for general quadrilateral grid. + pole_kind : [int, int] or None + Two item list which specifies the type of connection which occurs at the pole. + The first value specifies the connection that occurs at the minimum end of the + pole dimension. The second value specifies the connection that occurs at the + maximum end of the pole dimension. Options are 0 (no connections at pole), + 1 (monopole, this edge is connected to itself. Given that the edge is n elements long, + then element i is connected to element i+n/2), and 2 (bipole, this edge is connected + to itself. Given that the edge is n elements long, element i is connected to element n-i-1. + If None, defaults to [1,1] for monopole connections. See :attr:`ESMF.api.constants.PoleKind`. + Requires ESMF >= 8.0.1 + Returns ------- grid : ESMF.Grid object @@ -107,6 +118,12 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None): else: num_peri_dims = None + # `pole_kind` option supported since 8.0.1 + if ESMF.__version__ < '8.0.1': + if pole_kind is not None: + raise ValueError('The `pole_kind` option requires esmpy >= 8.0.1') + pole_kind = None + # ESMPy documentation claims that if staggerloc and coord_sys are None, # they will be set to default values (CENTER and SPH_DEG). # However, they actually need to be set explicitly, @@ -116,6 +133,7 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None): staggerloc=staggerloc, coord_sys=ESMF.CoordSys.SPH_DEG, num_peri_dims=num_peri_dims, + pole_kind=pole_kind, ) # The grid object points to the underlying Fortran arrays in ESMF. diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 989662a9..27632aae 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -157,11 +157,16 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): else: mask = None + if 'pole_kind' in ds: + pole_kind = np.asarray(ds['pole_kind']) + else: + pole_kind = None + # tranpose the arrays so they become Fortran-ordered if mask is not None: - grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=mask.T) + grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=mask.T, pole_kind=pole_kind) else: - grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=None) + grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=None, pole_kind=pole_kind) if need_bounds: lon_b, lat_b = _get_lon_lat_bounds(ds) diff --git a/xesmf/tests/test_frontend.py b/xesmf/tests/test_frontend.py index 4a5a595c..aa109751 100644 --- a/xesmf/tests/test_frontend.py +++ b/xesmf/tests/test_frontend.py @@ -1007,3 +1007,66 @@ def test_densify_polys(): poly = Polygon([(-80, -40), (80, -40), (80, 40), (-80, 40)]) # Large poly with pytest.warns(UserWarning): xe.SpatialAverager(ds_in, [poly]) + + +def test_regrid_polekind(): + + import pathlib + import tarfile + import urllib.request + + # get file from figshare with sample OM4 (tripolar) data + url_om4_sample = 'https://figshare.com/ndownloader/files/52228691' + fname = 'xesmf_testing_OM4.tar.gz' + urllib.request.urlretrieve(url_om4_sample, fname) + + # extract and read data + tar = tarfile.open(fname, 'r:gz') + tar.extractall() + tar.close() + + ds_in = xr.open_dataset('OM4_sample_sst.nc') + + # Open output grid specification + ds_out = xe.util.grid_global(1, 1, cf=False, lon1=180) + + # Create regridder without specifying pole kind + base_regrid = xe.Regridder(ds_in, ds_out, 'bilinear', ignore_degenerate=True, periodic=True) + base_result = base_regrid(ds_in['sst']) + + # Add monopole grid information. 1 denotes monopole, 2 bipole + ds_in['pole_kind'] = np.array([1, 1]) + ds_out['pole_kind'] = np.array([1, 1]) + + monopole_regrid = xe.Regridder(ds_in, ds_out, 'bilinear', ignore_degenerate=True, periodic=True) + monopole_result = monopole_regrid(ds_in['sst']) + + # Check behavior unchanged + assert monopole_result.equals(base_result) + + # Add bipole grid information + ds_in['pole_kind'] = np.array([1, 2], np.int32) + + bipole_regrid = xe.Regridder(ds_in, ds_out, 'bilinear', ignore_degenerate=True, periodic=True) + bipole_result = bipole_regrid(ds_in['sst']) + + # Confirm results have changed + assert not bipole_result.equals(monopole_result) + + # Confirm results are better with bipolar option + # without proper pole_kinds there are discontinuities close to the northfold + # easily visible on the northernmost rows + grad_monopole = np.gradient(monopole_result.isel(y=-2)) + grad_bipole = np.gradient(bipole_result.isel(y=-2)) + + rms_grad_monopole = (grad_monopole * grad_monopole).sum() + rms_grad_bipole = (grad_bipole * grad_bipole).sum() + + assert rms_grad_bipole < rms_grad_monopole + + # Clean up files + file_to_rem = pathlib.Path('OM4_sample_sst.nc') + file_to_rem.unlink() + file_to_rem = pathlib.Path('xesmf_testing_OM4.tar.gz') + file_to_rem.unlink() + return None