Skip to content

Commit d38ba97

Browse files
Improve typechecking (#2541)
Co-authored-by: Claude <[email protected]>
1 parent 831c43d commit d38ba97

File tree

15 files changed

+109
-88
lines changed

15 files changed

+109
-88
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,6 @@ known-first-party = ["parcels"]
159159

160160
[tool.ty.src]
161161
include = ["./src/"]
162+
exclude = [
163+
"./src/parcels/interpolators/", # ignore for now
164+
]

src/parcels/_compat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
KMeans: Any | None = None
99

1010
try:
11-
from mpi4py import MPI # type: ignore[no-redef]
11+
from mpi4py import MPI # type: ignore[import-untyped,no-redef]
1212
except ModuleNotFoundError:
1313
pass
1414

1515
# KMeans is used in MPI. sklearn not installed by default
1616
try:
17-
from sklearn.cluster import KMeans # type: ignore[no-redef]
17+
from sklearn.cluster import KMeans
1818
except ModuleNotFoundError:
1919
pass
2020

src/parcels/_core/field.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from collections.abc import Callable
4+
from collections.abc import Callable, Sequence
55
from datetime import datetime
66

77
import numpy as np
@@ -428,7 +428,7 @@ def _assert_valid_uxdataarray(data: ux.UxDataArray):
428428
)
429429

430430

431-
def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: ux.Grid | XGrid):
431+
def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: UxGrid | XGrid):
432432
if isinstance(data, ux.UxDataArray):
433433
if not isinstance(grid, UxGrid):
434434
raise ValueError(
@@ -448,7 +448,7 @@ def _get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | No
448448
return TimeInterval(data.time.values[0], data.time.values[-1])
449449

450450

451-
def _assert_same_time_interval(fields: list[Field]) -> None:
451+
def _assert_same_time_interval(fields: Sequence[Field]) -> None:
452452
if len(fields) == 0:
453453
return
454454

src/parcels/_core/fieldset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"):
223223
)
224224

225225
for varname in set(ds.data_vars) - set(fields.keys()):
226-
fields[varname] = Field(varname, ds[varname], grid, _select_uxinterpolator(ds[varname]))
226+
fields[varname] = Field(str(varname), ds[varname], grid, _select_uxinterpolator(ds[varname]))
227227

228228
return cls(list(fields.values()))
229229

@@ -319,7 +319,7 @@ def from_sgrid_conventions(
319319
)
320320

321321
for varname in set(ds.data_vars) - set(fields.keys()) - skip_vars:
322-
fields[varname] = Field(varname, ds[varname], grid, XLinear)
322+
fields[varname] = Field(str(varname), ds[varname], grid, XLinear)
323323

324324
return cls(list(fields.values()))
325325

@@ -353,7 +353,7 @@ def _datetime_to_msg(example_datetime: TimeLike) -> str:
353353
return msg
354354

355355

356-
def _format_calendar_error_message(field: Field, reference_datetime: TimeLike) -> str:
356+
def _format_calendar_error_message(field: Field | VectorField, reference_datetime: TimeLike) -> str:
357357
return f"Expected field {field.name!r} to have calendar compatible with datetime object {_datetime_to_msg(reference_datetime)}. Got field with calendar {_datetime_to_msg(field.time_interval.left)}. Have you considered using xarray to update the time dimension of the dataset to have a compatible calendar?"
358358

359359

src/parcels/_core/index_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
if TYPE_CHECKING:
1212
from parcels._core.field import Field
13-
from parcels.xgrid import XGrid
13+
from parcels._core.xgrid import XGrid
1414

1515

1616
GRID_SEARCH_ERROR = -3
@@ -19,7 +19,7 @@
1919

2020

2121
def _search_1d_array(
22-
arr: np.array,
22+
arr: np.ndarray,
2323
x: float,
2424
) -> tuple[int, int]:
2525
"""

src/parcels/_core/particle.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import operator
4-
from typing import Literal
4+
from typing import Any, Literal
55

66
import numpy as np
77

@@ -37,7 +37,7 @@ class Variable:
3737
def __init__(
3838
self,
3939
name,
40-
dtype: np.dtype = np.float32,
40+
dtype: np.dtype[Any] | type[np.generic] = np.float32,
4141
initial=0,
4242
to_write: bool | Literal["once"] = True,
4343
attrs: dict | None = None,
@@ -122,7 +122,7 @@ def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_va
122122
raise ValueError(f"Variable name already exists: {var.name}")
123123

124124

125-
def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClass:
125+
def get_default_particle(spatial_dtype: type[np.float32] | type[np.float64]) -> ParticleClass:
126126
if spatial_dtype not in [np.float32, np.float64]:
127127
raise ValueError(f"spatial_dtype must be np.float32 or np.float64. Got {spatial_dtype=!r}")
128128

@@ -177,7 +177,7 @@ def create_particle_data(
177177
nparticles: int,
178178
ngrids: int,
179179
time_interval: TimeInterval,
180-
initial: dict[str, np.array] | None = None,
180+
initial: dict[str, np.ndarray] | None = None,
181181
):
182182
if initial is None:
183183
initial = {}

src/parcels/_core/particlefile.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,10 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in
196196
if self.create_new_zarrfile:
197197
if self.chunks is None:
198198
self._chunks = (nparticles, 1)
199-
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index]
200-
arrsize = (self._maxids, self.chunks[1]) # type: ignore[index]
199+
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]):
200+
arrsize = (self._maxids, self.chunks[1])
201201
else:
202-
arrsize = (len(ids), self.chunks[1]) # type: ignore[index]
202+
arrsize = (len(ids), self.chunks[1])
203203
ds = xr.Dataset(
204204
attrs=self.metadata,
205205
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
@@ -221,7 +221,7 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in
221221
data[ids, 0] = particle_data[var.name][indices_to_write]
222222
dims = ["trajectory", "obs"]
223223
ds[var.name] = xr.DataArray(data=data, dims=dims, attrs=attrs[var.name])
224-
ds[var.name].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks # type: ignore[index]
224+
ds[var.name].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks
225225
ds.to_zarr(store, mode="w")
226226
self._create_new_zarrfile = False
227227
else:
@@ -234,7 +234,7 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in
234234
if len(once_ids) > 0:
235235
Z[var.name].vindex[ids_once] = particle_data[var.name][indices_to_write_once]
236236
else:
237-
if max(obs) >= Z[var.name].shape[1]: # type: ignore[type-var]
237+
if max(obs) >= Z[var.name].shape[1]:
238238
self._extend_zarr_dims(Z[var.name], store, dtype=var.dtype, axis=1)
239239
Z[var.name].vindex[ids, obs] = particle_data[var.name][indices_to_write]
240240

src/parcels/_core/utils/interpolation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from parcels._typing import Mesh
77

8-
__all__ = [] # type: ignore
8+
__all__ = []
99

1010

1111
def phi1D_lin(xsi: float) -> list[float]:

src/parcels/_core/utils/sgrid.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import re
1616
from collections.abc import Hashable, Iterable
1717
from dataclasses import dataclass
18-
from typing import Any, Literal, Protocol, Self, overload
18+
from typing import Any, Literal, Protocol, Self, cast, overload
1919

2020
import xarray as xr
2121

@@ -149,13 +149,13 @@ def __eq__(self, other: Any) -> bool:
149149
return self.to_attrs() == other.to_attrs()
150150

151151
@classmethod
152-
def from_attrs(cls, attrs):
152+
def from_attrs(cls, attrs): # type: ignore[override]
153153
try:
154154
return cls(
155155
cf_role=attrs["cf_role"],
156156
topology_dimension=attrs["topology_dimension"],
157-
node_dimensions=load_mappings(attrs["node_dimensions"]),
158-
face_dimensions=load_mappings(attrs["face_dimensions"]),
157+
node_dimensions=cast(tuple[Dim, Dim], load_mappings(attrs["node_dimensions"])),
158+
face_dimensions=cast(tuple[DimDimPadding, DimDimPadding], load_mappings(attrs["face_dimensions"])),
159159
node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")),
160160
vertical_dimensions=maybe_load_mappings(attrs.get("vertical_dimensions")),
161161
)
@@ -176,7 +176,7 @@ def to_attrs(self) -> dict[str, str | int]:
176176
return d
177177

178178
def rename(self, names_dict: dict[str, str]) -> Self:
179-
return _metadata_rename(self, names_dict)
179+
return cast(Self, _metadata_rename(self, names_dict))
180180

181181
def get_value_by_id(self, id: str) -> str:
182182
"""In the SGRID specification for 2D grids, different parts of the spec are identified by different "ID"s.
@@ -262,13 +262,15 @@ def __eq__(self, other: Any) -> bool:
262262
return self.to_attrs() == other.to_attrs()
263263

264264
@classmethod
265-
def from_attrs(cls, attrs):
265+
def from_attrs(cls, attrs): # type: ignore[override]
266266
try:
267267
return cls(
268268
cf_role=attrs["cf_role"],
269269
topology_dimension=attrs["topology_dimension"],
270-
node_dimensions=load_mappings(attrs["node_dimensions"]),
271-
volume_dimensions=load_mappings(attrs["volume_dimensions"]),
270+
node_dimensions=cast(tuple[Dim, Dim, Dim], load_mappings(attrs["node_dimensions"])),
271+
volume_dimensions=cast(
272+
tuple[DimDimPadding, DimDimPadding, DimDimPadding], load_mappings(attrs["volume_dimensions"])
273+
),
272274
node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")),
273275
)
274276
except Exception as e:
@@ -286,7 +288,7 @@ def to_attrs(self) -> dict[str, str | int]:
286288
return d
287289

288290
def rename(self, dims_dict: dict[str, str]) -> Self:
289-
return _metadata_rename(self, dims_dict)
291+
return cast(Self, _metadata_rename(self, dims_dict))
290292

291293
def get_value_by_id(self, id: str) -> str:
292294
"""In the SGRID specification for 3D grids, different parts of the spec are identified by different "ID"s.

src/parcels/_core/uxgrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class UxGrid(BaseGrid):
1818
for interpolation on unstructured grids.
1919
"""
2020

21-
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh) -> UxGrid:
21+
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh) -> None:
2222
"""
2323
Initializes the UxGrid with a uxarray grid and vertical coordinate array.
2424

0 commit comments

Comments
 (0)