Skip to content

Commit 1cb2a45

Browse files
Merge pull request #2332 from Parcels-code/implementing_field_interpolation_api
Implementing Field interpolation API
2 parents 698de64 + 76212f6 commit 1cb2a45

File tree

10 files changed

+135
-150
lines changed

10 files changed

+135
-150
lines changed

src/parcels/_core/field.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,9 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
218218
else:
219219
_ei = particles.ei[:, self.igrid]
220220

221-
tau, ti = _search_time_index(self, time)
222-
position = self.grid.search(z, y, x, ei=_ei)
223-
_update_particles_ei(particles, position, self)
224-
_update_particle_states_position(particles, position)
221+
particle_positions, grid_positions = _get_positions(self, time, z, y, x, particles, _ei)
225222

226-
value = self._interp_method(self, ti, position, tau, time, z, y, x)
223+
value = self._interp_method(particle_positions, grid_positions, self)
227224

228225
_update_particle_states_interp_value(particles, value)
229226

@@ -304,20 +301,17 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
304301
else:
305302
_ei = particles.ei[:, self.igrid]
306303

307-
tau, ti = _search_time_index(self.U, time)
308-
position = self.grid.search(z, y, x, ei=_ei)
309-
_update_particles_ei(particles, position, self)
310-
_update_particle_states_position(particles, position)
304+
particle_positions, grid_positions = _get_positions(self.U, time, z, y, x, particles, _ei)
311305

312306
if self._vector_interp_method is None:
313-
u = self.U._interp_method(self.U, ti, position, tau, time, z, y, x)
314-
v = self.V._interp_method(self.V, ti, position, tau, time, z, y, x)
307+
u = self.U._interp_method(particle_positions, grid_positions, self.U)
308+
v = self.V._interp_method(particle_positions, grid_positions, self.V)
315309
if "3D" in self.vector_type:
316-
w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x)
310+
w = self.W._interp_method(particle_positions, grid_positions, self.W)
317311
else:
318312
w = 0.0
319313
else:
320-
(u, v, w) = self._vector_interp_method(self, ti, position, tau, time, z, y, x)
314+
(u, v, w) = self._vector_interp_method(particle_positions, grid_positions, self)
321315

322316
if applyConversion:
323317
u = self.U.units.to_target(u, z, y, x)
@@ -343,45 +337,54 @@ def __getitem__(self, key):
343337
return _deal_with_errors(error, key, vector_type=self.vector_type)
344338

345339

346-
def _update_particles_ei(particles, position, field):
340+
def _update_particles_ei(particles, grid_positions: dict, field: Field):
347341
"""Update the element index (ei) of the particles"""
348342
if particles is not None:
349343
if isinstance(field.grid, XGrid):
350344
particles.ei[:, field.igrid] = field.grid.ravel_index(
351345
{
352-
"X": position["X"][0],
353-
"Y": position["Y"][0],
354-
"Z": position["Z"][0],
346+
"X": grid_positions["X"]["index"],
347+
"Y": grid_positions["Y"]["index"],
348+
"Z": grid_positions["Z"]["index"],
355349
}
356350
)
357351
elif isinstance(field.grid, UxGrid):
358352
particles.ei[:, field.igrid] = field.grid.ravel_index(
359353
{
360-
"Z": position["Z"][0],
361-
"FACE": position["FACE"][0],
354+
"Z": grid_positions["Z"]["index"],
355+
"FACE": grid_positions["FACE"]["index"],
362356
}
363357
)
364358

365359

366-
def _update_particle_states_position(particles, position):
360+
def _update_particle_states_position(particles, grid_positions: dict):
367361
"""Update the particle states based on the position dictionary."""
368362
if particles: # TODO also support uxgrid search
369363
for dim in ["X", "Y"]:
370-
if dim in position:
364+
if dim in grid_positions:
371365
particles.state = np.maximum(
372-
np.where(position[dim][0] == -1, StatusCode.ErrorOutOfBounds, particles.state), particles.state
366+
np.where(grid_positions[dim]["index"] == -1, StatusCode.ErrorOutOfBounds, particles.state),
367+
particles.state,
373368
)
374369
particles.state = np.maximum(
375-
np.where(position[dim][0] == GRID_SEARCH_ERROR, StatusCode.ErrorGridSearching, particles.state),
370+
np.where(
371+
grid_positions[dim]["index"] == GRID_SEARCH_ERROR,
372+
StatusCode.ErrorGridSearching,
373+
particles.state,
374+
),
376375
particles.state,
377376
)
378-
if "Z" in position:
377+
if "Z" in grid_positions:
379378
particles.state = np.maximum(
380-
np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particles.state),
379+
np.where(
380+
grid_positions["Z"]["index"] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particles.state
381+
),
381382
particles.state,
382383
)
383384
particles.state = np.maximum(
384-
np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particles.state),
385+
np.where(
386+
grid_positions["Z"]["index"] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particles.state
387+
),
385388
particles.state,
386389
)
387390

@@ -469,3 +472,14 @@ def _assert_same_time_interval(fields: list[Field]) -> None:
469472
raise ValueError(
470473
f"Fields must have the same time domain. {fields[0].name}: {reference_time_interval}, {field.name}: {field.time_interval}"
471474
)
475+
476+
477+
def _get_positions(field: Field, time, z, y, x, particles, _ei) -> tuple[dict, dict]:
478+
"""Initialize and populate particle_positions and grid_positions dictionaries"""
479+
particle_positions = {"time": time, "z": z, "lat": y, "lon": x}
480+
grid_positions = {}
481+
grid_positions.update(_search_time_index(field, time))
482+
grid_positions.update(field.grid.search(z, y, x, ei=_ei))
483+
_update_particles_ei(particles, grid_positions, field)
484+
_update_particle_states_position(particles, grid_positions)
485+
return particle_positions, grid_positions

src/parcels/_core/index_search.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,19 @@ def _search_time_index(field: Field, time: datetime):
7575
if the sampled value is outside the time value range.
7676
"""
7777
if field.time_interval is None:
78-
return np.zeros(shape=time.shape, dtype=np.float32), np.zeros(shape=time.shape, dtype=np.int32)
78+
return {
79+
"T": {
80+
"index": np.zeros(shape=time.shape, dtype=np.int32),
81+
"bcoord": np.zeros(shape=time.shape, dtype=np.float32),
82+
}
83+
}
7984

8085
if not field.time_interval.is_all_time_in_interval(time):
8186
_raise_outside_time_interval_error(time, field=None)
8287

8388
ti = np.searchsorted(field.data.time.data, time, side="right") - 1
8489
tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti])
85-
return np.atleast_1d(tau), np.atleast_1d(ti)
90+
return {"T": {"index": np.atleast_1d(ti), "bcoord": np.atleast_1d(tau)}}
8691

8792

8893
def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray, xi: np.ndarray):

src/parcels/_core/particleset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,12 +294,12 @@ def _neighbors_by_coor(self, coor):
294294
def populate_indices(self):
295295
"""Pre-populate guesses of particle ei (element id) indices"""
296296
for i, grid in enumerate(self.fieldset.gridset):
297-
position = grid.search(self.z, self.lat, self.lon)
297+
grid_positions = grid.search(self.z, self.lat, self.lon)
298298
self._data["ei"][:, i] = grid.ravel_index(
299299
{
300-
"X": position["X"][0],
301-
"Y": position["Y"][0],
302-
"Z": position["Z"][0],
300+
"X": grid_positions["X"]["index"],
301+
"Y": grid_positions["Y"]["index"],
302+
"Z": grid_positions["Z"]["index"],
303303
}
304304
)
305305

src/parcels/_core/uxgrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,4 @@ def search(self, z, y, x, ei=None, tol=1e-6):
118118
coords[zero_indices, :] = coords_q
119119
fi[zero_indices] = face_ids_q
120120

121-
return {"Z": (zi, zeta), "FACE": (fi, coords)}
121+
return {"Z": {"index": zi, "bcoord": zeta}, "FACE": {"index": fi, "bcoord": coords}}

src/parcels/_core/xgrid.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,11 @@ def search(self, z, y, x, ei=None):
289289
if ds.lon.ndim == 1:
290290
yi, eta = _search_1d_array(ds.lat.values, y)
291291
xi, xsi = _search_1d_array(ds.lon.values, x)
292-
return {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)}
292+
return {
293+
"Z": {"index": zi, "bcoord": zeta},
294+
"Y": {"index": yi, "bcoord": eta},
295+
"X": {"index": xi, "bcoord": xsi},
296+
}
293297

294298
yi, xi = None, None
295299
if ei is not None:
@@ -300,7 +304,11 @@ def search(self, z, y, x, ei=None):
300304
if ds.lon.ndim == 2:
301305
yi, eta, xi, xsi = _search_indices_curvilinear_2d(self, y, x, yi, xi)
302306

303-
return {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)}
307+
return {
308+
"Z": {"index": zi, "bcoord": zeta},
309+
"Y": {"index": yi, "bcoord": eta},
310+
"X": {"index": xi, "bcoord": xsi},
311+
}
304312

305313
raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")
306314

0 commit comments

Comments
 (0)