Skip to content

Commit 5dc1c55

Browse files
Vectorized version of XLinearInvdistLandTracer interpolator
With help from ChatGPT for the vectorization
1 parent 6392d0e commit 5dc1c55

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

src/parcels/interpolators.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -579,8 +579,6 @@ def XLinearInvdistLandTracer(
579579
"""Linear spatial interpolation on a regular grid, where points on land are not used."""
580580
values = XLinear(particle_positions, grid_positions, field)
581581

582-
on_land = np.argwhere(np.isnan(values))
583-
584582
xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"]
585583
yi, eta = grid_positions["Y"]["index"], grid_positions["Y"]["bcoord"]
586584
zi, zeta = grid_positions["Z"]["index"], grid_positions["Z"]["bcoord"]
@@ -592,27 +590,29 @@ def XLinearInvdistLandTracer(
592590

593591
corner_data = _get_corner_data_Agrid(field.data, ti, zi, yi, xi, lenT, lenZ, len(xsi), axis_dim)
594592

595-
def is_land(p: int):
596-
value = corner_data[:, :, :, :, p]
597-
return np.where(np.isnan(value), True, False)
593+
land_mask = np.isnan(corner_data)
594+
nb_land = np.sum(land_mask, axis=(0, 1, 2, 3))
598595

599-
for p in on_land:
600-
land = is_land(p)
601-
nb_land = np.sum(land)
602-
if nb_land == 4 * lenZ * lenT:
603-
values[p] = 0.0
604-
else:
605-
val = 0
606-
w_sum = 0
607-
for t in range(lenT):
608-
for k in range(lenZ):
609-
for j in range(2):
610-
for i in range(2):
611-
if land[t][k][j][i] == 0:
612-
distance = pow((eta[p] - j), 2) + pow((xsi[p] - i), 2)
613-
val += corner_data[t, k, j, i, p] / distance
614-
w_sum += 1 / distance
615-
values[p] = val / w_sum
596+
if np.any(nb_land):
597+
all_land_mask = nb_land == 4 * lenZ * lenT
598+
values[all_land_mask] = 0.0
599+
600+
not_all_land = ~all_land_mask
601+
if np.any(not_all_land):
602+
i_grid = np.arange(2)[None, None, None, :, None]
603+
j_grid = np.arange(2)[None, None, :, None, None]
604+
eta_b = eta[None, None, None, None, :]
605+
xsi_b = xsi[None, None, None, None, :]
606+
607+
inv_dist = 1.0 / ((eta_b - j_grid) ** 2 + (xsi_b - i_grid) ** 2)
608+
609+
valid_mask = ~land_mask
610+
weighted = np.where(valid_mask, corner_data * inv_dist, 0.0)
611+
612+
val = np.sum(weighted, axis=(0, 1, 2, 3))
613+
w_sum = np.sum(np.where(valid_mask, inv_dist, 0.0), axis=(0, 1, 2, 3))
614+
615+
values[not_all_land] = val[not_all_land] / w_sum[not_all_land]
616616

617617
return values.compute() if is_dask_collection(values) else values
618618

0 commit comments

Comments
 (0)