Skip to content

Commit 901f1c4

Browse files
Merge pull request #724 from zwicker-group/typing
Defined more precise array type
2 parents 1ee6371 + 7f4629f commit 901f1c4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+629
-540
lines changed

pde/fields/base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from ..grids.base import GridBase
1919
from ..tools.plotting import napari_add_layers, napari_viewer
20-
from ..tools.typing import NumberOrArray
20+
from ..tools.typing import NumberOrArray, NumericArray
2121

2222
if TYPE_CHECKING:
2323
from .scalar import ScalarField
@@ -38,15 +38,15 @@ class FieldBase(metaclass=ABCMeta):
3838

3939
_subclasses: dict[str, type[FieldBase]] = {} # all classes inheriting from this
4040
_grid: GridBase # the grid on which the field is defined
41-
__data_full: np.ndarray # the data on the grid including ghost points
42-
_data_valid: np.ndarray # the valid data without ghost points
41+
__data_full: NumericArray # the data on the grid including ghost points
42+
_data_valid: NumericArray # the valid data without ghost points
4343
_label: str | None # name of the field
4444
_logger: logging.Logger # logger instance to output information
4545

4646
def __init__(
4747
self,
4848
grid: GridBase,
49-
data: np.ndarray,
49+
data: NumericArray,
5050
*,
5151
label: str | None = None,
5252
):
@@ -82,7 +82,7 @@ def __getstate__(self) -> dict[str, Any]:
8282
return state
8383

8484
@property
85-
def data(self) -> np.ndarray:
85+
def data(self) -> NumericArray:
8686
""":class:`~numpy.ndarray`: discretized data at the support points."""
8787
return self._data_valid
8888

@@ -109,7 +109,7 @@ def _idx_valid(self) -> tuple[slice, ...]:
109109
return idx_comp + self.grid._idx_valid
110110

111111
@property
112-
def _data_full(self) -> np.ndarray:
112+
def _data_full(self) -> NumericArray:
113113
""":class:`~numpy.ndarray`: the full data including ghost cells."""
114114
return self.__data_full
115115

@@ -146,17 +146,17 @@ def _data_full(self, value: NumberOrArray) -> None:
146146
self._data_valid = self.__data_full[self._idx_valid]
147147

148148
@property
149-
def _data_flat(self) -> np.ndarray:
149+
def _data_flat(self) -> NumericArray:
150150
""":class:`~numpy.ndarray`: flat version of discretized data with ghost
151151
cells."""
152152
# flatten the first dimension of the internal data by creating a view and then
153153
# setting the new shape. This disallows accidental copying of the data
154154
data_flat = self._data_full.view()
155155
data_flat.shape = (-1, *self.grid._shape_full)
156-
return data_flat
156+
return data_flat # type: ignore
157157

158158
@_data_flat.setter
159-
def _data_flat(self, value: np.ndarray) -> None:
159+
def _data_flat(self, value: NumericArray) -> None:
160160
"""Set the full data including ghost cells from a flattened array."""
161161
# simply set the data -> this might need to be overwritten
162162
self._data_full = value
@@ -187,7 +187,7 @@ def label(self, value: str | None = None):
187187

188188
@classmethod
189189
def from_state(
190-
cls, attributes: dict[str, Any], data: np.ndarray | None = None
190+
cls, attributes: dict[str, Any], data: NumericArray | None = None
191191
) -> FieldBase:
192192
"""Create a field from given state.
193193

pde/fields/collection.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..tools.docstrings import fill_in_docstring
2727
from ..tools.misc import Number, number_array
2828
from ..tools.plotting import PlotReference, plot_on_axes, plot_on_figure
29-
from ..tools.typing import NumberOrArray
29+
from ..tools.typing import NumberOrArray, NumericArray
3030
from .base import FieldBase
3131
from .datafield_base import DataFieldBase
3232
from .scalar import ScalarField
@@ -103,7 +103,7 @@ def __init__(
103103
self._fields = fields # type: ignore
104104

105105
# extract data from individual fields
106-
fields_data: list[np.ndarray] = []
106+
fields_data: list[NumericArray] = []
107107
self._slices: list[slice] = []
108108
dof = 0 # count local degrees of freedom
109109
for field in self.fields:
@@ -253,7 +253,7 @@ def __eq__(self, other):
253253

254254
@classmethod
255255
def from_state(
256-
cls, attributes: dict[str, Any], data: np.ndarray | None = None
256+
cls, attributes: dict[str, Any], data: NumericArray | None = None
257257
) -> FieldCollection:
258258
"""Create a field collection from given state.
259259
@@ -286,7 +286,7 @@ def from_data(
286286
cls,
287287
field_classes,
288288
grid: GridBase,
289-
data: np.ndarray,
289+
data: NumericArray,
290290
*,
291291
with_ghost_cells: bool = True,
292292
label: str | None = None,
@@ -701,7 +701,7 @@ def averages(self) -> list:
701701
return [field.average for field in self]
702702

703703
@property
704-
def magnitudes(self) -> np.ndarray:
704+
def magnitudes(self) -> NumericArray:
705705
""":class:`~numpy.ndarray`: scalar magnitudes of all fields."""
706706
return np.array([field.magnitude for field in self]) # type: ignore
707707

@@ -807,7 +807,7 @@ def _get_merged_image_data(
807807
transpose: bool = False,
808808
vmin: float | list[float | None] | None = None,
809809
vmax: float | list[float | None] | None = None,
810-
) -> tuple[np.ndarray, dict[str, Any]]:
810+
) -> tuple[NumericArray, dict[str, Any]]:
811811
"""Obtain data required for a merged plot.
812812
813813
Args:

pde/fields/datafield_base.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..tools.numba import get_common_numba_dtype, jit, make_array_constructor
2626
from ..tools.plotting import PlotReference, plot_on_axes
2727
from ..tools.spectral import CorrelationType, make_correlated_noise
28-
from ..tools.typing import ArrayLike, NumberOrArray
28+
from ..tools.typing import ArrayLike, NumberOrArray, NumericArray
2929
from .base import FieldBase, RankError
3030

3131
if TYPE_CHECKING:
@@ -182,7 +182,7 @@ def random_uniform(
182182
# create complex random numbers for the field
183183
real_part = rng.uniform(np.real(vmin), np.real(vmax), size=shape)
184184
imag_part = rng.uniform(np.imag(vmin), np.imag(vmax), size=shape)
185-
data: np.ndarray = real_part + 1j * imag_part
185+
data: NumericArray = real_part + 1j * imag_part
186186
else:
187187
# create real random numbers for the field
188188
data = rng.uniform(vmin, vmax, size=shape)
@@ -293,7 +293,7 @@ def random_normal(
293293
else:
294294
tensor_shape = (grid.dim,) * cls.rank
295295

296-
def make_random_field() -> np.ndarray:
296+
def make_random_field() -> NumericArray:
297297
"""Helper function that creates a single tensor field."""
298298
out = np.empty(tensor_shape + grid.shape)
299299
print(out.shape, tensor_shape, grid)
@@ -313,7 +313,7 @@ def make_random_field() -> np.ndarray:
313313
# create complex random numbers for the field
314314
real_part = np.real(mean) + np.real(std) * scale * make_random_field()
315315
imag_part = np.imag(mean) + np.imag(std) * scale * make_random_field()
316-
data: np.ndarray = real_part + 1j * imag_part
316+
data: NumericArray = real_part + 1j * imag_part
317317
else:
318318
# create real random numbers for the field
319319
data = mean + std * scale * make_random_field()
@@ -481,7 +481,7 @@ def get_class_by_rank(cls, rank: int) -> type[DataFieldBase]:
481481
def from_state(
482482
cls: type[TDataField],
483483
attributes: dict[str, Any],
484-
data: np.ndarray | None = None,
484+
data: NumericArray | None = None,
485485
) -> TDataField:
486486
"""Create a field from given state.
487487
@@ -593,7 +593,7 @@ def make_interpolator(
593593
*,
594594
fill: Number | None = None,
595595
with_ghost_cells: bool = False,
596-
) -> Callable[[np.ndarray, np.ndarray], NumberOrArray]:
596+
) -> Callable[[NumericArray, NumericArray], NumberOrArray]:
597597
r"""Returns a function that can be used to interpolate values.
598598
599599
Args:
@@ -617,7 +617,7 @@ def make_interpolator(
617617
# convert `fill` to dtype of data
618618
if fill is not None:
619619
if self.rank == 0:
620-
fill = self.data.dtype.type(fill)
620+
fill = self.data.dtype.type(fill) # type: ignore
621621
else:
622622
fill = np.broadcast_to(fill, self.data_shape).astype(self.data.dtype)
623623

@@ -636,8 +636,8 @@ def make_interpolator(
636636

637637
@jit
638638
def interpolator(
639-
point: np.ndarray, data: np.ndarray | None = None
640-
) -> np.ndarray:
639+
point: NumericArray, data: NumericArray | None = None
640+
) -> NumericArray:
641641
"""Return the interpolated value at the position `point`
642642
643643
Args:
@@ -680,11 +680,11 @@ def interpolator(
680680
@fill_in_docstring
681681
def interpolate(
682682
self,
683-
point: np.ndarray,
683+
point: NumericArray,
684684
*,
685685
bc: BoundariesData | None = None,
686686
fill: Number | None = None,
687-
) -> np.ndarray:
687+
) -> NumericArray:
688688
r"""Interpolate the field to points between support points.
689689
690690
Args:
@@ -748,7 +748,7 @@ def interpolate_to_grid(
748748
"""
749749
raise NotImplementedError(f"Cannot interpolate {self.__class__.__name__}")
750750

751-
def insert(self, point: np.ndarray, amount: ArrayLike) -> None:
751+
def insert(self, point: NumericArray, amount: ArrayLike) -> None:
752752
"""Adds an (integrated) value to the field at an interpolated position.
753753
754754
Args:
@@ -965,7 +965,7 @@ def apply_operator(
965965

966966
def make_dot_operator(
967967
self, backend: Literal["numpy", "numba"] = "numba", *, conjugate: bool = True
968-
) -> Callable[[np.ndarray, np.ndarray, np.ndarray | None], np.ndarray]:
968+
) -> Callable[[NumericArray, NumericArray, NumericArray | None], NumericArray]:
969969
"""Return operator calculating the dot product between two fields.
970970
971971
This supports both products between two vectors as well as products
@@ -986,13 +986,13 @@ def make_dot_operator(
986986
num_axes = self.grid.num_axes
987987

988988
@register_jitable
989-
def maybe_conj(arr: np.ndarray) -> np.ndarray:
989+
def maybe_conj(arr: NumericArray) -> NumericArray:
990990
"""Helper function implementing optional conjugation."""
991991
return arr.conjugate() if conjugate else arr
992992

993993
def dot(
994-
a: np.ndarray, b: np.ndarray, out: np.ndarray | None = None
995-
) -> np.ndarray:
994+
a: NumericArray, b: NumericArray, out: NumericArray | None = None
995+
) -> NumericArray:
996996
"""Numpy implementation to calculate dot product between two fields."""
997997
rank_a = a.ndim - num_axes
998998
rank_b = b.ndim - num_axes
@@ -1040,8 +1040,8 @@ def get_rank(arr: nb.types.Type | nb.types.Optional) -> int:
10401040

10411041
@overload(dot, inline="always")
10421042
def dot_ol(
1043-
a: np.ndarray, b: np.ndarray, out: np.ndarray | None = None
1044-
) -> np.ndarray:
1043+
a: NumericArray, b: NumericArray, out: NumericArray | None = None
1044+
) -> NumericArray:
10451045
"""Numba implementation to calculate dot product between two fields."""
10461046
# get (and check) rank of the input arrays
10471047
rank_a = get_rank(a)
@@ -1050,15 +1050,19 @@ def dot_ol(
10501050
if rank_a == 1 and rank_b == 1: # result is scalar field
10511051

10521052
@register_jitable
1053-
def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> None:
1053+
def calc(
1054+
a: NumericArray, b: NumericArray, out: NumericArray
1055+
) -> None:
10541056
out[:] = a[0] * maybe_conj(b[0])
10551057
for j in range(1, dim):
10561058
out[:] += a[j] * maybe_conj(b[j])
10571059

10581060
elif rank_a == 2 and rank_b == 1: # result is vector field
10591061

10601062
@register_jitable
1061-
def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> None:
1063+
def calc(
1064+
a: NumericArray, b: NumericArray, out: NumericArray
1065+
) -> None:
10621066
for i in range(dim):
10631067
out[i] = a[i, 0] * maybe_conj(b[0])
10641068
for j in range(1, dim):
@@ -1067,7 +1071,9 @@ def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> None:
10671071
elif rank_a == 1 and rank_b == 2: # result is vector field
10681072

10691073
@register_jitable
1070-
def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> None:
1074+
def calc(
1075+
a: NumericArray, b: NumericArray, out: NumericArray
1076+
) -> None:
10711077
for i in range(dim):
10721078
out[i] = a[0] * maybe_conj(b[0, i])
10731079
for j in range(1, dim):
@@ -1076,7 +1082,9 @@ def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> None:
10761082
elif rank_a == 2 and rank_b == 2: # result is tensor-2 field
10771083

10781084
@register_jitable
1079-
def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> None:
1085+
def calc(
1086+
a: NumericArray, b: NumericArray, out: NumericArray
1087+
) -> None:
10801088
for i in range(dim):
10811089
for j in range(dim):
10821090
out[i, j] = a[i, 0] * maybe_conj(b[0, j])
@@ -1095,8 +1103,10 @@ def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> None:
10951103
dtype = get_common_numba_dtype(a, b)
10961104

10971105
def dot_impl(
1098-
a: np.ndarray, b: np.ndarray, out: np.ndarray | None = None
1099-
) -> np.ndarray:
1106+
a: NumericArray,
1107+
b: NumericArray,
1108+
out: NumericArray | None = None,
1109+
) -> NumericArray:
11001110
"""Helper function allocating output array."""
11011111
assert a.shape == a_shape
11021112
assert b.shape == b_shape
@@ -1108,8 +1118,10 @@ def dot_impl(
11081118
# function is called with `out` argument -> reuse `out` array
11091119

11101120
def dot_impl(
1111-
a: np.ndarray, b: np.ndarray, out: np.ndarray | None = None
1112-
) -> np.ndarray:
1121+
a: NumericArray,
1122+
b: NumericArray,
1123+
out: NumericArray | None = None,
1124+
) -> NumericArray:
11131125
"""Helper function without allocating output array."""
11141126
assert a.shape == a_shape
11151127
assert b.shape == b_shape
@@ -1121,8 +1133,8 @@ def dot_impl(
11211133

11221134
@jit
11231135
def dot_compiled(
1124-
a: np.ndarray, b: np.ndarray, out: np.ndarray | None = None
1125-
) -> np.ndarray:
1136+
a: NumericArray, b: NumericArray, out: NumericArray | None = None
1137+
) -> NumericArray:
11261138
"""Numba implementation to calculate dot product between two fields."""
11271139
return dot(a, b, out)
11281140

pde/fields/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..grids.boundaries.axes import BoundariesData
1919
from ..tools.docstrings import fill_in_docstring
2020
from ..tools.misc import Number
21-
from ..tools.typing import NumberOrArray
21+
from ..tools.typing import NumberOrArray, NumericArray
2222
from .datafield_base import DataFieldBase
2323

2424
if TYPE_CHECKING:

pde/fields/tensorial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..tools.docstrings import fill_in_docstring
1616
from ..tools.misc import get_common_dtype
1717
from ..tools.plotting import PlotReference, plot_on_figure
18-
from ..tools.typing import NumberOrArray
18+
from ..tools.typing import NumberOrArray, NumericArray
1919
from .datafield_base import DataFieldBase
2020
from .scalar import ScalarField
2121
from .vectorial import VectorField
@@ -102,7 +102,7 @@ def from_expression(
102102
points = [grid.cell_coords[..., i] for i in range(grid.num_axes)]
103103

104104
# evaluate all vector components at all points
105-
data: list[list[np.ndarray]] = [[None] * grid.dim for _ in range(grid.dim)] # type: ignore
105+
data: list[list[NumericArray]] = [[None] * grid.dim for _ in range(grid.dim)] # type: ignore
106106
for i in range(grid.dim):
107107
for j in range(grid.dim):
108108
expr = ScalarExpression(
@@ -245,7 +245,7 @@ def divergence(
245245
return self.apply_operator("tensor_divergence", bc=bc, out=out, **kwargs) # type: ignore
246246

247247
@property
248-
def integral(self) -> np.ndarray:
248+
def integral(self) -> NumericArray:
249249
""":class:`~numpy.ndarray`: integral of each component over space."""
250250
return self.grid.integrate(self.data) # type: ignore
251251

0 commit comments

Comments
 (0)