Skip to content

Commit 289bac2

Browse files
committed
Move the operator registry from individual backends to the registry class to allow flexible addition of operators (without loading the backend).
1 parent 767d5ad commit 289bac2

File tree

17 files changed

+186
-109
lines changed

17 files changed

+186
-109
lines changed

pde/backends/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
.. codeauthor:: David Zwicker <[email protected]>
1212
"""
1313

14-
# load the registry, which manages the backends
14+
# load the registry, which manages all backends
1515
from .registry import backends # noqa: I001
1616

1717
# load and register the default backend

pde/backends/base.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
import inspect
99
import logging
10-
from collections import defaultdict
11-
from typing import Any, Callable, Literal, NamedTuple
10+
from typing import Any, Callable, Literal
1211

1312
from ..fields import DataFieldBase
1413
from ..grids import BoundariesBase, GridBase
@@ -22,31 +21,21 @@
2221
NumberOrArray,
2322
NumericArray,
2423
OperatorFactory,
24+
OperatorInfo,
2525
TField,
2626
)
2727

2828
_base_logger = logging.getLogger(__name__.rsplit(".", 1)[0])
2929
""":class:`logging.Logger`: Base logger for backends."""
3030

3131

32-
class OperatorInfo(NamedTuple):
33-
"""Stores information about an operator."""
34-
35-
factory: OperatorFactory
36-
rank_in: int
37-
rank_out: int
38-
name: str = "" # attach a unique name to help caching
39-
40-
4132
class BackendBase:
4233
"""Basic backend from which all other backends inherit."""
4334

44-
_operators: dict[type[GridBase], dict[str, OperatorInfo]]
4535
_logger: logging.Logger # logger instance to output information
4636

4737
def __init__(self, name: str = "numpy"):
4838
self.name = name
49-
self._operators = defaultdict(dict)
5039

5140
def __init_subclass__(cls, **kwargs) -> None:
5241
"""Initialize class-level attributes of subclasses."""
@@ -59,8 +48,7 @@ def register_operator(
5948
grid_cls: type[GridBase],
6049
name: str,
6150
factory_func: OperatorFactory | None = None,
62-
rank_in: int = 0,
63-
rank_out: int = 0,
51+
**kwargs,
6452
):
6553
"""Register an operator for a particular grid.
6654
@@ -95,21 +83,9 @@ def make_operator(grid: GridBase): ...
9583
rank_out (int):
9684
The rank of the field that is returned by the operator
9785
"""
86+
from .registry import backends
9887

99-
def register_operator(factor_func_arg: OperatorFactory):
100-
"""Helper function to register the operator."""
101-
self._operators[grid_cls][name] = OperatorInfo(
102-
factory=factor_func_arg, rank_in=rank_in, rank_out=rank_out, name=name
103-
)
104-
return factor_func_arg
105-
106-
if factory_func is None:
107-
# method is used as a decorator, so return the helper function
108-
return register_operator
109-
else:
110-
# method is used directly
111-
register_operator(factory_func)
112-
return None
88+
backends.register_operator(self.name, grid_cls, name, factory_func, **kwargs)
11389

11490
def get_registered_operators(self, grid_id: GridBase | type[GridBase]) -> set[str]:
11591
"""Returns all operators defined for a grid.
@@ -118,14 +94,16 @@ def get_registered_operators(self, grid_id: GridBase | type[GridBase]) -> set[st
11894
grid (:class:`~pde.grid.base.GridBase` or its type):
11995
Grid for which the operator need to be returned
12096
"""
97+
from . import backends
98+
12199
grid_cls = grid_id if inspect.isclass(grid_id) else grid_id.__class__
122100

123101
# get all operators registered on the class
124102
operators = set()
125103
# add all custom defined operators
126104
classes = inspect.getmro(grid_cls)[:-1] # type: ignore
127105
for cls in classes:
128-
operators |= set(self._operators[cls].keys())
106+
operators |= set(backends._operators[self.name][cls].keys())
129107

130108
return operators
131109

@@ -149,11 +127,13 @@ def get_operator_info(
149127
return operator
150128
assert isinstance(operator, str)
151129

130+
from . import backends
131+
152132
# look for defined operators on all parent grid classes (except `object`)
153133
classes = inspect.getmro(grid.__class__)[:-1]
154134
for cls in classes:
155-
if operator in self._operators[cls]:
156-
return self._operators[cls][operator]
135+
if operator in backends._operators[self.name][cls]:
136+
return backends._operators[self.name][cls][operator]
157137

158138
# throw an error since operator was not found
159139
raise NotImplementedError(

pde/backends/numba/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from .backend import NumbaBackend
1313

1414
# add the loaded numba backend to the registry
15-
numba_backend = NumbaBackend("numba")
16-
backends.add(numba_backend)
15+
backends.add(NumbaBackend("numba"))
1716

1817
# register all the standard operators
1918
from . import operators

pde/backends/numba/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def make_inner_stepper(
728728
time `t_end`. The function call signature is `(state: numpy.ndarray,
729729
t_start: float, t_end: float)`
730730
"""
731-
assert solver.backend == "numba"
731+
assert solver.backend == self.name
732732

733733
from .solvers import make_adaptive_stepper, make_fixed_stepper
734734

pde/backends/numba/operators/cartesian.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ....tools.misc import module_available
2727
from ....tools.numba import jit
2828
from ....tools.typing import NumericArray, OperatorType
29-
from .. import numba_backend
29+
from ...registry import backends
3030

3131

3232
def make_corner_point_setter_2d(grid: CartesianGrid) -> Callable[[NumericArray], None]:
@@ -132,13 +132,13 @@ def laplace(arr: NumericArray, out: NumericArray) -> None:
132132
else:
133133
# use 9-point stencil with interpolated boundary conditions
134134
w = corner_weight
135-
numba_backend._logger.info(
135+
backends["numba"]._logger.info(
136136
"Create 2D Cartesian Laplacian with 9-point stencil (w=%.3g)", w
137137
)
138138

139139
if not np.isclose(*grid.discretization):
140140
# we have not yet found a good expression for the 9-point stencil for dx!=dy
141-
numba_backend._logger.warning(
141+
backends["numba"]._logger.warning(
142142
"9-point stencils with anisotropic grids are not tested and might "
143143
"produce wrong results."
144144
)
@@ -292,7 +292,7 @@ def laplace(arr: NumericArray, out: NumericArray) -> None:
292292
return laplace # type: ignore
293293

294294

295-
@numba_backend.register_operator(CartesianGrid, "laplace", rank_in=0, rank_out=0)
295+
@backends.register_operator("numba", CartesianGrid, "laplace", rank_in=0, rank_out=0)
296296
def make_laplace(
297297
grid: CartesianGrid, *, spectral: bool | None = None, **kwargs
298298
) -> OperatorType:
@@ -484,7 +484,7 @@ def gradient(arr: NumericArray, out: NumericArray) -> None:
484484
return gradient # type: ignore
485485

486486

487-
@numba_backend.register_operator(CartesianGrid, "gradient", rank_in=0, rank_out=1)
487+
@backends.register_operator("numba", CartesianGrid, "gradient", rank_in=0, rank_out=1)
488488
def make_gradient(
489489
grid: CartesianGrid,
490490
*,
@@ -681,8 +681,8 @@ def gradient_squared(arr: NumericArray, out: NumericArray) -> None:
681681
return gradient_squared # type: ignore
682682

683683

684-
@numba_backend.register_operator(
685-
CartesianGrid, "gradient_squared", rank_in=0, rank_out=0
684+
@backends.register_operator(
685+
"numba", CartesianGrid, "gradient_squared", rank_in=0, rank_out=0
686686
)
687687
def make_gradient_squared(grid: CartesianGrid, *, central: bool = True) -> OperatorType:
688688
"""Make a gradient operator on a Cartesian grid.
@@ -841,7 +841,7 @@ def divergence(arr: NumericArray, out: NumericArray) -> None:
841841
return divergence # type: ignore
842842

843843

844-
@numba_backend.register_operator(CartesianGrid, "divergence", rank_in=1, rank_out=0)
844+
@backends.register_operator("numba", CartesianGrid, "divergence", rank_in=1, rank_out=0)
845845
def make_divergence(
846846
grid: CartesianGrid,
847847
*,
@@ -904,8 +904,8 @@ def vectorized_operator(arr: NumericArray, out: NumericArray) -> None:
904904
return register_jitable(vectorized_operator) # type: ignore
905905

906906

907-
@numba_backend.register_operator(
908-
CartesianGrid, "vector_gradient", rank_in=1, rank_out=2
907+
@backends.register_operator(
908+
"numba", CartesianGrid, "vector_gradient", rank_in=1, rank_out=2
909909
)
910910
def make_vector_gradient(
911911
grid: CartesianGrid,
@@ -927,7 +927,9 @@ def make_vector_gradient(
927927
return _vectorize_operator(make_gradient, grid, method=method)
928928

929929

930-
@numba_backend.register_operator(CartesianGrid, "vector_laplace", rank_in=1, rank_out=1)
930+
@backends.register_operator(
931+
"numba", CartesianGrid, "vector_laplace", rank_in=1, rank_out=1
932+
)
931933
def make_vector_laplace(grid: CartesianGrid) -> OperatorType:
932934
"""Make a vector Laplacian on a Cartesian grid.
933935
@@ -941,8 +943,8 @@ def make_vector_laplace(grid: CartesianGrid) -> OperatorType:
941943
return _vectorize_operator(make_laplace, grid)
942944

943945

944-
@numba_backend.register_operator(
945-
CartesianGrid, "tensor_divergence", rank_in=2, rank_out=1
946+
@backends.register_operator(
947+
"numba", CartesianGrid, "tensor_divergence", rank_in=2, rank_out=1
946948
)
947949
def make_tensor_divergence(
948950
grid: CartesianGrid,

pde/backends/numba/operators/cylindrical_sym.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
from ....tools.docstrings import fill_in_docstring
2323
from ....tools.numba import jit
2424
from ....tools.typing import NumericArray, OperatorType
25-
from .. import numba_backend
25+
from ...registry import backends
2626

2727

28-
@numba_backend.register_operator(CylindricalSymGrid, "laplace", rank_in=0, rank_out=0)
28+
@backends.register_operator(
29+
"numba", CylindricalSymGrid, "laplace", rank_in=0, rank_out=0
30+
)
2931
@fill_in_docstring
3032
def make_laplace(grid: CylindricalSymGrid) -> OperatorType:
3133
"""Make a discretized laplace operator for a cylindrical grid.
@@ -64,7 +66,9 @@ def laplace(arr: NumericArray, out: NumericArray) -> None:
6466
return laplace # type: ignore
6567

6668

67-
@numba_backend.register_operator(CylindricalSymGrid, "gradient", rank_in=0, rank_out=1)
69+
@backends.register_operator(
70+
"numba", CylindricalSymGrid, "gradient", rank_in=0, rank_out=1
71+
)
6872
@fill_in_docstring
6973
def make_gradient(grid: CylindricalSymGrid) -> OperatorType:
7074
"""Make a discretized gradient operator for a cylindrical grid.
@@ -97,8 +101,8 @@ def gradient(arr: NumericArray, out: NumericArray) -> None:
97101
return gradient # type: ignore
98102

99103

100-
@numba_backend.register_operator(
101-
CylindricalSymGrid, "gradient_squared", rank_in=0, rank_out=0
104+
@backends.register_operator(
105+
"numba", CylindricalSymGrid, "gradient_squared", rank_in=0, rank_out=0
102106
)
103107
@fill_in_docstring
104108
def make_gradient_squared(
@@ -154,8 +158,8 @@ def gradient_squared(arr: NumericArray, out: NumericArray) -> None:
154158
return gradient_squared # type: ignore
155159

156160

157-
@numba_backend.register_operator(
158-
CylindricalSymGrid, "divergence", rank_in=1, rank_out=0
161+
@backends.register_operator(
162+
"numba", CylindricalSymGrid, "divergence", rank_in=1, rank_out=0
159163
)
160164
@fill_in_docstring
161165
def make_divergence(grid: CylindricalSymGrid) -> OperatorType:
@@ -194,8 +198,8 @@ def divergence(arr: NumericArray, out: NumericArray) -> None:
194198
return divergence # type: ignore
195199

196200

197-
@numba_backend.register_operator(
198-
CylindricalSymGrid, "vector_gradient", rank_in=1, rank_out=2
201+
@backends.register_operator(
202+
"numba", CylindricalSymGrid, "vector_gradient", rank_in=1, rank_out=2
199203
)
200204
@fill_in_docstring
201205
def make_vector_gradient(grid: CylindricalSymGrid) -> OperatorType:
@@ -244,8 +248,8 @@ def vector_gradient(arr: NumericArray, out: NumericArray) -> None:
244248
return vector_gradient # type: ignore
245249

246250

247-
@numba_backend.register_operator(
248-
CylindricalSymGrid, "vector_laplace", rank_in=1, rank_out=1
251+
@backends.register_operator(
252+
"numba", CylindricalSymGrid, "vector_laplace", rank_in=1, rank_out=1
249253
)
250254
@fill_in_docstring
251255
def make_vector_laplace(grid: CylindricalSymGrid) -> OperatorType:
@@ -306,8 +310,8 @@ def vector_laplace(arr: NumericArray, out: NumericArray) -> None:
306310
return vector_laplace # type: ignore
307311

308312

309-
@numba_backend.register_operator(
310-
CylindricalSymGrid, "tensor_divergence", rank_in=2, rank_out=1
313+
@backends.register_operator(
314+
"numba", CylindricalSymGrid, "tensor_divergence", rank_in=2, rank_out=1
311315
)
312316
@fill_in_docstring
313317
def make_tensor_divergence(grid: CylindricalSymGrid) -> OperatorType:

pde/backends/numba/operators/polar_sym.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
from ....tools.docstrings import fill_in_docstring
2222
from ....tools.numba import jit
2323
from ....tools.typing import NumericArray, OperatorType
24-
from .. import numba_backend
24+
from ...registry import backends
2525

2626

27-
@numba_backend.register_operator(PolarSymGrid, "laplace", rank_in=0, rank_out=0)
27+
@backends.register_operator("numba", PolarSymGrid, "laplace", rank_in=0, rank_out=0)
2828
@fill_in_docstring
2929
def make_laplace(grid: PolarSymGrid) -> OperatorType:
3030
"""Make a discretized laplace operator for a polar grid.
@@ -56,7 +56,7 @@ def laplace(arr: NumericArray, out: NumericArray) -> None:
5656
return laplace # type: ignore
5757

5858

59-
@numba_backend.register_operator(PolarSymGrid, "gradient", rank_in=0, rank_out=1)
59+
@backends.register_operator("numba", PolarSymGrid, "gradient", rank_in=0, rank_out=1)
6060
@fill_in_docstring
6161
def make_gradient(
6262
grid: PolarSymGrid, *, method: Literal["central", "forward", "backward"] = "central"
@@ -101,8 +101,8 @@ def gradient(arr: NumericArray, out: NumericArray) -> None:
101101
return gradient # type: ignore
102102

103103

104-
@numba_backend.register_operator(
105-
PolarSymGrid, "gradient_squared", rank_in=0, rank_out=0
104+
@backends.register_operator(
105+
"numba", PolarSymGrid, "gradient_squared", rank_in=0, rank_out=0
106106
)
107107
@fill_in_docstring
108108
def make_gradient_squared(grid: PolarSymGrid, *, central: bool = True) -> OperatorType:
@@ -152,7 +152,7 @@ def gradient_squared(arr: NumericArray, out: NumericArray) -> None:
152152
return gradient_squared # type: ignore
153153

154154

155-
@numba_backend.register_operator(PolarSymGrid, "divergence", rank_in=1, rank_out=0)
155+
@backends.register_operator("numba", PolarSymGrid, "divergence", rank_in=1, rank_out=0)
156156
@fill_in_docstring
157157
def make_divergence(grid: PolarSymGrid) -> OperatorType:
158158
"""Make a discretized divergence operator for a polar grid.
@@ -185,7 +185,9 @@ def divergence(arr: NumericArray, out: NumericArray) -> None:
185185
return divergence # type: ignore
186186

187187

188-
@numba_backend.register_operator(PolarSymGrid, "vector_gradient", rank_in=1, rank_out=2)
188+
@backends.register_operator(
189+
"numba", PolarSymGrid, "vector_gradient", rank_in=1, rank_out=2
190+
)
189191
@fill_in_docstring
190192
def make_vector_gradient(grid: PolarSymGrid) -> OperatorType:
191193
"""Make a discretized vector gradient operator for a polar grid.
@@ -224,8 +226,8 @@ def vector_gradient(arr: NumericArray, out: NumericArray) -> None:
224226
return vector_gradient # type: ignore
225227

226228

227-
@numba_backend.register_operator(
228-
PolarSymGrid, "tensor_divergence", rank_in=2, rank_out=1
229+
@backends.register_operator(
230+
"numba", PolarSymGrid, "tensor_divergence", rank_in=2, rank_out=1
229231
)
230232
@fill_in_docstring
231233
def make_tensor_divergence(grid: PolarSymGrid) -> OperatorType:

0 commit comments

Comments
 (0)