2020import numpy as np
2121from numba .extending import is_jitted , register_jitable
2222from numba .extending import overload as nb_overload
23- from numpy .typing import ArrayLike
23+ from numpy .typing import ArrayLike , NDArray
2424
2525from ..tools .cache import cached_method , cached_property
2626from ..tools .docstrings import fill_in_docstring
2727from ..tools .misc import Number , hybridmethod
2828from ..tools .numba import jit
2929from ..tools .typing import (
3030 CellVolume ,
31+ FloatingArray ,
3132 FloatOrArray ,
3233 NumberOrArray ,
3334 NumericArray ,
@@ -79,7 +80,7 @@ def _check_shape(shape: int | Sequence[int]) -> tuple[int, ...]:
7980
8081def discretize_interval (
8182 x_min : float , x_max : float , num : int
82- ) -> tuple [NumericArray , float ]:
83+ ) -> tuple [FloatingArray , float ]:
8384 r"""Construct a list of equidistantly placed intervals.
8485
8586 The discretization is defined as
@@ -143,8 +144,8 @@ class GridBase(metaclass=ABCMeta):
143144 _axes_symmetric : tuple [int , ...] = ()
144145 _axes_described : tuple [int , ...]
145146 _axes_bounds : tuple [tuple [float , float ], ...]
146- _axes_coords : tuple [NumericArray , ...]
147- _discretization : NumericArray
147+ _axes_coords : tuple [FloatingArray , ...]
148+ _discretization : FloatingArray
148149 _periodic : list [bool ]
149150 _shape : tuple [int , ...]
150151
@@ -230,7 +231,7 @@ def axes_bounds(self) -> tuple[tuple[float, float], ...]:
230231 return self ._axes_bounds
231232
232233 @property
233- def axes_coords (self ) -> tuple [NumericArray , ...]:
234+ def axes_coords (self ) -> tuple [FloatingArray , ...]:
234235 """tuple: coordinates of the cells for each axis"""
235236 return self ._axes_coords
236237
@@ -299,7 +300,7 @@ def _get_boundary_index(self, index: str | tuple[int, bool]) -> tuple[int, bool]
299300 return axis , upper
300301
301302 @property
302- def discretization (self ) -> NumericArray :
303+ def discretization (self ) -> FloatingArray :
303304 """:class:`numpy.array`: the linear size of a cell along each axis."""
304305 return self ._discretization
305306
@@ -514,17 +515,17 @@ def numba_type(self) -> str:
514515 return "f8[" + ", " .join ([":" ] * self .num_axes ) + "]"
515516
516517 @cached_property ()
517- def coordinate_arrays (self ) -> tuple [NumericArray , ...]:
518+ def coordinate_arrays (self ) -> tuple [FloatingArray , ...]:
518519 """tuple: for each axes: coordinate values for all cells"""
519520 return tuple (np .meshgrid (* self .axes_coords , indexing = "ij" ))
520521
521522 @cached_property ()
522- def cell_coords (self ) -> NumericArray :
523+ def cell_coords (self ) -> FloatingArray :
523524 """:class:`~numpy.ndarray`: coordinate values for all axes of each cell."""
524525 return np .moveaxis (self .coordinate_arrays , 0 , - 1 ) # type: ignore
525526
526527 @cached_property ()
527- def cell_volumes (self ) -> NumericArray :
528+ def cell_volumes (self ) -> FloatingArray :
528529 """:class:`~numpy.ndarray`: volume of each cell."""
529530 if self .cell_volume_data is None :
530531 # use the self.c to calculate cell volumes
@@ -548,13 +549,13 @@ def uniform_cell_volumes(self) -> bool:
548549
549550 def _difference_vector (
550551 self ,
551- p1 : NumericArray ,
552- p2 : NumericArray ,
552+ p1 : FloatingArray ,
553+ p2 : FloatingArray ,
553554 * ,
554555 coords : CoordsType ,
555556 periodic : Sequence [bool ],
556557 axes_bounds : tuple [tuple [float , float ], ...] | None ,
557- ) -> NumericArray :
558+ ) -> FloatingArray :
558559 """Return Cartesian vector(s) pointing from p1 to p2.
559560
560561 In case of periodic boundary conditions, the shortest vector is returned.
@@ -590,8 +591,8 @@ def _difference_vector(
590591 return diff # type: ignore
591592
592593 def difference_vector (
593- self , p1 : NumericArray , p2 : NumericArray , * , coords : CoordsType = "grid"
594- ) -> NumericArray :
594+ self , p1 : FloatingArray , p2 : FloatingArray , * , coords : CoordsType = "grid"
595+ ) -> FloatingArray :
595596 """Return Cartesian vector(s) pointing from p1 to p2.
596597
597598 In case of periodic boundary conditions, the shortest vector is returned.
@@ -614,7 +615,7 @@ def difference_vector(
614615 )
615616
616617 def distance (
617- self , p1 : NumericArray , p2 : NumericArray , * , coords : CoordsType = "grid"
618+ self , p1 : FloatingArray , p2 : FloatingArray , * , coords : CoordsType = "grid"
618619 ) -> float :
619620 """Calculate the distance between two points given in real coordinates.
620621
@@ -647,7 +648,7 @@ def _iter_boundaries(self) -> Iterator[tuple[int, bool]]:
647648
648649 def _boundary_coordinates (
649650 self , axis : int , upper : bool , * , offset : float = 0
650- ) -> NumericArray :
651+ ) -> FloatingArray :
651652 """Get coordinates of points on the boundary.
652653
653654 Args:
@@ -686,7 +687,7 @@ def volume(self) -> float:
686687 # this property should be overwritten when the volume can be calculated directly
687688 return self .cell_volumes .sum () # type: ignore
688689
689- def point_to_cartesian (self , points : NumericArray ) -> NumericArray :
690+ def point_to_cartesian (self , points : FloatingArray ) -> FloatingArray :
690691 """Convert coordinates of a point in grid coordinates to Cartesian coordinates.
691692
692693 Args:
@@ -698,7 +699,7 @@ def point_to_cartesian(self, points: NumericArray) -> NumericArray:
698699 """
699700 return self .c .pos_to_cart (self ._coords_full (points ))
700701
701- def point_from_cartesian (self , points : NumericArray ) -> NumericArray :
702+ def point_from_cartesian (self , points : FloatingArray ) -> FloatingArray :
702703 """Convert points given in Cartesian coordinates to grid coordinates.
703704
704705 Args:
@@ -711,7 +712,7 @@ def point_from_cartesian(self, points: NumericArray) -> NumericArray:
711712 return self ._coords_symmetric (self .c .pos_from_cart (points ))
712713
713714 def _vector_to_cartesian (
714- self , points : ArrayLike , components : ArrayLike
715+ self , points : FloatingArray , components : ArrayLike
715716 ) -> NumericArray :
716717 """Convert the vectors at given points into a Cartesian basis.
717718
@@ -745,8 +746,8 @@ def _vector_to_cartesian(
745746 return np .einsum ("j...,ji...->i..." , components , rot_mat ) # type: ignore
746747
747748 def normalize_point (
748- self , point : NumericArray , * , reflect : bool = False
749- ) -> NumericArray :
749+ self , point : FloatingArray , * , reflect : bool = False
750+ ) -> FloatingArray :
750751 """Normalize grid coordinates by applying periodic boundary conditions.
751752
752753 Here, points are assumed to be specified by the physical values along the
@@ -808,7 +809,7 @@ def normalize_point(
808809
809810 return point
810811
811- def _coords_symmetric (self , points : NumericArray ) -> NumericArray :
812+ def _coords_symmetric (self , points : FloatingArray ) -> FloatingArray :
812813 """Return only non-symmetric point coordinates.
813814
814815 Args:
@@ -824,8 +825,8 @@ def _coords_symmetric(self, points: NumericArray) -> NumericArray:
824825 return points [..., self ._axes_described ]
825826
826827 def _coords_full (
827- self , points : NumericArray , * , value : Literal ["min" , "max" ] | float = 0.0
828- ) -> NumericArray :
828+ self , points : FloatingArray , * , value : Literal ["min" , "max" ] | float = 0.0
829+ ) -> FloatingArray :
829830 """Specify point coordinates along symmetric axes on grids.
830831
831832 Args:
@@ -861,8 +862,8 @@ def _coords_full(
861862 return res # type: ignore
862863
863864 def transform (
864- self , coordinates : NumericArray , source : CoordsType , target : CoordsType
865- ) -> NumericArray :
865+ self , coordinates : FloatingArray , source : CoordsType , target : CoordsType
866+ ) -> FloatingArray :
866867 """Converts coordinates from one coordinate system to another.
867868
868869 Supported coordinate systems include the following:
@@ -956,10 +957,10 @@ def transform(
956957
957958 def contains_point (
958959 self ,
959- points : NumericArray ,
960+ points : FloatingArray ,
960961 * ,
961962 coords : Literal ["cartesian" , "cell" , "grid" ] = "cartesian" ,
962- ) -> NumericArray :
963+ ) -> NDArray [ np . bool ] :
963964 """Check whether the point is contained in the grid.
964965
965966 Args:
@@ -973,11 +974,11 @@ def contains_point(
973974 the grid
974975 """
975976 cell_coords = self .transform (points , source = coords , target = "cell" )
976- return np .all ((cell_coords >= 0 ) & (cell_coords <= self .shape ), axis = - 1 ) # type: ignore
977+ return np .all ((cell_coords >= 0 ) & (cell_coords <= self .shape ), axis = - 1 )
977978
978979 def iter_mirror_points (
979- self , point : NumericArray , with_self : bool = False , only_periodic : bool = True
980- ) -> Generator :
980+ self , point : FloatingArray , with_self : bool = False , only_periodic : bool = True
981+ ) -> Iterator [ FloatingArray ] :
981982 """Generates all mirror points corresponding to `point`
982983
983984 Args:
@@ -1105,7 +1106,7 @@ def get_random_point(
11051106 boundary_distance : float = 0 ,
11061107 coords : CoordsType = "cartesian" ,
11071108 rng : np .random .Generator | None = None ,
1108- ) -> NumericArray :
1109+ ) -> FloatingArray :
11091110 """Return a random point within the grid.
11101111
11111112 Args:
@@ -1555,7 +1556,7 @@ def integrate(
15551556 @cached_method ()
15561557 def make_normalize_point_compiled (
15571558 self , reflect : bool = True
1558- ) -> Callable [[NumericArray ], None ]:
1559+ ) -> Callable [[FloatingArray ], None ]:
15591560 """Return a compiled function that normalizes a point.
15601561
15611562 Here, the point is assumed to be specified by the physical values along
@@ -1582,7 +1583,7 @@ def make_normalize_point_compiled(
15821583 size = bounds [:, 1 ] - bounds [:, 0 ]
15831584
15841585 @jit
1585- def normalize_point (point : NumericArray ) -> None :
1586+ def normalize_point (point : FloatingArray ) -> None :
15861587 """Helper function normalizing a single point."""
15871588 assert point .ndim == 1 # only support single points
15881589 for i in range (num_axes ):
@@ -1730,7 +1731,7 @@ def _make_interpolator_compiled(
17301731 fill : Number | None = None ,
17311732 with_ghost_cells : bool = False ,
17321733 cell_coords : bool = False ,
1733- ) -> Callable [[NumericArray , NumericArray ], NumericArray ]:
1734+ ) -> Callable [[NumericArray , FloatingArray ], NumericArray ]:
17341735 """Return a compiled function for linear interpolation on the grid.
17351736
17361737 Args:
@@ -1760,7 +1761,7 @@ def _make_interpolator_compiled(
17601761
17611762 @jit
17621763 def interpolate_single (
1763- data : NumericArray , point : NumericArray
1764+ data : NumericArray , point : FloatingArray
17641765 ) -> NumberOrArray :
17651766 """Obtain interpolated value of data at a point.
17661767
@@ -1792,7 +1793,7 @@ def interpolate_single(
17921793
17931794 @jit
17941795 def interpolate_single (
1795- data : NumericArray , point : NumericArray
1796+ data : NumericArray , point : FloatingArray
17961797 ) -> NumberOrArray :
17971798 """Obtain interpolated value of data at a point.
17981799
@@ -1832,7 +1833,7 @@ def interpolate_single(
18321833
18331834 @jit
18341835 def interpolate_single (
1835- data : NumericArray , point : NumericArray
1836+ data : NumericArray , point : FloatingArray
18361837 ) -> NumberOrArray :
18371838 """Obtain interpolated value of data at a point.
18381839
@@ -1878,7 +1879,7 @@ def interpolate_single(
18781879
18791880 def make_inserter_compiled (
18801881 self , * , with_ghost_cells : bool = False
1881- ) -> Callable [[NumericArray , NumericArray , NumberOrArray ], None ]:
1882+ ) -> Callable [[NumericArray , FloatingArray , NumberOrArray ], None ]:
18821883 """Return a compiled function to insert values at interpolated positions.
18831884
18841885 Args:
@@ -1903,7 +1904,7 @@ def make_inserter_compiled(
19031904
19041905 @jit
19051906 def insert (
1906- data : NumericArray , point : NumericArray , amount : NumberOrArray
1907+ data : NumericArray , point : FloatingArray , amount : NumberOrArray
19071908 ) -> None :
19081909 """Add an amount to a field at an interpolated position.
19091910
@@ -1938,7 +1939,7 @@ def insert(
19381939
19391940 @jit
19401941 def insert (
1941- data : NumericArray , point : NumericArray , amount : NumberOrArray
1942+ data : NumericArray , point : FloatingArray , amount : NumberOrArray
19421943 ) -> None :
19431944 """Add an amount to a field at an interpolated position.
19441945
@@ -1985,7 +1986,7 @@ def insert(
19851986
19861987 @jit
19871988 def insert (
1988- data : NumericArray , point : NumericArray , amount : NumberOrArray
1989+ data : NumericArray , point : FloatingArray , amount : NumberOrArray
19891990 ) -> None :
19901991 """Add an amount to a field at an interpolated position.
19911992
0 commit comments