Skip to content

Commit 0bef99a

Browse files
Merge pull request #726 from zwicker-group/precise_types
Use more precise array types
2 parents 58d4dbe + df49fd0 commit 0bef99a

File tree

13 files changed

+130
-122
lines changed

13 files changed

+130
-122
lines changed

pde/fields/datafield_base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..tools.numba import get_common_numba_dtype, jit, make_array_constructor
2626
from ..tools.plotting import PlotReference, plot_on_axes
2727
from ..tools.spectral import CorrelationType, make_correlated_noise
28-
from ..tools.typing import ArrayLike, NumberOrArray, NumericArray
28+
from ..tools.typing import ArrayLike, FloatingArray, NumberOrArray, NumericArray
2929
from .base import FieldBase, RankError
3030

3131
if TYPE_CHECKING:
@@ -593,7 +593,7 @@ def make_interpolator(
593593
*,
594594
fill: Number | None = None,
595595
with_ghost_cells: bool = False,
596-
) -> Callable[[NumericArray, NumericArray], NumberOrArray]:
596+
) -> Callable[[FloatingArray, NumericArray], NumberOrArray]:
597597
r"""Returns a function that can be used to interpolate values.
598598
599599
Args:
@@ -636,7 +636,7 @@ def make_interpolator(
636636

637637
@jit
638638
def interpolator(
639-
point: NumericArray, data: NumericArray | None = None
639+
point: FloatingArray, data: NumericArray | None = None
640640
) -> NumericArray:
641641
"""Return the interpolated value at the position `point`
642642
@@ -680,7 +680,7 @@ def interpolator(
680680
@fill_in_docstring
681681
def interpolate(
682682
self,
683-
point: NumericArray,
683+
point: FloatingArray,
684684
*,
685685
bc: BoundariesData | None = None,
686686
fill: Number | None = None,
@@ -748,7 +748,7 @@ def interpolate_to_grid(
748748
"""
749749
raise NotImplementedError(f"Cannot interpolate {self.__class__.__name__}")
750750

751-
def insert(self, point: NumericArray, amount: ArrayLike) -> None:
751+
def insert(self, point: FloatingArray, amount: ArrayLike) -> None:
752752
"""Adds an (integrated) value to the field at an interpolated position.
753753
754754
Args:

pde/grids/base.py

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
import numpy as np
2121
from numba.extending import is_jitted, register_jitable
2222
from numba.extending import overload as nb_overload
23-
from numpy.typing import ArrayLike
23+
from numpy.typing import ArrayLike, NDArray
2424

2525
from ..tools.cache import cached_method, cached_property
2626
from ..tools.docstrings import fill_in_docstring
2727
from ..tools.misc import Number, hybridmethod
2828
from ..tools.numba import jit
2929
from ..tools.typing import (
3030
CellVolume,
31+
FloatingArray,
3132
FloatOrArray,
3233
NumberOrArray,
3334
NumericArray,
@@ -79,7 +80,7 @@ def _check_shape(shape: int | Sequence[int]) -> tuple[int, ...]:
7980

8081
def discretize_interval(
8182
x_min: float, x_max: float, num: int
82-
) -> tuple[NumericArray, float]:
83+
) -> tuple[FloatingArray, float]:
8384
r"""Construct a list of equidistantly placed intervals.
8485
8586
The discretization is defined as
@@ -143,8 +144,8 @@ class GridBase(metaclass=ABCMeta):
143144
_axes_symmetric: tuple[int, ...] = ()
144145
_axes_described: tuple[int, ...]
145146
_axes_bounds: tuple[tuple[float, float], ...]
146-
_axes_coords: tuple[NumericArray, ...]
147-
_discretization: NumericArray
147+
_axes_coords: tuple[FloatingArray, ...]
148+
_discretization: FloatingArray
148149
_periodic: list[bool]
149150
_shape: tuple[int, ...]
150151

@@ -230,7 +231,7 @@ def axes_bounds(self) -> tuple[tuple[float, float], ...]:
230231
return self._axes_bounds
231232

232233
@property
233-
def axes_coords(self) -> tuple[NumericArray, ...]:
234+
def axes_coords(self) -> tuple[FloatingArray, ...]:
234235
"""tuple: coordinates of the cells for each axis"""
235236
return self._axes_coords
236237

@@ -299,7 +300,7 @@ def _get_boundary_index(self, index: str | tuple[int, bool]) -> tuple[int, bool]
299300
return axis, upper
300301

301302
@property
302-
def discretization(self) -> NumericArray:
303+
def discretization(self) -> FloatingArray:
303304
""":class:`numpy.array`: the linear size of a cell along each axis."""
304305
return self._discretization
305306

@@ -514,17 +515,17 @@ def numba_type(self) -> str:
514515
return "f8[" + ", ".join([":"] * self.num_axes) + "]"
515516

516517
@cached_property()
517-
def coordinate_arrays(self) -> tuple[NumericArray, ...]:
518+
def coordinate_arrays(self) -> tuple[FloatingArray, ...]:
518519
"""tuple: for each axes: coordinate values for all cells"""
519520
return tuple(np.meshgrid(*self.axes_coords, indexing="ij"))
520521

521522
@cached_property()
522-
def cell_coords(self) -> NumericArray:
523+
def cell_coords(self) -> FloatingArray:
523524
""":class:`~numpy.ndarray`: coordinate values for all axes of each cell."""
524525
return np.moveaxis(self.coordinate_arrays, 0, -1) # type: ignore
525526

526527
@cached_property()
527-
def cell_volumes(self) -> NumericArray:
528+
def cell_volumes(self) -> FloatingArray:
528529
""":class:`~numpy.ndarray`: volume of each cell."""
529530
if self.cell_volume_data is None:
530531
# use the self.c to calculate cell volumes
@@ -548,13 +549,13 @@ def uniform_cell_volumes(self) -> bool:
548549

549550
def _difference_vector(
550551
self,
551-
p1: NumericArray,
552-
p2: NumericArray,
552+
p1: FloatingArray,
553+
p2: FloatingArray,
553554
*,
554555
coords: CoordsType,
555556
periodic: Sequence[bool],
556557
axes_bounds: tuple[tuple[float, float], ...] | None,
557-
) -> NumericArray:
558+
) -> FloatingArray:
558559
"""Return Cartesian vector(s) pointing from p1 to p2.
559560
560561
In case of periodic boundary conditions, the shortest vector is returned.
@@ -590,8 +591,8 @@ def _difference_vector(
590591
return diff # type: ignore
591592

592593
def difference_vector(
593-
self, p1: NumericArray, p2: NumericArray, *, coords: CoordsType = "grid"
594-
) -> NumericArray:
594+
self, p1: FloatingArray, p2: FloatingArray, *, coords: CoordsType = "grid"
595+
) -> FloatingArray:
595596
"""Return Cartesian vector(s) pointing from p1 to p2.
596597
597598
In case of periodic boundary conditions, the shortest vector is returned.
@@ -614,7 +615,7 @@ def difference_vector(
614615
)
615616

616617
def distance(
617-
self, p1: NumericArray, p2: NumericArray, *, coords: CoordsType = "grid"
618+
self, p1: FloatingArray, p2: FloatingArray, *, coords: CoordsType = "grid"
618619
) -> float:
619620
"""Calculate the distance between two points given in real coordinates.
620621
@@ -647,7 +648,7 @@ def _iter_boundaries(self) -> Iterator[tuple[int, bool]]:
647648

648649
def _boundary_coordinates(
649650
self, axis: int, upper: bool, *, offset: float = 0
650-
) -> NumericArray:
651+
) -> FloatingArray:
651652
"""Get coordinates of points on the boundary.
652653
653654
Args:
@@ -686,7 +687,7 @@ def volume(self) -> float:
686687
# this property should be overwritten when the volume can be calculated directly
687688
return self.cell_volumes.sum() # type: ignore
688689

689-
def point_to_cartesian(self, points: NumericArray) -> NumericArray:
690+
def point_to_cartesian(self, points: FloatingArray) -> FloatingArray:
690691
"""Convert coordinates of a point in grid coordinates to Cartesian coordinates.
691692
692693
Args:
@@ -698,7 +699,7 @@ def point_to_cartesian(self, points: NumericArray) -> NumericArray:
698699
"""
699700
return self.c.pos_to_cart(self._coords_full(points))
700701

701-
def point_from_cartesian(self, points: NumericArray) -> NumericArray:
702+
def point_from_cartesian(self, points: FloatingArray) -> FloatingArray:
702703
"""Convert points given in Cartesian coordinates to grid coordinates.
703704
704705
Args:
@@ -711,7 +712,7 @@ def point_from_cartesian(self, points: NumericArray) -> NumericArray:
711712
return self._coords_symmetric(self.c.pos_from_cart(points))
712713

713714
def _vector_to_cartesian(
714-
self, points: ArrayLike, components: ArrayLike
715+
self, points: FloatingArray, components: ArrayLike
715716
) -> NumericArray:
716717
"""Convert the vectors at given points into a Cartesian basis.
717718
@@ -745,8 +746,8 @@ def _vector_to_cartesian(
745746
return np.einsum("j...,ji...->i...", components, rot_mat) # type: ignore
746747

747748
def normalize_point(
748-
self, point: NumericArray, *, reflect: bool = False
749-
) -> NumericArray:
749+
self, point: FloatingArray, *, reflect: bool = False
750+
) -> FloatingArray:
750751
"""Normalize grid coordinates by applying periodic boundary conditions.
751752
752753
Here, points are assumed to be specified by the physical values along the
@@ -808,7 +809,7 @@ def normalize_point(
808809

809810
return point
810811

811-
def _coords_symmetric(self, points: NumericArray) -> NumericArray:
812+
def _coords_symmetric(self, points: FloatingArray) -> FloatingArray:
812813
"""Return only non-symmetric point coordinates.
813814
814815
Args:
@@ -824,8 +825,8 @@ def _coords_symmetric(self, points: NumericArray) -> NumericArray:
824825
return points[..., self._axes_described]
825826

826827
def _coords_full(
827-
self, points: NumericArray, *, value: Literal["min", "max"] | float = 0.0
828-
) -> NumericArray:
828+
self, points: FloatingArray, *, value: Literal["min", "max"] | float = 0.0
829+
) -> FloatingArray:
829830
"""Specify point coordinates along symmetric axes on grids.
830831
831832
Args:
@@ -861,8 +862,8 @@ def _coords_full(
861862
return res # type: ignore
862863

863864
def transform(
864-
self, coordinates: NumericArray, source: CoordsType, target: CoordsType
865-
) -> NumericArray:
865+
self, coordinates: FloatingArray, source: CoordsType, target: CoordsType
866+
) -> FloatingArray:
866867
"""Converts coordinates from one coordinate system to another.
867868
868869
Supported coordinate systems include the following:
@@ -956,10 +957,10 @@ def transform(
956957

957958
def contains_point(
958959
self,
959-
points: NumericArray,
960+
points: FloatingArray,
960961
*,
961962
coords: Literal["cartesian", "cell", "grid"] = "cartesian",
962-
) -> NumericArray:
963+
) -> NDArray[np.bool]:
963964
"""Check whether the point is contained in the grid.
964965
965966
Args:
@@ -973,11 +974,11 @@ def contains_point(
973974
the grid
974975
"""
975976
cell_coords = self.transform(points, source=coords, target="cell")
976-
return np.all((cell_coords >= 0) & (cell_coords <= self.shape), axis=-1) # type: ignore
977+
return np.all((cell_coords >= 0) & (cell_coords <= self.shape), axis=-1)
977978

978979
def iter_mirror_points(
979-
self, point: NumericArray, with_self: bool = False, only_periodic: bool = True
980-
) -> Generator:
980+
self, point: FloatingArray, with_self: bool = False, only_periodic: bool = True
981+
) -> Iterator[FloatingArray]:
981982
"""Generates all mirror points corresponding to `point`
982983
983984
Args:
@@ -1105,7 +1106,7 @@ def get_random_point(
11051106
boundary_distance: float = 0,
11061107
coords: CoordsType = "cartesian",
11071108
rng: np.random.Generator | None = None,
1108-
) -> NumericArray:
1109+
) -> FloatingArray:
11091110
"""Return a random point within the grid.
11101111
11111112
Args:
@@ -1555,7 +1556,7 @@ def integrate(
15551556
@cached_method()
15561557
def make_normalize_point_compiled(
15571558
self, reflect: bool = True
1558-
) -> Callable[[NumericArray], None]:
1559+
) -> Callable[[FloatingArray], None]:
15591560
"""Return a compiled function that normalizes a point.
15601561
15611562
Here, the point is assumed to be specified by the physical values along
@@ -1582,7 +1583,7 @@ def make_normalize_point_compiled(
15821583
size = bounds[:, 1] - bounds[:, 0]
15831584

15841585
@jit
1585-
def normalize_point(point: NumericArray) -> None:
1586+
def normalize_point(point: FloatingArray) -> None:
15861587
"""Helper function normalizing a single point."""
15871588
assert point.ndim == 1 # only support single points
15881589
for i in range(num_axes):
@@ -1730,7 +1731,7 @@ def _make_interpolator_compiled(
17301731
fill: Number | None = None,
17311732
with_ghost_cells: bool = False,
17321733
cell_coords: bool = False,
1733-
) -> Callable[[NumericArray, NumericArray], NumericArray]:
1734+
) -> Callable[[NumericArray, FloatingArray], NumericArray]:
17341735
"""Return a compiled function for linear interpolation on the grid.
17351736
17361737
Args:
@@ -1760,7 +1761,7 @@ def _make_interpolator_compiled(
17601761

17611762
@jit
17621763
def interpolate_single(
1763-
data: NumericArray, point: NumericArray
1764+
data: NumericArray, point: FloatingArray
17641765
) -> NumberOrArray:
17651766
"""Obtain interpolated value of data at a point.
17661767
@@ -1792,7 +1793,7 @@ def interpolate_single(
17921793

17931794
@jit
17941795
def interpolate_single(
1795-
data: NumericArray, point: NumericArray
1796+
data: NumericArray, point: FloatingArray
17961797
) -> NumberOrArray:
17971798
"""Obtain interpolated value of data at a point.
17981799
@@ -1832,7 +1833,7 @@ def interpolate_single(
18321833

18331834
@jit
18341835
def interpolate_single(
1835-
data: NumericArray, point: NumericArray
1836+
data: NumericArray, point: FloatingArray
18361837
) -> NumberOrArray:
18371838
"""Obtain interpolated value of data at a point.
18381839
@@ -1878,7 +1879,7 @@ def interpolate_single(
18781879

18791880
def make_inserter_compiled(
18801881
self, *, with_ghost_cells: bool = False
1881-
) -> Callable[[NumericArray, NumericArray, NumberOrArray], None]:
1882+
) -> Callable[[NumericArray, FloatingArray, NumberOrArray], None]:
18821883
"""Return a compiled function to insert values at interpolated positions.
18831884
18841885
Args:
@@ -1903,7 +1904,7 @@ def make_inserter_compiled(
19031904

19041905
@jit
19051906
def insert(
1906-
data: NumericArray, point: NumericArray, amount: NumberOrArray
1907+
data: NumericArray, point: FloatingArray, amount: NumberOrArray
19071908
) -> None:
19081909
"""Add an amount to a field at an interpolated position.
19091910
@@ -1938,7 +1939,7 @@ def insert(
19381939

19391940
@jit
19401941
def insert(
1941-
data: NumericArray, point: NumericArray, amount: NumberOrArray
1942+
data: NumericArray, point: FloatingArray, amount: NumberOrArray
19421943
) -> None:
19431944
"""Add an amount to a field at an interpolated position.
19441945
@@ -1985,7 +1986,7 @@ def insert(
19851986

19861987
@jit
19871988
def insert(
1988-
data: NumericArray, point: NumericArray, amount: NumberOrArray
1989+
data: NumericArray, point: FloatingArray, amount: NumberOrArray
19891990
) -> None:
19901991
"""Add an amount to a field at an interpolated position.
19911992

pde/grids/cartesian.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from ..tools.cuboid import Cuboid
1515
from ..tools.plotting import plot_on_axes
16-
from ..tools.typing import NumericArray
16+
from ..tools.typing import FloatingArray, NumericArray
1717
from .base import (
1818
CoordsType,
1919
DimensionError,
@@ -238,7 +238,7 @@ def get_random_point(
238238
boundary_distance: float = 0,
239239
coords: CoordsType = "cartesian",
240240
rng: np.random.Generator | None = None,
241-
) -> NumericArray:
241+
) -> FloatingArray:
242242
"""Return a random point within the grid.
243243
244244
Args:
@@ -274,8 +274,8 @@ def get_random_point(
274274
raise ValueError(f"Unknown coordinate system `{coords}`")
275275

276276
def difference_vector(
277-
self, p1: NumericArray, p2: NumericArray, *, coords: CoordsType = "grid"
278-
) -> NumericArray:
277+
self, p1: FloatingArray, p2: FloatingArray, *, coords: CoordsType = "grid"
278+
) -> FloatingArray:
279279
return self._difference_vector(
280280
p1, p2, coords=coords, periodic=self.periodic, axes_bounds=self.axes_bounds
281281
)

0 commit comments

Comments
 (0)