Skip to content

Commit 336424a

Browse files
sol1105aulemahal
andauthored
Enabling target mask for LocStream input (#445)
* Enabling target masks for LocStream input - Adding function xesmf.smm.post_apply_target_mask_to_weights to zero out weights of masked target grid cells - Adding application of this function to BaseRegridder in the case of LocStream input, Grid output incl. mask - This may lead to unexpected results when not using with nearest_s2d, as the masking is applied *post* weight generation * Restricting automatic application of new masking option to nearest_s2d * Update new tests * Updated changelog. * Updated changelog. * Pin cf_xarray as latest update seems to have introduced an issue to cfxr.bounds_to_vertices * Pin cf_xarray for ci as latest update seems to have introduced an issue to cfxr.bounds_to_vertices * unpin cfxr --------- Co-authored-by: Pascal Bourgault <[email protected]>
1 parent 41e1bcc commit 336424a

File tree

5 files changed

+167
-3
lines changed

5 files changed

+167
-3
lines changed

CHANGES.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
What's new
22
==========
33

4-
0.8.11 (unreleased)
5-
-------------------
6-
* ``xe.util.cf_grid_2d`` returns bounds as coordinates, as ``grid_2d`` does and as usually expected. (:pull:453`). `By `Pascal Bourgault <https://github.com/aulemahal>`_.
4+
0.9.0 (unreleased)
5+
------------------
6+
* Added support for target masks when regridding ``LocStream`` to ``Grid`` with ``nearest_s2d`` (:pull:`445`). By `Martin Schupfner <https://github.com/sol1105>`_.
7+
* ``xesmf.util.cf_grid_2d`` returns bounds as coordinates, as ``grid_2d`` does and as usually expected. (:pull:`453`). `By `Pascal Bourgault <https://github.com/aulemahal>`_.
78

89
0.8.10 (2025-04-29)
910
-------------------

xesmf/frontend.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
add_nans_to_weights,
1919
apply_weights,
2020
check_shapes,
21+
post_apply_target_mask_to_weights,
2122
read_weights,
2223
)
2324
from .util import LAT_CF_ATTRS, LON_CF_ATTRS, split_polygons_and_holes
@@ -376,6 +377,22 @@ def __init__(
376377
# Convert weights, whatever their format, to a sparse coo matrix
377378
self.weights = read_weights(weights, self.n_in, self.n_out)
378379

380+
# Optionally post-apply output mask for LocStream input and Grid output
381+
# as xesmf.backend.esmf_regrid_build filters the output mask in that case
382+
# (ESMF does not support output masks for LocStream input and Grid output)
383+
# Only method supported is `nearest_s2d`:
384+
# For other methods the masking approach may lead to unexpected results
385+
# as the weights are applied post weight generation and other methods may have
386+
# the source-target mapping depending on the location of masked cells.
387+
if (
388+
isinstance(grid_in, LocStream)
389+
and isinstance(grid_out, Grid)
390+
and grid_out.mask is not None
391+
and grid_out.mask[0] is not None
392+
and method == 'nearest_s2d'
393+
):
394+
self.weights = post_apply_target_mask_to_weights(self.weights, grid_out.mask[0])
395+
379396
# replace zeros by NaN for weight matrix entries of unmapped target cells if specified or a mask is present
380397
if (
381398
(grid_out.mask is not None) and (grid_out.mask[0] is not None)

xesmf/smm.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,60 @@ def apply_weights(weights, indata, shape_in, shape_out):
201201
return outdata
202202

203203

204+
def post_apply_target_mask_to_weights(weights, target_mask_2d):
205+
"""
206+
Set all contributions to masked target grid cells to zero.
207+
208+
Parameters
209+
----------
210+
weights : DataArray backed by a sparse.COO array
211+
Sparse weights matrix.
212+
target_mask_2d : array-like
213+
Mask array of shape (nx, ny) for the target grid.
214+
False / 0 indicates a masked cell whose weights shall be set to zero.
215+
216+
Returns
217+
-------
218+
DataArray backed by a sparse.COO array
219+
Modified sparse weights matrix with all rows corresponding to masked target cells set to zero.
220+
221+
Notes
222+
-----
223+
This defines a post-processing step applied after ESMF weight generation.
224+
It is useful in cases where ESMF/ESMPy masks cannot be used directly,
225+
which is the case when source or target are LocStream objects / sequences.
226+
"""
227+
# Ensure mask can be converted to array
228+
try:
229+
target_mask_2d = np.asarray(target_mask_2d, dtype=weights.data.data.dtype)
230+
except Exception as e:
231+
raise TypeError(
232+
"Argument 'target_mask_2d' must be array-like and convertible to a numeric/boolean array"
233+
) from e
234+
235+
# Validate dimensionality and shape
236+
if target_mask_2d.ndim != 2:
237+
raise ValueError(f"Argument 'target_mask_2d' must be 2D, got shape {target_mask_2d.shape}")
238+
n_target, n_source = weights.data.shape
239+
if target_mask_2d.size != n_target:
240+
raise ValueError(
241+
f"Mismatch: weight matrix has {n_target} target cells, "
242+
f"but mask has {target_mask_2d.size} elements"
243+
)
244+
245+
# Flatten mask array to align with weight matrix target index (Fortran order for ESMF layout)
246+
target_mask_flat = target_mask_2d.ravel(order='F')
247+
248+
# Multiply row-wise by mask to zero out weights of masked target cells
249+
W = weights.data * target_mask_flat[:, None]
250+
251+
# Create weights DataArray backed by sparse.COO
252+
weights = xr.DataArray(
253+
sps.COO(coords=W.coords, data=W.data, shape=W.shape), dims=('out_dim', 'in_dim')
254+
)
255+
return weights
256+
257+
204258
def add_nans_to_weights(weights):
205259
"""Add nan in empty rows of the regridding weights sparse matrix.
206260

xesmf/tests/test_frontend.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,3 +1018,36 @@ def test_spatial_averager_mask():
10181018
savg = xe.SpatialAverager(dsm, [poly], geom_dim_name='my_geom')
10191019
out = savg(dsm.abc)
10201020
assert_allclose(out, 2, rtol=1e-3)
1021+
1022+
1023+
def test_locstream_input_grid_output_with_target_mask_applied():
1024+
# Create locstream input (6 coordinate points)
1025+
locstream_in = xr.Dataset(
1026+
{'var': xr.DataArray(np.ones((6)), dims=['location'])},
1027+
coords={
1028+
'lat': ('location', np.linspace(0, 5, 6)),
1029+
'lon': ('location', np.linspace(0, 10, 6)),
1030+
},
1031+
)
1032+
1033+
# Create Grid output with target mask (3x3 grid)
1034+
ds_out = xe.util.cf_grid_2d(0, 10, 10.0 / 3.0, 0, 5, 5.0 / 3.0)
1035+
target_mask_2d = np.ones((3, 3), dtype=bool)
1036+
target_mask_2d[-1, :] = False
1037+
ds_out['target_mask'] = xr.DataArray(target_mask_2d, dims=['lat', 'lon'])
1038+
1039+
# Create Grid output with target mask (3x3 grid)
1040+
ds_out = xe.util.cf_grid_2d(0, 10, 10.0 / 3.0, 0, 5, 5.0 / 3.0)
1041+
target_mask_2d = np.ones((3, 3), dtype=bool)
1042+
target_mask_2d[-1, :] = False
1043+
ds_out['mask'] = xr.DataArray(target_mask_2d, dims=['lat', 'lon'])
1044+
1045+
# Generate weights
1046+
regridder = xe.Regridder(
1047+
ds_in=locstream_in, ds_out=ds_out, method='nearest_s2d', locstream_in=True
1048+
)
1049+
1050+
# Apply weights and check results - the northmost cells should be masked
1051+
da_out = regridder(locstream_in)['var']
1052+
assert np.all(np.isnan(da_out[-1, :]))
1053+
assert np.all(da_out[:-1, :] == 1)

xesmf/tests/test_smm.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23
import sparse as sps
34
import xarray as xr
45

@@ -27,3 +28,61 @@ def test_add_nans_to_weights():
2728

2829
Matout = xe.smm.add_nans_to_weights(xr.DataArray(Matin, dims=('in', 'out')))
2930
assert np.allclose(Matin.todense(), Matout.data.todense())
31+
32+
33+
def test_post_apply_target_mask_to_weights():
34+
# Create a small sparse weights matrix with shape (9 target, 4 source)
35+
# coords = [[target_indices], [source_indices]]
36+
coords = np.array([[0, 1, 1, 2, 3, 3, 4, 5], [0, 0, 1, 1, 2, 3, 2, 3]])
37+
data = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.45, 0.7, 0.8])
38+
shape = (6, 4)
39+
W_sparse = sps.COO(coords, data, shape=shape)
40+
weights = xr.DataArray(W_sparse, dims=('out_dim', 'in_dim'))
41+
42+
# Define a 3x3 mask for target (flattened size = 9):
43+
# If all goes to plan, weights of cells 3 and 4 (i.e. index 2 and 3)
44+
# will be set to 0.
45+
target_mask_2d = np.array([[True, False], [True, True], [False, True]])
46+
47+
# Apply mask
48+
masked_weights = xe.smm.post_apply_target_mask_to_weights(weights, target_mask_2d)
49+
50+
# Check results
51+
np.testing.assert_array_equal(masked_weights.data.data, np.array([0.1, 0.2, 0.3, 0.7, 0.8]))
52+
np.testing.assert_array_equal(
53+
masked_weights.data.coords, np.array([[0, 1, 1, 4, 5], [0, 0, 1, 2, 3]])
54+
)
55+
56+
57+
def test_post_apply_target_mask_to_weights_exceptions():
58+
# Create a weights DataArray & mask
59+
coords = np.array([[0, 1], [0, 1]])
60+
data = np.array([0.5, 0.5])
61+
shape = (2, 2)
62+
W_sparse = sps.COO(coords, data, shape=shape)
63+
weights = xr.DataArray(W_sparse, dims=('out_dim', 'in_dim'))
64+
valid_mask = np.array([[True, False]])
65+
66+
# Mask not array-like
67+
with pytest.raises(
68+
TypeError,
69+
match="Argument 'target_mask_2d' must be array-like and convertible to a numeric/boolean array",
70+
):
71+
xe.smm.post_apply_target_mask_to_weights(weights, 'not_array_like')
72+
73+
# Shape mismatch
74+
wrong_shape_mask = np.array([[True, False, True]])
75+
with pytest.raises(
76+
ValueError, match='Mismatch: weight matrix has 2 target cells, but mask has 3 elements'
77+
):
78+
xe.smm.post_apply_target_mask_to_weights(weights, wrong_shape_mask)
79+
80+
# Mask not 2D
81+
wrong_shape_mask = np.array([[[True]], [[True]]])
82+
with pytest.raises(
83+
ValueError, match="Argument 'target_mask_2d' must be 2D, got shape \\(2, 1, 1\\)"
84+
):
85+
xe.smm.post_apply_target_mask_to_weights(weights, wrong_shape_mask)
86+
87+
# That should work
88+
xe.smm.post_apply_target_mask_to_weights(weights, valid_mask)

0 commit comments

Comments
 (0)