Skip to content

Commit 93b96a2

Browse files
committed
Improved backend handling in solvers
1 parent aa34522 commit 93b96a2

File tree

3 files changed

+125
-17
lines changed

3 files changed

+125
-17
lines changed

pde/backends/base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,24 @@ def make_sde_rhs(
341341
raise NotImplementedError(
342342
f"SDE right hand side not defined for backend {self.name}"
343343
)
344+
345+
def make_stepper(
346+
self, solver: SolverBase, state: TField, dt: float | None = None
347+
) -> Callable[[TField, float, float], float]:
348+
"""Return a stepper function using an explicit scheme.
349+
350+
Args:
351+
solver (:class:`~pde.solvers.base.SolverBase`):
352+
The solver instance, which determines how the stepper is constructed
353+
state (:class:`~pde.fields.base.FieldBase`):
354+
An example for the state from which the grid and other information can
355+
be extracted
356+
dt (float):
357+
Time step used (Uses :attr:`SolverBase.dt_default` if `None`)
358+
359+
Returns:
360+
Function that can be called to advance the `state` from time `t_start` to
361+
time `t_end`. The function call signature is `(state: numpy.ndarray,
362+
t_start: float, t_end: float)`
363+
"""
364+
raise NotImplementedError(f"Steppers are not defined for backend {self.name}")

pde/backends/numpy/backend.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99

1010
import numpy as np
1111

12-
from ...fields import DataFieldBase, FieldBase, VectorField
12+
from ...fields import DataFieldBase, VectorField
1313
from ...grids import BoundariesBase, GridBase
1414
from ...pdes import PDEBase
15+
from ...solvers import SolverBase
1516
from ...tools.typing import DataSetter, GhostCellSetter, NumericArray, TField
1617
from ..base import BackendBase, OperatorInfo
1718

@@ -283,3 +284,24 @@ def pde_rhs(state_data: NumericArray, t: float) -> NumericArray:
283284

284285
pde_rhs._backend = "numpy" # type: ignore
285286
return pde_rhs
287+
288+
def make_stepper(
289+
self, solver: SolverBase, state: TField, dt: float | None = None
290+
) -> Callable[[TField, float, float], float]:
291+
"""Return a stepper function using an explicit scheme.
292+
293+
Args:
294+
solver (:class:`~pde.solvers.base.SolverBase`):
295+
The solver instance, which determines how the stepper is constructed
296+
state (:class:`~pde.fields.base.FieldBase`):
297+
An example for the state from which the grid and other information can
298+
be extracted
299+
dt (float):
300+
Time step used (Uses :attr:`SolverBase.dt_default` if `None`)
301+
302+
Returns:
303+
Function that can be called to advance the `state` from time `t_start` to
304+
time `t_end`. The function call signature is `(state: numpy.ndarray,
305+
t_start: float, t_end: float)`
306+
"""
307+
return solver.make_stepper(state, dt=dt)

pde/solvers/base.py

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import logging
1212
import warnings
1313
from inspect import isabstract
14-
from typing import Any, Callable
14+
from typing import TYPE_CHECKING, Any, Callable
1515

1616
import numba as nb
1717
import numpy as np
@@ -23,6 +23,10 @@
2323
from ..tools.numba import is_jitted, jit
2424
from ..tools.typing import BackendType, NumericArray, StepperHook, TField
2525

26+
if TYPE_CHECKING:
27+
from ..backends.base import BackendBase
28+
29+
2630
_base_logger = logging.getLogger(__name__.rsplit(".", 1)[0])
2731
""":class:`logging.Logger`: Base logger for solvers."""
2832

@@ -49,22 +53,30 @@ class SolverBase:
4953
"""dict: dictionary of all inheriting classes"""
5054

5155
_logger: logging.Logger
56+
_backend_name: BackendType
57+
__backend_obj: BackendBase | None
5258

53-
def __init__(self, pde: PDEBase, *, backend: BackendType = "auto"):
59+
def __init__(self, pde: PDEBase, *, backend: BackendBase | BackendType = "auto"):
5460
"""
5561
Args:
5662
pde (:class:`~pde.pdes.base.PDEBase`):
5763
The partial differential equation that should be solved
58-
backend (str):
64+
backend (str or :class:`~pde.backends.base.BackendBase`):
5965
Determines how the function is created. Accepted values are 'numpy` and
6066
'numba'. Alternatively, 'auto' lets the code decide for the most optimal
6167
backend.
6268
"""
6369
self.pde = pde
64-
self.backend = backend
6570
self.info: dict[str, Any] = {"class": self.__class__.__name__}
6671
if self.pde:
6772
self.info["pde_class"] = self.pde.__class__.__name__
73+
if isinstance(backend, str):
74+
self._backend_name = backend
75+
self.__backend_obj = None
76+
else:
77+
# assume that `backend` is of type BackendBase
78+
self._backend_name = backend.name # type: ignore
79+
self.__backend_obj = backend
6880

6981
def __init_subclass__(cls, **kwargs):
7082
"""Initialize class-level attributes of subclasses."""
@@ -124,6 +136,68 @@ def registered_solvers(cls) -> list[str]:
124136
"""list of str: the names of the registered solvers"""
125137
return sorted(cls._subclasses.keys())
126138

139+
@property
140+
def backend(self) -> BackendType:
141+
"""str: The name of the backend used for this solver."""
142+
return self._backend_name
143+
144+
@backend.setter
145+
def backend(self, value: BackendBase | BackendType) -> None:
146+
"""sets a new backend for the solver
147+
148+
This setter tries to ensure consistency and make sure that backends are not
149+
changed after the object has been loaded (i.e., after _backend_obj has been
150+
accessed). The method also raises a warning when the backend is changed (except
151+
if it was set to `auto` before). This allows solvers to react flexibly to
152+
changes in the backend, e.g., demanded by the PDE implementation.
153+
154+
Args:
155+
value:
156+
The backend object or name
157+
"""
158+
# determine the name of the new backend
159+
if isinstance(value, str):
160+
new_backend = value
161+
else:
162+
# assume value is of type BackendBase
163+
new_backend = value.name # type: ignore
164+
165+
# check whether the new name contradicts the old backend name
166+
if self._backend_name in {"auto", new_backend}:
167+
pass # nothing to do
168+
else:
169+
self._logger.warning(
170+
"Changing the backend of the solver from `%s` to `%s`",
171+
self._backend_name,
172+
new_backend,
173+
)
174+
175+
# check whether the new backend contradicts the old backend object
176+
if self.__backend_obj is not None and self.__backend_obj.name != new_backend:
177+
raise TypeError(
178+
"Tried changing the loaded backend of the solver from `%s` to `%s`",
179+
self.__backend_obj.name,
180+
new_backend,
181+
)
182+
183+
# set the new backend
184+
self._backend_name = new_backend
185+
if not isinstance(value, str):
186+
self.__backend_obj = value
187+
188+
@property
189+
def _backend_obj(self) -> BackendBase:
190+
""":class:`~pde.backends.base.BackendBase`: The backend used for this solver."""
191+
if self.__backend_obj is None:
192+
from ..backends import backends
193+
194+
if self._backend_name == "auto":
195+
self._backend_name = "numpy" # conservative fall-back
196+
197+
self.__backend_obj = backends[self._backend_name]
198+
199+
return self.__backend_obj
200+
127201
@property
128202
def _compiled(self) -> bool:
129203
"""bool: indicates whether functions need to be compiled"""
@@ -220,18 +294,9 @@ def _check_backend(self, rhs: Callable) -> None:
220294
else:
221295
self.info["backend"] = "undetermined"
222296

223-
if self.backend != self.info["backend"]:
224-
if self.backend == "auto":
225-
# solver did not care about a backend, so we simply use the solver one
226-
self.backend = self.info["backend"]
227-
else:
228-
# there is a mismatch, which we need to report
229-
self._logger.warning(
230-
"The PDE class requested a different backend (%s) than the solver "
231-
"(%s), which might lead to incompatibilities",
232-
self.backend,
233-
self.info["backend"],
234-
)
297+
# adjust the backend of the solver to the one requested by the PDE
298+
if self.info["backend"] != "undetermined":
299+
self.backend = self.info["backend"]
235300

236301
def _make_pde_rhs(
237302
self, state: TField, backend: BackendType = "auto"

0 commit comments

Comments
 (0)