Skip to content

Commit 30b0dd4

Browse files
Merge pull request #2350 from Parcels-code/interpolator_invdistland
Adding invdistland interpolator
2 parents 1cb2a45 + 60dd863 commit 30b0dd4

File tree

2 files changed

+106
-16
lines changed

2 files changed

+106
-16
lines changed

src/parcels/interpolators.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"UXPiecewiseLinearNode",
2323
"XFreeslip",
2424
"XLinear",
25+
"XLinearInvdistLandTracer",
2526
"XNearest",
2627
"XPartialslip",
2728
"ZeroInterpolator",
@@ -75,11 +76,11 @@ def _get_corner_data_Agrid(
7576

7677
# Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/z
7778
yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
78-
yi = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT) * (lenZ))
79+
yi = np.tile(np.array([yi, yi, yi_1, yi_1]).flatten(), lenT * lenZ)
7980

8081
# X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/z
8182
xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
82-
xi = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT) * (lenZ))
83+
xi = np.tile(np.array([xi, xi_1]).flatten(), lenT * lenZ * 2)
8384

8485
# Create DataArrays for indexing
8586
selection_dict = {
@@ -91,7 +92,7 @@ def _get_corner_data_Agrid(
9192
if "time" in data.dims:
9293
selection_dict["time"] = xr.DataArray(ti, dims=("points"))
9394

94-
return data.isel(selection_dict).data.reshape(lenT, lenZ, npart, 4)
95+
return data.isel(selection_dict).data.reshape(lenT, lenZ, 2, 2, npart)
9596

9697

9798
def XLinear(
@@ -114,22 +115,22 @@ def XLinear(
114115
corner_data = _get_corner_data_Agrid(data, ti, zi, yi, xi, lenT, lenZ, len(xsi), axis_dim)
115116

116117
if lenT == 2:
117-
tau = tau[np.newaxis, :, np.newaxis]
118-
corner_data = corner_data[0, :, :, :] * (1 - tau) + corner_data[1, :, :, :] * tau
118+
tau = tau[np.newaxis, :]
119+
corner_data = corner_data[0, :] * (1 - tau) + corner_data[1, :] * tau
119120
else:
120-
corner_data = corner_data[0, :, :, :]
121+
corner_data = corner_data[0, :]
121122

122123
if lenZ == 2:
123-
zeta = zeta[:, np.newaxis]
124-
corner_data = corner_data[0, :, :] * (1 - zeta) + corner_data[1, :, :] * zeta
124+
zeta = zeta[np.newaxis, :]
125+
corner_data = corner_data[0, :] * (1 - zeta) + corner_data[1, :] * zeta
125126
else:
126-
corner_data = corner_data[0, :, :]
127+
corner_data = corner_data[0, :]
127128

128129
value = (
129-
(1 - xsi) * (1 - eta) * corner_data[:, 0]
130-
+ xsi * (1 - eta) * corner_data[:, 1]
131-
+ (1 - xsi) * eta * corner_data[:, 2]
132-
+ xsi * eta * corner_data[:, 3]
130+
(1 - xsi) * (1 - eta) * corner_data[0, 0, :]
131+
+ xsi * (1 - eta) * corner_data[0, 1, :]
132+
+ (1 - xsi) * eta * corner_data[1, 0, :]
133+
+ xsi * eta * corner_data[1, 1, :]
133134
)
134135
return value.compute() if is_dask_collection(value) else value
135136

@@ -409,8 +410,8 @@ def _Spatialslip(
409410
corner_dataV = _get_corner_data_Agrid(vectorfield.V.data, ti, zi, yi, xi, lenT, lenZ, npart, axis_dim)
410411

411412
def is_land(ti: int, zi: int, yi: int, xi: int):
412-
uval = corner_dataU[ti, zi, :, xi + 2 * yi]
413-
vval = corner_dataV[ti, zi, :, xi + 2 * yi]
413+
uval = corner_dataU[ti, zi, yi, xi, :]
414+
vval = corner_dataV[ti, zi, yi, xi, :]
414415
return np.where(np.isclose(uval, 0.0) & np.isclose(vval, 0.0), True, False)
415416

416417
f_u = np.ones_like(xsi)
@@ -571,6 +572,52 @@ def XNearest(
571572
return value.compute() if is_dask_collection(value) else value
572573

573574

575+
def XLinearInvdistLandTracer(
576+
particle_positions: dict[str, float | np.ndarray],
577+
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],
578+
field: Field,
579+
):
580+
"""Linear spatial interpolation on a regular grid, where points on land are not used."""
581+
values = XLinear(particle_positions, grid_positions, field)
582+
583+
xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"]
584+
yi, eta = grid_positions["Y"]["index"], grid_positions["Y"]["bcoord"]
585+
zi, zeta = grid_positions["Z"]["index"], grid_positions["Z"]["bcoord"]
586+
ti, tau = grid_positions["T"]["index"], grid_positions["T"]["bcoord"]
587+
588+
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
589+
lenT = 2 if np.any(tau > 0) else 1
590+
lenZ = 2 if np.any(zeta > 0) else 1
591+
592+
corner_data = _get_corner_data_Agrid(field.data, ti, zi, yi, xi, lenT, lenZ, len(xsi), axis_dim)
593+
594+
land_mask = np.isnan(corner_data)
595+
nb_land = np.sum(land_mask, axis=(0, 1, 2, 3))
596+
597+
if np.any(nb_land):
598+
all_land_mask = nb_land == 4 * lenZ * lenT
599+
values[all_land_mask] = 0.0
600+
601+
not_all_land = ~all_land_mask
602+
if np.any(not_all_land):
603+
i_grid = np.arange(2)[None, None, None, :, None]
604+
j_grid = np.arange(2)[None, None, :, None, None]
605+
eta_b = eta[None, None, None, None, :]
606+
xsi_b = xsi[None, None, None, None, :]
607+
608+
inv_dist = 1.0 / ((eta_b - j_grid) ** 2 + (xsi_b - i_grid) ** 2)
609+
610+
valid_mask = ~land_mask
611+
weighted = np.where(valid_mask, corner_data * inv_dist, 0.0)
612+
613+
val = np.sum(weighted, axis=(0, 1, 2, 3))
614+
w_sum = np.sum(np.where(valid_mask, inv_dist, 0.0), axis=(0, 1, 2, 3))
615+
616+
values[not_all_land] = val[not_all_land] / w_sum[not_all_land]
617+
618+
return values.compute() if is_dask_collection(values) else values
619+
620+
574621
def UXPiecewiseConstantFace(
575622
particle_positions: dict[str, float | np.ndarray],
576623
grid_positions: dict[_UXGRID_AXES, dict[str, int | float | np.ndarray]],

tests/test_interpolation.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
UXPiecewiseLinearNode,
2222
XFreeslip,
2323
XLinear,
24+
XLinearInvdistLandTracer,
2425
XNearest,
2526
XPartialslip,
2627
ZeroInterpolator,
@@ -68,9 +69,19 @@ def field():
6869
[0.49, 0.49],
6970
[0.51, 0.51],
7071
[1.49, 6.49],
71-
id="Linear",
72+
id="Linear-1",
7273
),
7374
pytest.param(XLinear, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 13.99, id="Linear-2"),
75+
pytest.param(
76+
XLinear,
77+
[np.timedelta64(0, "s"), np.timedelta64(1, "s"), np.timedelta64(1, "s")],
78+
[0, 0, 2.5],
79+
[0.49, 0.49, 0.49],
80+
[0.51, 0.51, 0.51],
81+
[1.49, 6.49, 13.99],
82+
id="Linear-3",
83+
),
84+
pytest.param(XLinearInvdistLandTracer, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 13.99, id="LinearInvDistLand"),
7485
pytest.param(
7586
XNearest,
7687
[np.timedelta64(0, "s"), np.timedelta64(3, "s")],
@@ -122,6 +133,38 @@ def test_spatial_slip_interpolation(field, func, t, z, y, x, expected):
122133
np.testing.assert_array_almost_equal(velocities, expected)
123134

124135

136+
@pytest.mark.parametrize(
137+
"func, t, z, y, x, expected",
138+
[
139+
(XLinearInvdistLandTracer, np.timedelta64(1, "s"), 0, 0.5, 0.5, 1.0),
140+
(XLinearInvdistLandTracer, np.timedelta64(1, "s"), 0, 1.5, 1.5, 0.0),
141+
(
142+
XLinearInvdistLandTracer,
143+
[np.timedelta64(0, "s"), np.timedelta64(1, "s")],
144+
[0, 2],
145+
[0.5, 0.5],
146+
[0.5, 0.5],
147+
1.0,
148+
),
149+
(
150+
XLinearInvdistLandTracer,
151+
[np.timedelta64(0, "s"), np.timedelta64(1, "s")],
152+
[0, 2],
153+
[0.5, 1.5],
154+
[0.5, 1.5],
155+
[1.0, 0.0],
156+
),
157+
],
158+
)
159+
def test_invdistland_interpolation(field, func, t, z, y, x, expected):
160+
field.data[:] = 1.0
161+
field.data[:, :, 1:3, 1:3] = np.nan # Set NaN land value to test inv_dist
162+
field.interp_method = func
163+
164+
value = field[t, z, y, x]
165+
np.testing.assert_array_almost_equal(value, expected)
166+
167+
125168
@pytest.mark.parametrize("mesh", ["spherical", "flat"])
126169
def test_interpolation_mesh_type(mesh, npart=10):
127170
ds = simple_UV_dataset(mesh=mesh)

0 commit comments

Comments
 (0)