2525from ..tools .numba import get_common_numba_dtype , jit , make_array_constructor
2626from ..tools .plotting import PlotReference , plot_on_axes
2727from ..tools .spectral import CorrelationType , make_correlated_noise
28- from ..tools .typing import ArrayLike , NumberOrArray
28+ from ..tools .typing import ArrayLike , NumberOrArray , NumericArray
2929from .base import FieldBase , RankError
3030
3131if TYPE_CHECKING :
@@ -182,7 +182,7 @@ def random_uniform(
182182 # create complex random numbers for the field
183183 real_part = rng .uniform (np .real (vmin ), np .real (vmax ), size = shape )
184184 imag_part = rng .uniform (np .imag (vmin ), np .imag (vmax ), size = shape )
185- data : np . ndarray = real_part + 1j * imag_part
185+ data : NumericArray = real_part + 1j * imag_part
186186 else :
187187 # create real random numbers for the field
188188 data = rng .uniform (vmin , vmax , size = shape )
@@ -293,7 +293,7 @@ def random_normal(
293293 else :
294294 tensor_shape = (grid .dim ,) * cls .rank
295295
296- def make_random_field () -> np . ndarray :
296+ def make_random_field () -> NumericArray :
297297 """Helper function that creates a single tensor field."""
298298 out = np .empty (tensor_shape + grid .shape )
299299 print (out .shape , tensor_shape , grid )
@@ -313,7 +313,7 @@ def make_random_field() -> np.ndarray:
313313 # create complex random numbers for the field
314314 real_part = np .real (mean ) + np .real (std ) * scale * make_random_field ()
315315 imag_part = np .imag (mean ) + np .imag (std ) * scale * make_random_field ()
316- data : np . ndarray = real_part + 1j * imag_part
316+ data : NumericArray = real_part + 1j * imag_part
317317 else :
318318 # create real random numbers for the field
319319 data = mean + std * scale * make_random_field ()
@@ -481,7 +481,7 @@ def get_class_by_rank(cls, rank: int) -> type[DataFieldBase]:
481481 def from_state (
482482 cls : type [TDataField ],
483483 attributes : dict [str , Any ],
484- data : np . ndarray | None = None ,
484+ data : NumericArray | None = None ,
485485 ) -> TDataField :
486486 """Create a field from given state.
487487
@@ -593,7 +593,7 @@ def make_interpolator(
593593 * ,
594594 fill : Number | None = None ,
595595 with_ghost_cells : bool = False ,
596- ) -> Callable [[np . ndarray , np . ndarray ], NumberOrArray ]:
596+ ) -> Callable [[NumericArray , NumericArray ], NumberOrArray ]:
597597 r"""Returns a function that can be used to interpolate values.
598598
599599 Args:
@@ -617,7 +617,7 @@ def make_interpolator(
617617 # convert `fill` to dtype of data
618618 if fill is not None :
619619 if self .rank == 0 :
620- fill = self .data .dtype .type (fill )
620+ fill = self .data .dtype .type (fill ) # type: ignore
621621 else :
622622 fill = np .broadcast_to (fill , self .data_shape ).astype (self .data .dtype )
623623
@@ -636,8 +636,8 @@ def make_interpolator(
636636
637637 @jit
638638 def interpolator (
639- point : np . ndarray , data : np . ndarray | None = None
640- ) -> np . ndarray :
639+ point : NumericArray , data : NumericArray | None = None
640+ ) -> NumericArray :
641641 """Return the interpolated value at the position `point`
642642
643643 Args:
@@ -680,11 +680,11 @@ def interpolator(
680680 @fill_in_docstring
681681 def interpolate (
682682 self ,
683- point : np . ndarray ,
683+ point : NumericArray ,
684684 * ,
685685 bc : BoundariesData | None = None ,
686686 fill : Number | None = None ,
687- ) -> np . ndarray :
687+ ) -> NumericArray :
688688 r"""Interpolate the field to points between support points.
689689
690690 Args:
@@ -748,7 +748,7 @@ def interpolate_to_grid(
748748 """
749749 raise NotImplementedError (f"Cannot interpolate { self .__class__ .__name__ } " )
750750
751- def insert (self , point : np . ndarray , amount : ArrayLike ) -> None :
751+ def insert (self , point : NumericArray , amount : ArrayLike ) -> None :
752752 """Adds an (integrated) value to the field at an interpolated position.
753753
754754 Args:
@@ -965,7 +965,7 @@ def apply_operator(
965965
966966 def make_dot_operator (
967967 self , backend : Literal ["numpy" , "numba" ] = "numba" , * , conjugate : bool = True
968- ) -> Callable [[np . ndarray , np . ndarray , np . ndarray | None ], np . ndarray ]:
968+ ) -> Callable [[NumericArray , NumericArray , NumericArray | None ], NumericArray ]:
969969 """Return operator calculating the dot product between two fields.
970970
971971 This supports both products between two vectors as well as products
@@ -986,13 +986,13 @@ def make_dot_operator(
986986 num_axes = self .grid .num_axes
987987
988988 @register_jitable
989- def maybe_conj (arr : np . ndarray ) -> np . ndarray :
989+ def maybe_conj (arr : NumericArray ) -> NumericArray :
990990 """Helper function implementing optional conjugation."""
991991 return arr .conjugate () if conjugate else arr
992992
993993 def dot (
994- a : np . ndarray , b : np . ndarray , out : np . ndarray | None = None
995- ) -> np . ndarray :
994+ a : NumericArray , b : NumericArray , out : NumericArray | None = None
995+ ) -> NumericArray :
996996 """Numpy implementation to calculate dot product between two fields."""
997997 rank_a = a .ndim - num_axes
998998 rank_b = b .ndim - num_axes
@@ -1040,8 +1040,8 @@ def get_rank(arr: nb.types.Type | nb.types.Optional) -> int:
10401040
10411041 @overload (dot , inline = "always" )
10421042 def dot_ol (
1043- a : np . ndarray , b : np . ndarray , out : np . ndarray | None = None
1044- ) -> np . ndarray :
1043+ a : NumericArray , b : NumericArray , out : NumericArray | None = None
1044+ ) -> NumericArray :
10451045 """Numba implementation to calculate dot product between two fields."""
10461046 # get (and check) rank of the input arrays
10471047 rank_a = get_rank (a )
@@ -1050,15 +1050,19 @@ def dot_ol(
10501050 if rank_a == 1 and rank_b == 1 : # result is scalar field
10511051
10521052 @register_jitable
1053- def calc (a : np .ndarray , b : np .ndarray , out : np .ndarray ) -> None :
1053+ def calc (
1054+ a : NumericArray , b : NumericArray , out : NumericArray
1055+ ) -> None :
10541056 out [:] = a [0 ] * maybe_conj (b [0 ])
10551057 for j in range (1 , dim ):
10561058 out [:] += a [j ] * maybe_conj (b [j ])
10571059
10581060 elif rank_a == 2 and rank_b == 1 : # result is vector field
10591061
10601062 @register_jitable
1061- def calc (a : np .ndarray , b : np .ndarray , out : np .ndarray ) -> None :
1063+ def calc (
1064+ a : NumericArray , b : NumericArray , out : NumericArray
1065+ ) -> None :
10621066 for i in range (dim ):
10631067 out [i ] = a [i , 0 ] * maybe_conj (b [0 ])
10641068 for j in range (1 , dim ):
@@ -1067,7 +1071,9 @@ def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> None:
10671071 elif rank_a == 1 and rank_b == 2 : # result is vector field
10681072
10691073 @register_jitable
1070- def calc (a : np .ndarray , b : np .ndarray , out : np .ndarray ) -> None :
1074+ def calc (
1075+ a : NumericArray , b : NumericArray , out : NumericArray
1076+ ) -> None :
10711077 for i in range (dim ):
10721078 out [i ] = a [0 ] * maybe_conj (b [0 , i ])
10731079 for j in range (1 , dim ):
@@ -1076,7 +1082,9 @@ def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> None:
10761082 elif rank_a == 2 and rank_b == 2 : # result is tensor-2 field
10771083
10781084 @register_jitable
1079- def calc (a : np .ndarray , b : np .ndarray , out : np .ndarray ) -> None :
1085+ def calc (
1086+ a : NumericArray , b : NumericArray , out : NumericArray
1087+ ) -> None :
10801088 for i in range (dim ):
10811089 for j in range (dim ):
10821090 out [i , j ] = a [i , 0 ] * maybe_conj (b [0 , j ])
@@ -1095,8 +1103,10 @@ def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> None:
10951103 dtype = get_common_numba_dtype (a , b )
10961104
10971105 def dot_impl (
1098- a : np .ndarray , b : np .ndarray , out : np .ndarray | None = None
1099- ) -> np .ndarray :
1106+ a : NumericArray ,
1107+ b : NumericArray ,
1108+ out : NumericArray | None = None ,
1109+ ) -> NumericArray :
11001110 """Helper function allocating output array."""
11011111 assert a .shape == a_shape
11021112 assert b .shape == b_shape
@@ -1108,8 +1118,10 @@ def dot_impl(
11081118 # function is called with `out` argument -> reuse `out` array
11091119
11101120 def dot_impl (
1111- a : np .ndarray , b : np .ndarray , out : np .ndarray | None = None
1112- ) -> np .ndarray :
1121+ a : NumericArray ,
1122+ b : NumericArray ,
1123+ out : NumericArray | None = None ,
1124+ ) -> NumericArray :
11131125 """Helper function without allocating output array."""
11141126 assert a .shape == a_shape
11151127 assert b .shape == b_shape
@@ -1121,8 +1133,8 @@ def dot_impl(
11211133
11221134 @jit
11231135 def dot_compiled (
1124- a : np . ndarray , b : np . ndarray , out : np . ndarray | None = None
1125- ) -> np . ndarray :
1136+ a : NumericArray , b : NumericArray , out : NumericArray | None = None
1137+ ) -> NumericArray :
11261138 """Numba implementation to calculate dot product between two fields."""
11271139 return dot (a , b , out )
11281140
0 commit comments