Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 40 additions & 26 deletions src/parcels/_core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,9 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
else:
_ei = particles.ei[:, self.igrid]

tau, ti = _search_time_index(self, time)
position = self.grid.search(z, y, x, ei=_ei)
_update_particles_ei(particles, position, self)
_update_particle_states_position(particles, position)
particle_positions, grid_positions = _get_positions(self, time, z, y, x, particles, _ei)

value = self._interp_method(self, ti, position, tau, time, z, y, x)
value = self._interp_method(particle_positions, grid_positions, self)

_update_particle_states_interp_value(particles, value)

Expand Down Expand Up @@ -304,20 +301,17 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
else:
_ei = particles.ei[:, self.igrid]

tau, ti = _search_time_index(self.U, time)
position = self.grid.search(z, y, x, ei=_ei)
_update_particles_ei(particles, position, self)
_update_particle_states_position(particles, position)
particle_positions, grid_positions = _get_positions(self.U, time, z, y, x, particles, _ei)

if self._vector_interp_method is None:
u = self.U._interp_method(self.U, ti, position, tau, time, z, y, x)
v = self.V._interp_method(self.V, ti, position, tau, time, z, y, x)
u = self.U._interp_method(particle_positions, grid_positions, self.U)
v = self.V._interp_method(particle_positions, grid_positions, self.V)
if "3D" in self.vector_type:
w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x)
w = self.W._interp_method(particle_positions, grid_positions, self.W)
else:
w = 0.0
else:
(u, v, w) = self._vector_interp_method(self, ti, position, tau, time, z, y, x)
(u, v, w) = self._vector_interp_method(particle_positions, grid_positions, self)

if applyConversion:
u = self.U.units.to_target(u, z, y, x)
Expand All @@ -343,45 +337,54 @@ def __getitem__(self, key):
return _deal_with_errors(error, key, vector_type=self.vector_type)


def _update_particles_ei(particles, position, field):
def _update_particles_ei(particles, grid_positions: dict, field: Field):
"""Update the element index (ei) of the particles"""
if particles is not None:
if isinstance(field.grid, XGrid):
particles.ei[:, field.igrid] = field.grid.ravel_index(
{
"X": position["X"][0],
"Y": position["Y"][0],
"Z": position["Z"][0],
"X": grid_positions["X"]["index"],
"Y": grid_positions["Y"]["index"],
"Z": grid_positions["Z"]["index"],
}
)
elif isinstance(field.grid, UxGrid):
particles.ei[:, field.igrid] = field.grid.ravel_index(
{
"Z": position["Z"][0],
"FACE": position["FACE"][0],
"Z": grid_positions["Z"]["index"],
"FACE": grid_positions["FACE"]["index"],
}
)


def _update_particle_states_position(particles, position):
def _update_particle_states_position(particles, grid_positions: dict):
"""Update the particle states based on the position dictionary."""
if particles: # TODO also support uxgrid search
for dim in ["X", "Y"]:
if dim in position:
if dim in grid_positions:
particles.state = np.maximum(
np.where(position[dim][0] == -1, StatusCode.ErrorOutOfBounds, particles.state), particles.state
np.where(grid_positions[dim]["index"] == -1, StatusCode.ErrorOutOfBounds, particles.state),
particles.state,
)
particles.state = np.maximum(
np.where(position[dim][0] == GRID_SEARCH_ERROR, StatusCode.ErrorGridSearching, particles.state),
np.where(
grid_positions[dim]["index"] == GRID_SEARCH_ERROR,
StatusCode.ErrorGridSearching,
particles.state,
),
particles.state,
)
if "Z" in position:
if "Z" in grid_positions:
particles.state = np.maximum(
np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particles.state),
np.where(
grid_positions["Z"]["index"] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particles.state
),
particles.state,
)
particles.state = np.maximum(
np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particles.state),
np.where(
grid_positions["Z"]["index"] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particles.state
),
particles.state,
)

Expand Down Expand Up @@ -469,3 +472,14 @@ def _assert_same_time_interval(fields: list[Field]) -> None:
raise ValueError(
f"Fields must have the same time domain. {fields[0].name}: {reference_time_interval}, {field.name}: {field.time_interval}"
)


def _get_positions(field: Field, time, z, y, x, particles, _ei) -> tuple[dict, dict]:
"""Initialize and populate particle_positions and grid_positions dictionaries"""
particle_positions = {"time": time, "z": z, "lat": y, "lon": x}
grid_positions = {}
grid_positions.update(_search_time_index(field, time))
grid_positions.update(field.grid.search(z, y, x, ei=_ei))
_update_particles_ei(particles, grid_positions, field)
_update_particle_states_position(particles, grid_positions)
return particle_positions, grid_positions
9 changes: 7 additions & 2 deletions src/parcels/_core/index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,19 @@ def _search_time_index(field: Field, time: datetime):
if the sampled value is outside the time value range.
"""
if field.time_interval is None:
return np.zeros(shape=time.shape, dtype=np.float32), np.zeros(shape=time.shape, dtype=np.int32)
return {
"T": {
"index": np.zeros(shape=time.shape, dtype=np.int32),
"bcoord": np.zeros(shape=time.shape, dtype=np.float32),
}
}

if not field.time_interval.is_all_time_in_interval(time):
_raise_outside_time_interval_error(time, field=None)

ti = np.searchsorted(field.data.time.data, time, side="right") - 1
tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti])
return np.atleast_1d(tau), np.atleast_1d(ti)
return {"T": {"index": np.atleast_1d(ti), "bcoord": np.atleast_1d(tau)}}


def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray, xi: np.ndarray):
Expand Down
8 changes: 4 additions & 4 deletions src/parcels/_core/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,12 @@ def _neighbors_by_coor(self, coor):
def populate_indices(self):
"""Pre-populate guesses of particle ei (element id) indices"""
for i, grid in enumerate(self.fieldset.gridset):
position = grid.search(self.z, self.lat, self.lon)
grid_positions = grid.search(self.z, self.lat, self.lon)
self._data["ei"][:, i] = grid.ravel_index(
{
"X": position["X"][0],
"Y": position["Y"][0],
"Z": position["Z"][0],
"X": grid_positions["X"]["index"],
"Y": grid_positions["Y"]["index"],
"Z": grid_positions["Z"]["index"],
}
)

Expand Down
2 changes: 1 addition & 1 deletion src/parcels/_core/uxgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,4 @@ def search(self, z, y, x, ei=None, tol=1e-6):
coords[zero_indices, :] = coords_q
fi[zero_indices] = face_ids_q

return {"Z": (zi, zeta), "FACE": (fi, coords)}
return {"Z": {"index": zi, "bcoord": zeta}, "FACE": {"index": fi, "bcoord": coords}}
12 changes: 10 additions & 2 deletions src/parcels/_core/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,11 @@ def search(self, z, y, x, ei=None):
if ds.lon.ndim == 1:
yi, eta = _search_1d_array(ds.lat.values, y)
xi, xsi = _search_1d_array(ds.lon.values, x)
return {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)}
return {
"Z": {"index": zi, "bcoord": zeta},
"Y": {"index": yi, "bcoord": eta},
"X": {"index": xi, "bcoord": xsi},
}

yi, xi = None, None
if ei is not None:
Expand All @@ -300,7 +304,11 @@ def search(self, z, y, x, ei=None):
if ds.lon.ndim == 2:
yi, eta, xi, xsi = _search_indices_curvilinear_2d(self, y, x, yi, xi)

return {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)}
return {
"Z": {"index": zi, "bcoord": zeta},
"Y": {"index": yi, "bcoord": eta},
"X": {"index": xi, "bcoord": xsi},
}

raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")

Expand Down
Loading