Skip to content

Commit 5acda53

Browse files
Merge branch 'v4-dev' into reuse-cache-file
2 parents 376acda + d4a0a52 commit 5acda53

File tree

13 files changed

+167
-73
lines changed

13 files changed

+167
-73
lines changed

src/parcels/_core/field.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AllParcelsErrorCodes,
1919
StatusCode,
2020
)
21+
from parcels._core.utils.string import _assert_str_and_python_varname
2122
from parcels._core.utils.time import TimeInterval
2223
from parcels._core.uxgrid import UxGrid
2324
from parcels._core.xgrid import XGrid, _transpose_xfield_data_to_tzyx
@@ -101,8 +102,9 @@ def __init__(
101102
raise ValueError(
102103
f"Expected `data` to be a uxarray.UxDataArray or xarray.DataArray object, got {type(data)}."
103104
)
104-
if not isinstance(name, str):
105-
raise ValueError(f"Expected `name` to be a string, got {type(name)}.")
105+
106+
_assert_str_and_python_varname(name)
107+
106108
if not isinstance(grid, (UxGrid, XGrid)):
107109
raise ValueError(f"Expected `grid` to be a parcels UxGrid, or parcels XGrid object, got {type(grid)}.")
108110

@@ -246,6 +248,8 @@ class VectorField:
246248
def __init__(
247249
self, name: str, U: Field, V: Field, W: Field | None = None, vector_interp_method: Callable | None = None
248250
):
251+
_assert_str_and_python_varname(name)
252+
249253
self.name = name
250254
self.U = U
251255
self.V = V

src/parcels/_core/fieldset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from parcels._core.converters import Geographic, GeographicPolar
1313
from parcels._core.field import Field, VectorField
14+
from parcels._core.utils.string import _assert_str_and_python_varname
1415
from parcels._core.utils.time import get_datetime_type_calendar
1516
from parcels._core.utils.time import is_compatible as datetime_is_compatible
1617
from parcels._core.xgrid import _DEFAULT_XGCM_KWARGS, XGrid
@@ -163,6 +164,8 @@ def add_constant(self, name, value):
163164
`Diffusion <../examples/tutorial_diffusion.ipynb>`__
164165
`Periodic boundaries <../examples/tutorial_periodic_boundaries.ipynb>`__
165166
"""
167+
_assert_str_and_python_varname(name)
168+
166169
if name in self.constants:
167170
raise ValueError(f"FieldSet already has a constant with name '{name}'")
168171
if not isinstance(value, (float, np.floating, int, np.integer)):

src/parcels/_core/particle.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import enum
44
import operator
5-
from keyword import iskeyword
65
from typing import Literal
76

87
import numpy as np
98

109
from parcels._compat import _attrgetter_helper
1110
from parcels._core.statuscodes import StatusCode
11+
from parcels._core.utils.string import _assert_str_and_python_varname
1212
from parcels._core.utils.time import TimeInterval
1313
from parcels._reprs import _format_list_items_multiline
1414

@@ -45,9 +45,7 @@ def __init__(
4545
to_write: bool | Literal["once"] = True,
4646
attrs: dict | None = None,
4747
):
48-
if not isinstance(name, str):
49-
raise TypeError(f"Variable name must be a string. Got {name=!r}")
50-
_assert_valid_python_varname(name)
48+
_assert_str_and_python_varname(name)
5149

5250
try:
5351
dtype = np.dtype(dtype)
@@ -153,12 +151,6 @@ def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_va
153151
raise ValueError(f"Variable name already exists: {var.name}")
154152

155153

156-
def _assert_valid_python_varname(name):
157-
if name.isidentifier() and not iskeyword(name):
158-
return
159-
raise ValueError(f"Particle variable has to be a valid Python variable name. Got {name=!r}")
160-
161-
162154
def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClass:
163155
if spatial_dtype not in [np.float32, np.float64]:
164156
raise ValueError(f"spatial_dtype must be np.float32 or np.float64. Got {spatial_dtype=!r}")

src/parcels/_core/utils/string.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from keyword import iskeyword, kwlist
2+
3+
4+
def _assert_str_and_python_varname(name):
5+
if not isinstance(name, str):
6+
raise TypeError(f"Expected a string for variable name, got {type(name).__name__} instead.")
7+
8+
msg = f"Received invalid Python variable name {name!r}: "
9+
10+
if not name.isidentifier():
11+
msg += "not a valid identifier. HINT: avoid using spaces, special characters, and starting with a number."
12+
raise ValueError(msg)
13+
14+
if iskeyword(name):
15+
msg += f"it is a reserved keyword. HINT: avoid using the following names: {', '.join(kwlist)}"
16+
raise ValueError(msg)

src/parcels/_datasets/structured/generic.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ def _rotated_curvilinear_grid():
2323
{
2424
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
2525
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
26-
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
27-
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
28-
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
29-
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
26+
"U_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
27+
"V_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
28+
"U_C_grid": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
29+
"V_C_grid": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
3030
},
3131
coords={
3232
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
@@ -92,16 +92,19 @@ def _unrolled_cone_curvilinear_grid():
9292
new_lon_lat.append((lon + pivot[0], lat + pivot[1]))
9393

9494
new_lon, new_lat = zip(*new_lon_lat, strict=True)
95-
LON, LAT = np.array(new_lon).reshape(LON.shape), np.array(new_lat).reshape(LAT.shape)
95+
LON, LAT = (
96+
np.array(new_lon).reshape(LON.shape),
97+
np.array(new_lat).reshape(LAT.shape),
98+
)
9699

97100
return xr.Dataset(
98101
{
99102
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
100103
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
101-
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
102-
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
103-
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
104-
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
104+
"U_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
105+
"V_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
106+
"U_C_grid": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
107+
"V_C_grid": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
105108
},
106109
coords={
107110
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
@@ -140,10 +143,10 @@ def _unrolled_cone_curvilinear_grid():
140143
{
141144
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
142145
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
143-
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
144-
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
145-
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
146-
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
146+
"U_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
147+
"V_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
148+
"U_C_grid": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
149+
"V_C_grid": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
147150
},
148151
coords={
149152
"XG": (
@@ -182,10 +185,10 @@ def _unrolled_cone_curvilinear_grid():
182185
{
183186
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
184187
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
185-
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
186-
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
187-
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
188-
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
188+
"U_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
189+
"V_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
190+
"U_C_grid": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
191+
"V_C_grid": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
189192
},
190193
coords={
191194
"XG": (

tests/test_field.py

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,27 @@
1515
def test_field_init_param_types():
1616
data = datasets_structured["ds_2d_left"]
1717
grid = XGrid.from_dataset(data)
18-
with pytest.raises(ValueError, match="Expected `name` to be a string"):
18+
19+
with pytest.raises(TypeError, match="Expected a string for variable name, got int instead."):
1920
Field(name=123, data=data["data_g"], grid=grid)
2021

21-
with pytest.raises(ValueError, match="Expected `data` to be a uxarray.UxDataArray or xarray.DataArray"):
22+
for name in ["a b", "123"]:
23+
with pytest.raises(
24+
ValueError,
25+
match=r"Received invalid Python variable name.*: not a valid identifier. HINT: avoid using spaces, special characters, and starting with a number.",
26+
):
27+
Field(name=name, data=data["data_g"], grid=grid)
28+
29+
with pytest.raises(
30+
ValueError,
31+
match=r"Received invalid Python variable name.*: it is a reserved keyword. HINT: avoid using the following names:.*",
32+
):
33+
Field(name="while", data=data["data_g"], grid=grid)
34+
35+
with pytest.raises(
36+
ValueError,
37+
match="Expected `data` to be a uxarray.UxDataArray or xarray.DataArray",
38+
):
2239
Field(name="test", data=123, grid=grid)
2340

2441
with pytest.raises(ValueError, match="Expected `grid` to be a parcels UxGrid, or parcels XGrid"):
@@ -28,7 +45,11 @@ def test_field_init_param_types():
2845
@pytest.mark.parametrize(
2946
"data,grid",
3047
[
31-
pytest.param(ux.UxDataArray(), XGrid.from_dataset(datasets_structured["ds_2d_left"]), id="uxdata-grid"),
48+
pytest.param(
49+
ux.UxDataArray(),
50+
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
51+
id="uxdata-grid",
52+
),
3253
pytest.param(
3354
xr.DataArray(),
3455
UxGrid(
@@ -76,7 +97,11 @@ def test_field_init_fail_on_float_time_dim():
7697
(users are expected to use timedelta64 or datetime).
7798
"""
7899
ds = datasets_structured["ds_2d_left"].copy()
79-
ds["time"] = (ds["time"].dims, np.arange(0, T_structured, dtype="float64"), ds["time"].attrs)
100+
ds["time"] = (
101+
ds["time"].dims,
102+
np.arange(0, T_structured, dtype="float64"),
103+
ds["time"].attrs,
104+
)
80105

81106
data = ds["data_g"]
82107
grid = XGrid.from_dataset(ds)
@@ -122,7 +147,12 @@ def invalid_interpolator_wrong_signature(self, ti, position, tau, t, z, y, inval
122147

123148
# Test invalid interpolator with wrong signature
124149
with pytest.raises(ValueError, match=".*incorrect name.*"):
125-
Field(name="test", data=ds["data_g"], grid=grid, interp_method=invalid_interpolator_wrong_signature)
150+
Field(
151+
name="test",
152+
data=ds["data_g"],
153+
grid=grid,
154+
interp_method=invalid_interpolator_wrong_signature,
155+
)
126156

127157

128158
def test_vectorfield_invalid_interpolator():
@@ -138,7 +168,12 @@ def invalid_interpolator_wrong_signature(self, ti, position, tau, t, z, y, apply
138168

139169
# Test invalid interpolator with wrong signature
140170
with pytest.raises(ValueError, match=".*incorrect name.*"):
141-
VectorField(name="UV", U=U, V=V, vector_interp_method=invalid_interpolator_wrong_signature)
171+
VectorField(
172+
name="UV",
173+
U=U,
174+
V=V,
175+
vector_interp_method=invalid_interpolator_wrong_signature,
176+
)
142177

143178

144179
def test_field_unstructured_z_linear():
@@ -161,18 +196,34 @@ def test_field_unstructured_z_linear():
161196
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseConstantFace)
162197

163198
# Test above first cell center - for piecewise constant, should return the depth of the first cell center
164-
assert np.isclose(P.eval(time=ds.time[0].values, z=[10.0], y=[30.0], x=[30.0], applyConversion=False), 55.555557)
199+
assert np.isclose(
200+
P.eval(time=ds.time[0].values, z=[10.0], y=[30.0], x=[30.0], applyConversion=False),
201+
55.555557,
202+
)
165203
# Test below first cell center, but in the first layer - for piecewise constant, should return the depth of the first cell center
166-
assert np.isclose(P.eval(time=ds.time[0].values, z=[65.0], y=[30.0], x=[30.0], applyConversion=False), 55.555557)
204+
assert np.isclose(
205+
P.eval(time=ds.time[0].values, z=[65.0], y=[30.0], x=[30.0], applyConversion=False),
206+
55.555557,
207+
)
167208
# Test bottom layer - for piecewise constant, should return the depth of the of the bottom layer cell center
168209
assert np.isclose(
169-
P.eval(time=ds.time[0].values, z=[900.0], y=[30.0], x=[30.0], applyConversion=False), 944.44445801
210+
P.eval(time=ds.time[0].values, z=[900.0], y=[30.0], x=[30.0], applyConversion=False),
211+
944.44445801,
170212
)
171213

172214
W = Field(name="W", data=ds.W, grid=grid, interp_method=UXPiecewiseLinearNode)
173-
assert np.isclose(W.eval(time=ds.time[0].values, z=[10.0], y=[30.0], x=[30.0], applyConversion=False), 10.0)
174-
assert np.isclose(W.eval(time=ds.time[0].values, z=[65.0], y=[30.0], x=[30.0], applyConversion=False), 65.0)
175-
assert np.isclose(W.eval(time=ds.time[0].values, z=[900.0], y=[30.0], x=[30.0], applyConversion=False), 900.0)
215+
assert np.isclose(
216+
W.eval(time=ds.time[0].values, z=[10.0], y=[30.0], x=[30.0], applyConversion=False),
217+
10.0,
218+
)
219+
assert np.isclose(
220+
W.eval(time=ds.time[0].values, z=[65.0], y=[30.0], x=[30.0], applyConversion=False),
221+
65.0,
222+
)
223+
assert np.isclose(
224+
W.eval(time=ds.time[0].values, z=[900.0], y=[30.0], x=[30.0], applyConversion=False),
225+
900.0,
226+
)
176227

177228

178229
def test_field_constant_in_time():
@@ -185,7 +236,13 @@ def test_field_constant_in_time():
185236
# Assert that the field can be evaluated at any time, and returns the same value
186237
time = np.datetime64("2000-01-01T00:00:00")
187238
P1 = P.eval(time=time, z=[10.0], y=[30.0], x=[30.0], applyConversion=False)
188-
P2 = P.eval(time=time + np.timedelta64(1, "D"), z=[10.0], y=[30.0], x=[30.0], applyConversion=False)
239+
P2 = P.eval(
240+
time=time + np.timedelta64(1, "D"),
241+
z=[10.0],
242+
y=[30.0],
243+
x=[30.0],
244+
applyConversion=False,
245+
)
189246
assert np.isclose(P1, P2)
190247

191248

tests/test_fieldset.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
def fieldset() -> FieldSet:
2121
"""Fixture to create a FieldSet object for testing."""
2222
grid = XGrid.from_dataset(ds, mesh="flat")
23-
U = Field("U", ds["U (A grid)"], grid)
24-
V = Field("V", ds["V (A grid)"], grid)
23+
U = Field("U", ds["U_A_grid"], grid)
24+
V = Field("V", ds["V_A_grid"], grid)
2525
UV = VectorField("UV", U, V)
2626

2727
return FieldSet(
@@ -39,6 +39,17 @@ def test_fieldset_add_constant(fieldset):
3939
assert fieldset.test_constant == 1.0
4040

4141

42+
def test_fieldset_add_constant_int_name(fieldset):
43+
with pytest.raises(TypeError, match="Expected a string for variable name, got int instead."):
44+
fieldset.add_constant(123, 1.0)
45+
46+
47+
@pytest.mark.parametrize("name", ["a b", "123", "while"])
48+
def test_fieldset_add_constant_invalid_name(fieldset, name):
49+
with pytest.raises(ValueError, match=r"Received invalid Python variable name.*"):
50+
fieldset.add_constant(name, 1.0)
51+
52+
4253
def test_fieldset_add_constant_field(fieldset):
4354
fieldset.add_constant_field("test_constant_field", 1.0)
4455

@@ -54,7 +65,7 @@ def test_fieldset_add_constant_field(fieldset):
5465

5566
def test_fieldset_add_field(fieldset):
5667
grid = XGrid.from_dataset(ds, mesh="flat")
57-
field = Field("test_field", ds["U (A grid)"], grid)
68+
field = Field("test_field", ds["U_A_grid"], grid)
5869
fieldset.add_field(field)
5970
assert fieldset.test_field == field
6071

@@ -67,7 +78,7 @@ def test_fieldset_add_field_wrong_type(fieldset):
6778

6879
def test_fieldset_add_field_already_exists(fieldset):
6980
grid = XGrid.from_dataset(ds, mesh="flat")
70-
field = Field("test_field", ds["U (A grid)"], grid)
81+
field = Field("test_field", ds["U_A_grid"], grid)
7182
fieldset.add_field(field, "test_field")
7283
with pytest.raises(ValueError, match="FieldSet already has a Field with name 'test_field'"):
7384
fieldset.add_field(field, "test_field")
@@ -104,12 +115,12 @@ def test_fieldset_gridset_multiple_grids(): ...
104115

105116
def test_fieldset_time_interval():
106117
grid1 = XGrid.from_dataset(ds, mesh="flat")
107-
field1 = Field("field1", ds["U (A grid)"], grid1)
118+
field1 = Field("field1", ds["U_A_grid"], grid1)
108119

109120
ds2 = ds.copy()
110121
ds2["time"] = (ds2["time"].dims, ds2["time"].data + np.timedelta64(timedelta(days=1)), ds2["time"].attrs)
111122
grid2 = XGrid.from_dataset(ds2, mesh="flat")
112-
field2 = Field("field2", ds2["U (A grid)"], grid2)
123+
field2 = Field("field2", ds2["U_A_grid"], grid2)
113124

114125
fieldset = FieldSet([field1, field2])
115126
fieldset.add_constant_field("constant_field", 1.0)
@@ -135,8 +146,8 @@ def test_fieldset_init_incompatible_calendars():
135146
)
136147

137148
grid = XGrid.from_dataset(ds1, mesh="flat")
138-
U = Field("U", ds1["U (A grid)"], grid)
139-
V = Field("V", ds1["V (A grid)"], grid)
149+
U = Field("U", ds1["U_A_grid"], grid)
150+
V = Field("V", ds1["V_A_grid"], grid)
140151
UV = VectorField("UV", U, V)
141152

142153
ds2 = ds.copy()

tests/test_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
def fieldset() -> FieldSet:
1919
ds = datasets_structured["ds_2d_left"]
2020
grid = XGrid.from_dataset(ds, mesh="flat")
21-
U = Field("U", ds["U (A grid)"], grid)
22-
V = Field("V", ds["V (A grid)"], grid)
21+
U = Field("U", ds["U_A_grid"], grid)
22+
V = Field("V", ds["V_A_grid"], grid)
2323
return FieldSet([U, V])
2424

2525

0 commit comments

Comments
 (0)