Skip to content
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ What's new
0.8.4 (2024-02-26)
------------------
* Fix regression from :pull:`332` that made ``Regridder`` fail with rectilinear datasets and ``parallel=True``. (:issue:`343`, :pull:`344`).
* Allow Python 3.12 (and higher) again. (:pull:`345).
* Allow Python 3.12 (and higher) again. (:pull:`345`).

0.8.3 (2024-02-20)
------------------
Expand Down
1,631 changes: 1,253 additions & 378 deletions doc/notebooks/Curvilinear_grid.ipynb

Large diffs are not rendered by default.

118 changes: 56 additions & 62 deletions doc/notebooks/Masking.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None):
grid_mask = mask.astype(np.int32)
if not (grid_mask.shape == lon.shape):
raise ValueError(
'mask must have the same shape as the latitude/longitude'
'mask must have the same shape as the latitude/longitude '
'coordinates, got: mask.shape = %s, lon.shape = %s' % (mask.shape, lon.shape)
)
grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, from_file=False)
Expand Down
2 changes: 1 addition & 1 deletion xesmf/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def wave_smooth(lon, lat):
2D wave field

Notes
-------
-----
Equation from [1]_ [2]_:

.. math:: Y_2^2 = 2 + \cos^2(\\theta) \cos(2 \phi)
Expand Down
106 changes: 102 additions & 4 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
add_nans_to_weights,
apply_weights,
check_shapes,
mask_source_indices,
read_weights,
)
from .util import LAT_CF_ATTRS, LON_CF_ATTRS, split_polygons_and_holes
from .util import LAT_CF_ATTRS, LON_CF_ATTRS, _get_edge_indices_2d, split_polygons_and_holes

try:
import dask.array as da
Expand Down Expand Up @@ -248,6 +249,7 @@ def __init__(
output_dims=None,
unmapped_to_nan=False,
parallel=False,
post_mask_source=None,
):
"""
Base xESMF regridding class supporting ESMF objects: `Grid`, `Mesh` and `LocStream`.
Expand Down Expand Up @@ -329,6 +331,28 @@ def __init__(
generation in the BaseRegridder is skipped and weights are generated in paralell in the
subsest_regridder instead.

post_mask_source : str or array-like, optional
Optionally applies a post-processing step to remove selected source grid cells from
contributing to the regridding weight matrix.

Note: This differs from the typical masking approach, which prevents source cells from
being used during weight generation. Here, the regridding weights are modified *after*
creation to remove the contribution of specified source grid cells.

Options:

- If set to ``"domain_edge"``, the outermost edge cells of the source grid are
automatically detected and their contribution to the regridding weights is removed.
This is useful to avoid extrapolation beyond the domain boundary when using the
nearest-neighbor method ``'nearest_s2d'``, particularly when remapping from a smaller
to a larger domain (as is common with regional source grids like CORDEX).
Only supported for ``Grid`` type ESMF objects as source grid.
- If an array-like of integers is provided, it is interpreted as flat indices
(i.e., 1D indices of the flattened source grid) identifying source cells
whose contribution to the regridding weights should be removed.

Default is ``None``, meaning no post-weight-generation source grid cell masking is applied.

Returns
-------
baseregridder : xESMF BaseRegridder object
Expand Down Expand Up @@ -361,10 +385,41 @@ def __init__(
self.n_in = self.shape_in[0] * self.shape_in[1]
self.n_out = self.shape_out[0] * self.shape_out[1]

# Validate post_mask_source
self.post_mask_source = None
if isinstance(post_mask_source, str) and post_mask_source == 'domain_edge':
if self.sequence_in:
raise ValueError(
"post_mask_source='domain_edge' is only supported for 'Grid' type ESMF objects "
'as source grid (i.e. structured - rectilinear or curvilinear - grids. '
f"Grid type detected: {type(grid_in)}"
)
self.post_mask_source = _get_edge_indices_2d(self.shape_in[1], self.shape_in[0])
elif post_mask_source is not None:
try:
self.post_mask_source = np.asarray(post_mask_source)
except Exception as e:
raise TypeError(
f"`post_mask_source` must be array-like of integers. Got: {type(post_mask_source)}"
) from e
if self.post_mask_source is not None and not np.issubdtype(
self.post_mask_source.dtype, np.integer
):
raise TypeError(
f"`post_mask_source` must be of integer type. Got dtype: {self.post_mask_source.dtype}"
)

# some logic about reusing weights with either filename or weights args
if reuse_weights and (filename is None) and (weights is None):
raise ValueError('To reuse weights, you need to provide either filename or weights.')

# decide whether unmapped cells should be mapped to NaN
self.unmapped_to_nan = False
if (
(grid_out.mask is not None) and (grid_out.mask[0] is not None)
) or unmapped_to_nan is True:
self.unmapped_to_nan = True

if not parallel:
if not reuse_weights and weights is None:
weights = self._compute_weights(grid_in, grid_out) # Dictionary of weights
Expand All @@ -376,10 +431,13 @@ def __init__(
# Convert weights, whatever their format, to a sparse coo matrix
self.weights = read_weights(weights, self.n_in, self.n_out)

# Optionally apply post_mask_source to manipulate the weights and removing
# the contribution of the specified source cells
if self.post_mask_source is not None:
self.weights = mask_source_indices(self.weights, self.post_mask_source)

# replace zeros by NaN for weight matrix entries of unmapped target cells if specified or a mask is present
if (
(grid_out.mask is not None) and (grid_out.mask[0] is not None)
) or unmapped_to_nan is True:
if self.unmapped_to_nan:
self.weights = add_nans_to_weights(self.weights)

# follows legacy logic of writing weights if filename is provided
Expand Down Expand Up @@ -870,6 +928,29 @@ def __init__(
If an output mask is defined, or regridding method is `nearest_s2d` or `nearest_d2s`,
this option has no effect.

post_mask_source : str or array-like, optional
Optionally applies a post-processing step to remove selected source grid cells from
contributing to the regridding weight matrix.

Note: This differs from the typical masking approach, which prevents source cells from
being used during weight generation. Here, the regridding weights are modified *after*
creation to remove the contribution of specified source grid cells.

Options:

- If set to ``"domain_edge"``, the outermost edge cells of the source grid are
automatically detected and their contribution to the regridding weights is removed.
This is useful to avoid extrapolation beyond the domain boundary when using the
nearest-neighbor method ``'nearest_s2d'``, particularly when remapping from a smaller
to a larger domain (as is common with regional source grids like CORDEX).
Only supported for ``Grid`` type ESMF objects as source grid.

- If an array-like of integers is provided, it is interpreted as flat indices
(i.e., 1D indices of the flattened source grid) identifying source cells
whose contribution to the regridding weights should be removed.

Default is ``None``, meaning no post-weight-generation source grid cell masking is applied.

Returns
-------
regridder : xESMF regridder object
Expand Down Expand Up @@ -932,6 +1013,7 @@ def __init__(
parallel=parallel,
**kwargs,
)

# Weights are computed, we do not need the grids anymore
grid_in.destroy()
grid_out.destroy()
Expand Down Expand Up @@ -983,6 +1065,7 @@ def __init__(
self.out_coords = xr.Dataset(coords={lat_out.name: lat_out, lon_out.name: lon_out})

if parallel:
# Generate the weights in parallel
self._init_para_regrid(ds_in, ds_out, kwargs)

def _init_para_regrid(self, ds_in, ds_out, kwargs):
Expand Down Expand Up @@ -1075,6 +1158,12 @@ def _init_para_regrid(self, ds_in, ds_out, kwargs):
chunks
) # template has same chunks as ds_out

# If post_mask_source is specified, set it to None, as it needs to be dealt with
# for the final weights only
if 'post_mask_source' in kwargs:
kwargs['post_mask_source'] = None

# Compute weights in parallel
w = xr.map_blocks(
subset_regridder,
ds_out,
Expand All @@ -1095,6 +1184,15 @@ def _init_para_regrid(self, ds_in, ds_out, kwargs):
weights.name = 'weights'
self.weights = weights

# Optionally apply post_mask_source to manipulate the weights and removing
# the contribution of the specified source cells
if self.post_mask_source is not None:
self.weights = mask_source_indices(self.weights, self.post_mask_source)
# replace zeros by NaN for weight matrix entries of unmapped target cells
# if specified or a mask is present
if self.unmapped_to_nan:
self.weights = add_nans_to_weights(self.weights)

# follows legacy logic of writing weights if filename is provided
if 'filename' in kwargs:
filename = kwargs['filename']
Expand Down
80 changes: 80 additions & 0 deletions xesmf/smm.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,86 @@ def add_nans_to_weights(weights):
return weights


def mask_source_indices(weights, source_indices_to_mask):
"""
Remove entries in a sparse.COO weight matrix that map from masked source indices.

Parameters
----------
weights: DataArray backed by a sparse.COO array
Sparse weights matrix.
source_indices_to_mask: array-like
Flat indices of source grid cells whose contribution should be removed
(eg. output of xesmf.util._get_edge_indices_2d)

Returns
-------
DataArray backed by a sparse.COO array:
New weight matrix with masked source contributions removed
"""
# Extract the sparse.COO remapping weight matrix (ntarget, nsource)
W = weights.data

# W.coords is a 2D array with shape (2, N), holding the non-zero entries in the
# sparse matrix. Rows being [0] target_idx, [1] source_idx.
tgt_idx = W.coords[0]
src_idx = W.coords[1]
# Each W.data[i] represents the weight from source_idx[i] to target_idx[i]
data = W.data

# Validate source_indices_to_mask
n_source = W.shape[1]
invalid = np.asarray(source_indices_to_mask) >= n_source
if np.any(invalid) or np.any(np.asarray(source_indices_to_mask) < 0):
raise ValueError(
f"Some of the provided source indices are out of valid range [0, {n_source}). "
f"Invalid indices: {np.asarray(source_indices_to_mask)[invalid]}"
)

# Boolean mask for the source_idx - False for masked source indices
mask = ~np.isin(src_idx, source_indices_to_mask)

# Create new sparse matrix with only non-masked entries
data_masked = data[mask]
# Create new coordinates array by vertical (row-wise) stacking of the new target and source indices
coords_masked = np.vstack([tgt_idx[mask], src_idx[mask]])

# Create new sparse weight matrix and assign it to the weights DataArray
weights = xr.DataArray(
sps.COO(coords_masked, data_masked, shape=W.shape), dims=('out_dim', 'in_dim')
)

return weights


def gen_mask_from_weights(weights, nlat, nlon):
"""Generate a 2D mask from the regridding weights sparse matrix.
This function will generate a 2D binary mask out of a regridding weights sparse matrix.

Parameters
----------
weights : DataArray backed by a sparse.COO array
Sparse weights matrix.

Returns
-------
numpy.ndarray of type numpy.int32 and of shape (nlat, nlon)
Binary mask.
"""
# Taken from @trondkr and adapted by @raphaeldussin to use `lil`.
# lil matrix is better than CSR when changing sparsity
m = weights.data.to_scipy_sparse().tolil()

# Create mask ndarray of ones and fill with 0-elements
mask = np.ones((nlat, nlon), dtype=np.int32).ravel()
for krow in range(len(m.rows)):
if any([np.isnan(x) for x in m.data[krow]]):
mask[krow] = 0

# Reshape and return
return mask.reshape((nlat, nlon))


def _combine_weight_multipoly(weights, areas, indexes):
"""Reduce a weight sparse matrix (csc format) by combining (adding) columns.

Expand Down
Loading
Loading