Skip to content

Commit ccb3b52

Browse files
Merge pull request #2348 from Parcels-code/updates_from_virtualship_dev
Fixes/tests from virtualship development at EDITO hackathon
2 parents 3a5a2da + a0fb6fa commit ccb3b52

File tree

5 files changed

+72
-6
lines changed

5 files changed

+72
-6
lines changed

docs/community/v4-migration.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,12 @@ Version 4 of Parcels is unreleased at the moment. The information in this migrat
3434

3535
- Particlefiles should be created by `ParticleFile(...)` instead of `pset.ParticleFile(...)`
3636
- The `name` argument in `ParticleFile` has been replaced by `store` and can now be a string, a Path or a zarr store.
37+
38+
## Field
39+
40+
- `Field.eval()` returns an array of floats instead of a single float (related to the vectorization)
41+
- `Field.eval()` does not throw OutOfBounds or other errors
42+
43+
## GridSet
44+
45+
- `GridSet` is now a list, so change `fieldset.gridset.grids[0]` to `fieldset.gridset[0]`.

src/parcels/_core/index_search.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _search_1d_array(
4444
# TODO v4: We probably rework this to deal with 0D arrays before this point (as we already know field dimensionality)
4545
if len(arr) < 2:
4646
return np.zeros(shape=x.shape, dtype=np.int32), np.zeros_like(x)
47-
index = np.searchsorted(arr, x, side="right") - 1
47+
index = np.clip(np.searchsorted(arr, x, side="right") - 1, 0, len(arr) - 2)
4848
# Use broadcasting to avoid repeated array access
4949
arr_index = arr[index]
5050
arr_next = arr[np.clip(index + 1, 1, len(arr) - 1)] # Ensure we don't go out of bounds
@@ -57,7 +57,7 @@ def _search_1d_array(
5757
# bcoord = (x - arr[index]) / dx
5858

5959
index = np.where(x < arr[0], LEFT_OUT_OF_BOUNDS, index)
60-
index = np.where(x >= arr[-1], RIGHT_OUT_OF_BOUNDS, index)
60+
index = np.where(x > arr[-1], RIGHT_OUT_OF_BOUNDS, index)
6161

6262
return np.atleast_1d(index), np.atleast_1d(bcoord)
6363

@@ -85,8 +85,8 @@ def _search_time_index(field: Field, time: datetime):
8585
if not field.time_interval.is_all_time_in_interval(time):
8686
_raise_outside_time_interval_error(time, field=None)
8787

88-
ti = np.searchsorted(field.data.time.data, time, side="right") - 1
89-
tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti])
88+
ti, tau = _search_1d_array(field.data.time.data, time)
89+
9090
return {"T": {"index": np.atleast_1d(ti), "bcoord": np.atleast_1d(tau)}}
9191

9292

src/parcels/_core/xgrid.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,12 @@ def __init__(self, grid: xgcm.Grid, mesh="flat"):
104104
if "lat" in ds:
105105
ds.set_coords("lat")
106106

107-
if len(set(grid.axes) & {"X", "Y", "Z"}) > 0: # Only if spatial grid is >0D (see #2054 for further development)
107+
if len(set(grid.axes) & {"X", "Y"}) > 0: # Only if spatial grid is >0D (see #2054 for further development)
108108
assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes)
109109

110+
if "Z" in grid.axes:
111+
assert_valid_depth(ds["depth"])
112+
110113
assert_valid_mesh(mesh)
111114
self._ds = ds
112115

@@ -482,6 +485,14 @@ def assert_valid_lat_lon(da_lat, da_lon, axes: _XGCM_AXES):
482485
)
483486

484487

488+
def assert_valid_depth(da_depth):
489+
if not np.all(np.diff(da_depth.values) > 0):
490+
raise ValueError(
491+
f"Depth DataArray {da_depth.name!r} with dims {da_depth.dims} must be strictly increasing. "
492+
f'HINT: you may be able to use ds.reindex to flip depth - e.g., ds = ds.reindex({da_depth.name}=ds["{da_depth.name}"][::-1])'
493+
)
494+
495+
485496
def _convert_center_pos_to_fpoint(
486497
*, index: int, bcoord: float, xgcm_position: _XGCM_AXIS_POSITION, f_points_xgcm_position: _XGCM_AXIS_POSITION
487498
) -> tuple[int, float]:

tests/test_particleset_execute.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,45 @@ def test_particleset_endtime_type(fieldset, endtime, expectation):
150150
pset.execute(endtime=endtime, dt=np.timedelta64(10, "m"), pyfunc=DoNothing)
151151

152152

153+
def test_particleset_run_to_endtime(fieldset):
154+
starttime = fieldset.time_interval.left
155+
endtime = fieldset.time_interval.right
156+
157+
def SampleU(particles, fieldset): # pragma: no cover
158+
_ = fieldset.U[particles]
159+
160+
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], time=[starttime])
161+
pset.execute(SampleU, endtime=endtime, dt=np.timedelta64(1, "D"))
162+
assert pset[0].time == endtime
163+
164+
165+
def test_particleset_interpolate_on_domainedge(zonal_flow_fieldset):
166+
fieldset = zonal_flow_fieldset
167+
168+
MyParticle = Particle.add_variable(Variable("var"))
169+
170+
def SampleU(particles, fieldset): # pragma: no cover
171+
particles.var = fieldset.U[particles]
172+
173+
print(fieldset.U.grid.lon)
174+
pset = ParticleSet(fieldset, pclass=MyParticle, lon=fieldset.U.grid.lon[-1], lat=fieldset.U.grid.lat[-1])
175+
pset.execute(SampleU, runtime=np.timedelta64(1, "D"), dt=np.timedelta64(1, "D"))
176+
np.testing.assert_equal(pset[0].var, 1)
177+
178+
179+
def test_particleset_interpolate_outside_domainedge(zonal_flow_fieldset):
180+
fieldset = zonal_flow_fieldset
181+
182+
def SampleU(particles, fieldset): # pragma: no cover
183+
particles.dlon = fieldset.U[particles]
184+
185+
dlat = 1e-3
186+
pset = ParticleSet(fieldset, lon=fieldset.U.grid.lon[-1], lat=fieldset.U.grid.lat[-1] + dlat)
187+
188+
with pytest.raises(FieldOutOfBoundError):
189+
pset.execute(SampleU, runtime=np.timedelta64(1, "D"), dt=np.timedelta64(1, "D"))
190+
191+
153192
@pytest.mark.parametrize(
154193
"dt", [np.timedelta64(1, "s"), np.timedelta64(1, "ms"), np.timedelta64(10, "ms"), np.timedelta64(1, "ns")]
155194
)
@@ -329,7 +368,6 @@ def MoveRight(particles, fieldset): # pragma: no cover
329368

330369
def MoveLeft(particles, fieldset): # pragma: no cover
331370
inds = np.where(particles.state == StatusCode.ErrorOutOfBounds)
332-
print(inds, particles.state)
333371
particles[inds].dlon -= 1.0
334372
particles[inds].state = StatusCode.Success
335373

tests/test_xgrid.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ def test_invalid_lon_lat():
126126
XGrid.from_dataset(ds)
127127

128128

129+
def test_invalid_depth():
130+
ds = datasets["ds_2d_left"].copy()
131+
ds = ds.reindex({"ZG": ds.ZG[::-1]})
132+
133+
with pytest.raises(ValueError, match="Depth DataArray .* must be strictly increasing*"):
134+
XGrid.from_dataset(ds)
135+
136+
129137
@pytest.mark.parametrize(
130138
"ds",
131139
[

0 commit comments

Comments
 (0)