1111import logging
1212import warnings
1313from inspect import isabstract
14- from typing import Any , Callable
14+ from typing import TYPE_CHECKING , Any , Callable
1515
1616import numba as nb
1717import numpy as np
2323from ..tools .numba import is_jitted , jit
2424from ..tools .typing import BackendType , NumericArray , StepperHook , TField
2525
26+ if TYPE_CHECKING :
27+ from ..backends .base import BackendBase
28+
29+
2630_base_logger = logging .getLogger (__name__ .rsplit ("." , 1 )[0 ])
2731""":class:`logging.Logger`: Base logger for solvers."""
2832
@@ -49,22 +53,30 @@ class SolverBase:
4953 """dict: dictionary of all inheriting classes"""
5054
5155 _logger : logging .Logger
56+ _backend_name : BackendType
57+ __backend_obj : BackendBase | None
5258
53- def __init__ (self , pde : PDEBase , * , backend : BackendType = "auto" ):
59+ def __init__ (self , pde : PDEBase , * , backend : BackendBase | BackendType = "auto" ):
5460 """
5561 Args:
5662 pde (:class:`~pde.pdes.base.PDEBase`):
5763 The partial differential equation that should be solved
58- backend (str):
64+ backend (str or :class:`~pde.backends.base.BackendBase` ):
5965 Determines how the function is created. Accepted values are 'numpy` and
6066 'numba'. Alternatively, 'auto' lets the code decide for the most optimal
6167 backend.
6268 """
6369 self .pde = pde
64- self .backend = backend
6570 self .info : dict [str , Any ] = {"class" : self .__class__ .__name__ }
6671 if self .pde :
6772 self .info ["pde_class" ] = self .pde .__class__ .__name__
73+ if isinstance (backend , str ):
74+ self ._backend_name = backend
75+ self .__backend_obj = None
76+ else :
77+ # assume that `backend` is of type BackendBase
78+ self ._backend_name = backend .name # type: ignore
79+ self .__backend_obj = backend
6880
6981 def __init_subclass__ (cls , ** kwargs ):
7082 """Initialize class-level attributes of subclasses."""
@@ -124,6 +136,68 @@ def registered_solvers(cls) -> list[str]:
124136 """list of str: the names of the registered solvers"""
125137 return sorted (cls ._subclasses .keys ())
126138
139+ @property
140+ def backend (self ) -> BackendType :
141+ """str: The name of the backend used for this solver."""
142+ return self ._backend_name
143+
144+ @backend .setter
145+ def backend (self , value : BackendBase | BackendType ) -> None :
146+ """sets a new backend for the solver
147+
148+ This setter tries to ensure consistency and make sure that backends are not
149+ changed after the object has been loaded (i.e., after _backend_obj has been
150+ accessed). The method also raises a warning when the backend is changed (except
151+ if it was set to `auto` before). This allows solvers to react flexibly to
152+ changes in the backend, e.g., demanded by the PDE implementation.
153+
154+ Args:
155+ value:
156+ The backend object or name
157+ """
158+ # determine the name of the new backend
159+ if isinstance (value , str ):
160+ new_backend = value
161+ else :
162+ # assume value is of type BackendBase
163+ new_backend = value .name # type: ignore
164+
165+ # check whether the new name contradicts the old backend name
166+ if self ._backend_name in {"auto" , new_backend }:
167+ pass # nothing to do
168+ else :
169+ self ._logger .warning (
170+ "Changing the backend of the solver from `%s` to `%s`" ,
171+ self ._backend_name ,
172+ new_backend ,
173+ )
174+
175+ # check whether the new backend contradicts the old backend object
176+ if self .__backend_obj is not None and self .__backend_obj .name != new_backend :
177+ raise TypeError (
178+ "Tried changing the loaded backend of the solver from `%s` to `%s`" ,
179+ self .__backend_obj .name ,
180+ new_backend ,
181+ )
182+
183+ # set the new backend
184+ self ._backend_name = new_backend
185+ if not isinstance (value , str ):
186+ self .__backend_obj = value
187+
188+ @property
189+ def _backend_obj (self ) -> BackendBase :
190+ """:class:`~pde.backends.base.BackendBase`: The backend used for this solver."""
191+ if self .__backend_obj is None :
192+ from ..backends import backends
193+
194+ if self ._backend_name == "auto" :
195+ self ._backend_name = "numpy" # conservative fall-back
196+
197+ self .__backend_obj = backends [self ._backend_name ]
198+
199+ return self .__backend_obj
200+
127201 @property
128202 def _compiled (self ) -> bool :
129203 """bool: indicates whether functions need to be compiled"""
@@ -220,18 +294,9 @@ def _check_backend(self, rhs: Callable) -> None:
220294 else :
221295 self .info ["backend" ] = "undetermined"
222296
223- if self .backend != self .info ["backend" ]:
224- if self .backend == "auto" :
225- # solver did not care about a backend, so we simply use the solver one
226- self .backend = self .info ["backend" ]
227- else :
228- # there is a mismatch, which we need to report
229- self ._logger .warning (
230- "The PDE class requested a different backend (%s) than the solver "
231- "(%s), which might lead to incompatibilities" ,
232- self .backend ,
233- self .info ["backend" ],
234- )
297+ # adjust the backend of the solver to the one requested by the PDE
298+ if self .info ["backend" ] != "undetermined" :
299+ self .backend = self .info ["backend" ]
235300
236301 def _make_pde_rhs (
237302 self , state : TField , backend : BackendType = "auto"
0 commit comments