Skip to content

Commit f5c368d

Browse files
committed
further improve typing
1 parent 616c975 commit f5c368d

File tree

3 files changed

+55
-18
lines changed

3 files changed

+55
-18
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,4 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
119119
warn_unreachable = true
120120
disallow_untyped_defs = false
121121
disallow_incomplete_defs = false
122-
disable_error_code = ["import-untyped", "import-not-found", "no-untyped-call"]
122+
disable_error_code = ["import-untyped", "import-not-found"]

yt_experiments/tiled_grid/tests/test_tiled_grid.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23
import unyt
34
from numpy.testing import assert_equal
45
from yt.testing import fake_amr_ds, requires_module
@@ -78,6 +79,16 @@ def test_arbitrary_grid_oct():
7879
assert level_arrays[ilev].shape == expected_levels[ilev]
7980

8081

82+
def test_missing_ds():
83+
with pytest.raises(ValueError, match="Please provide a dataset"):
84+
_ = YTTiledArbitraryGrid(
85+
unyt.unyt_array([0, 0, 0], "m"),
86+
unyt.unyt_array([1, 1, 1], "m"),
87+
(20, 20, 20),
88+
5,
89+
)
90+
91+
8192
@requires_module("xarray")
8293
def test_arbitrary_grid_to_xarray():
8394
import xarray as xr

yt_experiments/tiled_grid/tiled_grid.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from yt.data_objects.construction_data_containers import YTArbitraryGrid
99
from yt.data_objects.static_output import Dataset
1010

11+
_GridInfo = tuple[
12+
npt.NDArray, npt.NDArray, unyt.unyt_array, unyt.unyt_array, Any, npt.NDArray
13+
]
14+
1115

1216
def _validate_edge(edge: npt.ArrayLike, ds: Dataset):
1317
if not isinstance(edge, unyt.unyt_array):
@@ -62,6 +66,9 @@ def __init__(
6266
6367
"""
6468

69+
if ds is None:
70+
raise ValueError("Please provide a dataset via the ds keyword argument")
71+
6572
self.ds = ds
6673
self.left_edge = _validate_edge(left_edge, ds)
6774
self.right_edge = _validate_edge(right_edge, ds)
@@ -86,7 +93,7 @@ def __init__(
8693
self._left_cell_center = self.left_edge + self.dds / 2.0
8794
self._right_cell_center = self.right_edge - self.dds / 2.0
8895

89-
def __repr__(self):
96+
def __repr__(self) -> str:
9097
nm = self.__class__.__name__
9198
shape = tuple(self.dims)
9299
n_chunks = tuple(self.nchunks)
@@ -97,13 +104,13 @@ def __repr__(self):
97104
)
98105
return msg
99106

100-
def _get_grid_by_ijk(self, ijk_grid):
107+
def _get_grid_by_ijk(self, ijk_grid: npt.NDArray[int]) -> _GridInfo:
101108
chunksizes = self.chunks
102109

103110
le_index = []
104111
re_index = []
105-
le_val = self.ds.domain_left_edge.copy()
106-
re_val = self.ds.domain_right_edge.copy()
112+
le_val: unyt.unyt_array = self.ds.domain_left_edge.copy()
113+
re_val: unyt.unyt_array = self.ds.domain_right_edge.copy()
107114

108115
for idim in range(self._ndim):
109116
chunk_i = ijk_grid[idim]
@@ -122,29 +129,29 @@ def _get_grid_by_ijk(self, ijk_grid):
122129
le_index[2] : re_index[2],
123130
]
124131

125-
le_index = np.array(le_index, dtype=int)
126-
re_index = np.array(re_index, dtype=int)
132+
le_index_ = np.array(le_index, dtype=int)
133+
re_index_ = np.array(re_index, dtype=int)
127134
shape = chunksizes
128135

129-
return le_index, re_index, le_val, re_val, slc, shape
136+
return le_index_, re_index_, le_val, re_val, slc, shape
130137

131-
def _get_grid(self, igrid: int):
138+
def _get_grid(self, igrid: int) -> _GridInfo:
132139
# get grid extent of a **single** grid
133140
ijk_grid = np.unravel_index(igrid, self.nchunks)
134141
return self._get_grid_by_ijk(ijk_grid)
135142

136-
def _coord_array(self, idim):
143+
def _coord_array(self, idim: int) -> npt.NDArray:
137144
LE = self._left_cell_center[idim]
138145
RE = self._right_cell_center[idim]
139146
N = self.dims[idim]
140147
return np.mgrid[LE : RE : N * 1j]
141148

142-
def to_xarray(self, field, *, output_array=None):
149+
def to_xarray(
150+
self, field: tuple[str, str], *, output_array: npt.ArrayLike | None = None
151+
) -> Any:
143152

144153
import xarray as xr
145154

146-
# ToDo: import from on_demand_imports
147-
148155
vals = self.to_array(field, output_array=output_array)
149156

150157
dims = self.ds.coordinates.axis_order
@@ -162,7 +169,13 @@ def to_xarray(self, field, *, output_array=None):
162169
)
163170
return xr_ds
164171

165-
def single_grid_values(self, igrid, field, *, ops=None):
172+
def single_grid_values(
173+
self,
174+
igrid: int,
175+
field: tuple[str, str],
176+
*,
177+
ops: list[Callable[[npt.NDArray], npt.NDArray]] | None = None,
178+
) -> tuple[npt.NDArray, Any]:
166179
"""
167180
Get the values for a field for a single grid chunk as in-memory array.
168181
@@ -308,7 +321,9 @@ def __init__(
308321

309322
self.levels: list[YTTiledArbitraryGrid] = levels
310323

311-
def _validate_levels(self, levels):
324+
def _validate_levels(
325+
self, levels: Sequence[int | tuple[int, int, int] | npt.ArrayLike]
326+
):
312327

313328
for ilev in range(1, self.n_levels):
314329
res = np.prod(levels[ilev])
@@ -321,7 +336,7 @@ def _validate_levels(self, levels):
321336
)
322337
raise ValueError(msg)
323338

324-
def __repr__(self):
339+
def __repr__(self) -> str:
325340
return (
326341
f"{self.__class__.__name__} with {self.n_levels} levels and base resolution "
327342
f"{self.base_resolution}"
@@ -330,7 +345,11 @@ def __repr__(self):
330345
def base_resolution(self) -> tuple[int, int, int]:
331346
return tuple(self[0].dims)
332347

333-
def to_arrays(self, field, output_arrays=None):
348+
def to_arrays(
349+
self,
350+
field: tuple[str, str],
351+
output_arrays: list[npt.ArrayLike | None] | None = None,
352+
) -> list[npt.ArrayLike]:
334353
if output_arrays is None:
335354
output_arrays = [None for _ in range(len(self.levels))]
336355

@@ -390,7 +409,14 @@ def _validate_factor(
390409
return np.asarray(input_factor, dtype=int)
391410

392411

393-
def _get_filled_grid(le, re, shp, field, ds, field_parameters):
412+
def _get_filled_grid(
413+
le: npt.NDArray,
414+
re: npt.NDArray,
415+
shp: npt.NDArray,
416+
field: tuple[str, str],
417+
ds: Dataset,
418+
field_parameters: Any,
419+
) -> npt.NDArray:
394420
grid = YTArbitraryGrid(le, re, shp, ds=ds, field_parameters=field_parameters)
395421
vals = grid[field]
396422
return vals

0 commit comments

Comments
 (0)