8
8
from yt .data_objects .construction_data_containers import YTArbitraryGrid
9
9
from yt .data_objects .static_output import Dataset
10
10
11
+ _GridInfo = tuple [
12
+ npt .NDArray , npt .NDArray , unyt .unyt_array , unyt .unyt_array , Any , npt .NDArray
13
+ ]
14
+
11
15
12
16
def _validate_edge (edge : npt .ArrayLike , ds : Dataset ):
13
17
if not isinstance (edge , unyt .unyt_array ):
@@ -62,6 +66,9 @@ def __init__(
62
66
63
67
"""
64
68
69
+ if ds is None :
70
+ raise ValueError ("Please provide a dataset via the ds keyword argument" )
71
+
65
72
self .ds = ds
66
73
self .left_edge = _validate_edge (left_edge , ds )
67
74
self .right_edge = _validate_edge (right_edge , ds )
@@ -86,7 +93,7 @@ def __init__(
86
93
self ._left_cell_center = self .left_edge + self .dds / 2.0
87
94
self ._right_cell_center = self .right_edge - self .dds / 2.0
88
95
89
- def __repr__ (self ):
96
+ def __repr__ (self ) -> str :
90
97
nm = self .__class__ .__name__
91
98
shape = tuple (self .dims )
92
99
n_chunks = tuple (self .nchunks )
@@ -97,13 +104,13 @@ def __repr__(self):
97
104
)
98
105
return msg
99
106
100
- def _get_grid_by_ijk (self , ijk_grid ) :
107
+ def _get_grid_by_ijk (self , ijk_grid : npt . NDArray [ int ]) -> _GridInfo :
101
108
chunksizes = self .chunks
102
109
103
110
le_index = []
104
111
re_index = []
105
- le_val = self .ds .domain_left_edge .copy ()
106
- re_val = self .ds .domain_right_edge .copy ()
112
+ le_val : unyt . unyt_array = self .ds .domain_left_edge .copy ()
113
+ re_val : unyt . unyt_array = self .ds .domain_right_edge .copy ()
107
114
108
115
for idim in range (self ._ndim ):
109
116
chunk_i = ijk_grid [idim ]
@@ -122,29 +129,29 @@ def _get_grid_by_ijk(self, ijk_grid):
122
129
le_index [2 ] : re_index [2 ],
123
130
]
124
131
125
- le_index = np .array (le_index , dtype = int )
126
- re_index = np .array (re_index , dtype = int )
132
+ le_index_ = np .array (le_index , dtype = int )
133
+ re_index_ = np .array (re_index , dtype = int )
127
134
shape = chunksizes
128
135
129
- return le_index , re_index , le_val , re_val , slc , shape
136
+ return le_index_ , re_index_ , le_val , re_val , slc , shape
130
137
131
- def _get_grid (self , igrid : int ):
138
+ def _get_grid (self , igrid : int ) -> _GridInfo :
132
139
# get grid extent of a **single** grid
133
140
ijk_grid = np .unravel_index (igrid , self .nchunks )
134
141
return self ._get_grid_by_ijk (ijk_grid )
135
142
136
- def _coord_array (self , idim ) :
143
+ def _coord_array (self , idim : int ) -> npt . NDArray :
137
144
LE = self ._left_cell_center [idim ]
138
145
RE = self ._right_cell_center [idim ]
139
146
N = self .dims [idim ]
140
147
return np .mgrid [LE : RE : N * 1j ]
141
148
142
- def to_xarray (self , field , * , output_array = None ):
149
+ def to_xarray (
150
+ self , field : tuple [str , str ], * , output_array : npt .ArrayLike | None = None
151
+ ) -> Any :
143
152
144
153
import xarray as xr
145
154
146
- # ToDo: import from on_demand_imports
147
-
148
155
vals = self .to_array (field , output_array = output_array )
149
156
150
157
dims = self .ds .coordinates .axis_order
@@ -162,7 +169,13 @@ def to_xarray(self, field, *, output_array=None):
162
169
)
163
170
return xr_ds
164
171
165
- def single_grid_values (self , igrid , field , * , ops = None ):
172
+ def single_grid_values (
173
+ self ,
174
+ igrid : int ,
175
+ field : tuple [str , str ],
176
+ * ,
177
+ ops : list [Callable [[npt .NDArray ], npt .NDArray ]] | None = None ,
178
+ ) -> tuple [npt .NDArray , Any ]:
166
179
"""
167
180
Get the values for a field for a single grid chunk as in-memory array.
168
181
@@ -308,7 +321,9 @@ def __init__(
308
321
309
322
self .levels : list [YTTiledArbitraryGrid ] = levels
310
323
311
- def _validate_levels (self , levels ):
324
+ def _validate_levels (
325
+ self , levels : Sequence [int | tuple [int , int , int ] | npt .ArrayLike ]
326
+ ):
312
327
313
328
for ilev in range (1 , self .n_levels ):
314
329
res = np .prod (levels [ilev ])
@@ -321,7 +336,7 @@ def _validate_levels(self, levels):
321
336
)
322
337
raise ValueError (msg )
323
338
324
- def __repr__ (self ):
339
+ def __repr__ (self ) -> str :
325
340
return (
326
341
f"{ self .__class__ .__name__ } with { self .n_levels } levels and base resolution "
327
342
f"{ self .base_resolution } "
@@ -330,7 +345,11 @@ def __repr__(self):
330
345
def base_resolution (self ) -> tuple [int , int , int ]:
331
346
return tuple (self [0 ].dims )
332
347
333
- def to_arrays (self , field , output_arrays = None ):
348
+ def to_arrays (
349
+ self ,
350
+ field : tuple [str , str ],
351
+ output_arrays : list [npt .ArrayLike | None ] | None = None ,
352
+ ) -> list [npt .ArrayLike ]:
334
353
if output_arrays is None :
335
354
output_arrays = [None for _ in range (len (self .levels ))]
336
355
@@ -390,7 +409,14 @@ def _validate_factor(
390
409
return np .asarray (input_factor , dtype = int )
391
410
392
411
393
- def _get_filled_grid (le , re , shp , field , ds , field_parameters ):
412
+ def _get_filled_grid (
413
+ le : npt .NDArray ,
414
+ re : npt .NDArray ,
415
+ shp : npt .NDArray ,
416
+ field : tuple [str , str ],
417
+ ds : Dataset ,
418
+ field_parameters : Any ,
419
+ ) -> npt .NDArray :
394
420
grid = YTArbitraryGrid (le , re , shp , ds = ds , field_parameters = field_parameters )
395
421
vals = grid [field ]
396
422
return vals
0 commit comments