Skip to content

Commit 767d5ad

Browse files
committed
Moved some more functionality to the backends
1 parent 1ef5c69 commit 767d5ad

File tree

4 files changed

+156
-102
lines changed

4 files changed

+156
-102
lines changed

pde/backends/numba/solvers.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,60 @@
1616
AdaptiveStepperType,
1717
FixedStepperType,
1818
SolverBase,
19+
_make_dt_adjuster,
1920
)
2021
from ...tools.math import OnlineStatistics
2122
from ...tools.numba import jit
22-
from ...tools.typing import NumericArray, TField
23+
from ...tools.typing import NumericArray, StepperHook, TField
2324

2425
SingleStepType = Callable[[NumericArray, float], None]
2526

2627

28+
def _make_post_step_hook(solver: SolverBase, state: TField) -> StepperHook:
29+
"""Create a callable that executes the PDE's post-step hook.
30+
31+
The returned function calls the post-step hook provided by the PDE (if any)
32+
after each completed time step. If the PDE implements make_post_step_hook,
33+
this method attempts to obtain both the hook function and an initial value
34+
for the hook's mutable data by calling
35+
``post_step_hook, post_step_data_init = self.pde.make_post_step_hook(state)``.
36+
The initial data is stored on the solver instance as ``self._post_step_data_init``
37+
(copied to ensure mutability) and will be passed to the hook when the stepper
38+
is executed.
39+
40+
If no hook is provided by the PDE (i.e., ``make_post_step_hook`` raises
41+
:class:`NotImplementedError`) or if the solver's ``_use_post_step_hook`` flag
42+
is ``False``, a no-op hook is returned and ``self._post_step_data_init`` is set
43+
to ``None``.
44+
45+
The hook returned by this method always conforms to the signature
46+
``(state_data: numpy.ndarray, t: float, post_step_data: numpy.ndarray) -> None``
47+
and is suitable for JIT compilation where supported.
48+
49+
Args:
50+
solver (:class:`~pde.solvers.base.SolverBase`):
51+
The solver instance, which determines how the hook is constructed
52+
state (:class:`~pde.fields.base.FieldBase`):
53+
Example field providing the array shape and grid information required
54+
by the PDE when constructing the post-step hook.
55+
56+
Returns:
57+
callable:
58+
A function that invokes the PDE's post-step hook (or a no-op) with the
59+
signature described above.
60+
"""
61+
# get uncompiled post_step_hook
62+
post_step_hook = solver._make_post_step_hook(state)
63+
64+
# compile post_step_hook
65+
post_step_data_type = nb.typeof(solver._post_step_data_init)
66+
signature_hook = (nb.typeof(state.data), nb.float64, post_step_data_type)
67+
post_step_hook = jit(signature_hook)(post_step_hook)
68+
69+
solver._logger.debug("Compiled post-step hook")
70+
return post_step_hook # type: ignore
71+
72+
2773
def _make_fixed_stepper(
2874
solver: SolverBase, state: TField, dt: float
2975
) -> FixedStepperType:
@@ -42,7 +88,7 @@ def _make_fixed_stepper(
4288
single_step = solver._make_single_step_fixed_dt(state, dt)
4389
single_step_signature = (nb.typeof(state.data), nb.double)
4490
single_step = jit(single_step_signature)(single_step)
45-
post_step_hook = solver._make_post_step_hook(state)
91+
post_step_hook = _make_post_step_hook(solver, state)
4692

4793
# provide compiled function doing all steps
4894
fixed_stepper_signature = (
@@ -86,7 +132,7 @@ def _make_adams_bashforth_stepper(
86132
raise NotImplementedError
87133

88134
rhs_pde = solver._make_pde_rhs(state, backend=solver.backend)
89-
post_step_hook = solver._make_post_step_hook(state)
135+
post_step_hook = _make_post_step_hook(solver, state)
90136
sig_single_step = (nb.typeof(state.data), nb.double, nb.typeof(state.data))
91137

92138
@jit(sig_single_step)
@@ -168,11 +214,12 @@ def _make_adaptive_stepper_general(
168214
single_step_error = solver._make_single_step_error_estimate(state)
169215
signature_single_step = (nb.typeof(state.data), nb.double, nb.double)
170216
single_step_error = jit(signature_single_step)(single_step_error)
171-
post_step_hook = solver._make_post_step_hook(state)
217+
post_step_hook = _make_post_step_hook(solver, state)
172218
sync_errors = solver._make_error_synchronizer()
173219

174220
# obtain auxiliary functions
175-
adjust_dt = solver._make_dt_adjuster()
221+
signature = (nb.double, nb.double)
222+
adjust_dt = jit(signature)(_make_dt_adjuster(solver.dt_min, solver.dt_max))
176223
tolerance = solver.tolerance
177224
dt_min = solver.dt_min
178225

@@ -231,7 +278,7 @@ def adaptive_stepper(
231278
adaptive_stepper = jit(signature_stepper)(adaptive_stepper)
232279

233280
solver._logger.info("Initialized adaptive stepper")
234-
return adaptive_stepper # type: ignore
281+
return adaptive_stepper
235282

236283

237284
def _make_adaptive_stepper_euler(
@@ -251,11 +298,21 @@ def _make_adaptive_stepper_euler(
251298
time `t_end`. The function call signature is `(state: numpy.ndarray,
252299
t_start: float, t_end: float)`
253300
"""
254-
stepper = solver._make_adaptive_stepper(state)
255301
if nb.config.DISABLE_JIT:
256302
# this can be useful to debug numba implementations and for test coverage checks
257-
return stepper
303+
return solver._make_adaptive_stepper(state)
258304
else:
305+
# create compiled function for adjusting the time step
306+
adjust_dt = _make_dt_adjuster(solver.dt_min, solver.dt_max)
307+
adjust_signature = (nb.double, nb.double)
308+
adjust_dt = jit(adjust_signature)(adjust_dt)
309+
310+
# create the adaptive stepper and compile it
311+
stepper = solver._make_adaptive_stepper(
312+
state,
313+
post_step_hook=_make_post_step_hook(solver, state),
314+
adjust_dt=adjust_dt,
315+
)
259316
signature = (
260317
nb.typeof(state.data),
261318
nb.double,

pde/solvers/base.py

Lines changed: 74 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
from inspect import isabstract
1414
from typing import TYPE_CHECKING, Any, Callable, Union
1515

16-
import numba as nb
1716
import numpy as np
1817
from numba.extending import register_jitable
1918

2019
from ..pdes.base import PDEBase
2120
from ..tools.math import OnlineStatistics
2221
from ..tools.misc import classproperty
23-
from ..tools.numba import is_jitted, jit
22+
from ..tools.numba import is_jitted
2423
from ..tools.typing import BackendType, NumericArray, StepperHook, TField
2524

2625
if TYPE_CHECKING:
@@ -205,12 +204,6 @@ def _backend_obj(self) -> BackendBase:
205204

206205
return self.__backend_obj
207206

208-
@property
209-
def _compiled(self) -> bool:
210-
"""bool: indicates whether functions need to be compiled"""
211-
jit_enabled = not nb.config.DISABLE_JIT
212-
return jit_enabled and self.backend != "numpy"
213-
214207
def _make_error_synchronizer(
215208
self, operator: int | str = "MAX"
216209
) -> Callable[[float], float]:
@@ -283,12 +276,6 @@ def post_step_hook(
283276
# ensure that the initial values is a mutable array
284277
self._post_step_data_init = np.array(self._post_step_data_init, copy=True)
285278

286-
self._post_step_data_type = nb.typeof(self._post_step_data_init)
287-
if self._compiled:
288-
sig_hook = (nb.typeof(state.data), nb.float64, self._post_step_data_type)
289-
post_step_hook = jit(sig_hook)(post_step_hook)
290-
self._logger.debug("Compiled post-step hook")
291-
292279
return post_step_hook # type: ignore
293280

294281
def _check_backend(self, rhs: Callable) -> None:
@@ -498,54 +485,6 @@ def __init__(
498485
self.adaptive = adaptive
499486
self.tolerance = tolerance
500487

501-
def _make_dt_adjuster(self) -> Callable[[float, float], float]:
502-
"""Return a function that can be used to adjust time steps."""
503-
dt_min = self.dt_min
504-
dt_min_nan_err = f"Encountered NaN even though dt < {dt_min}"
505-
dt_min_err = f"Time step below {dt_min}"
506-
dt_max = self.dt_max
507-
508-
def adjust_dt(dt: float, error_rel: float) -> float:
509-
"""Helper function that adjust the time step.
510-
511-
The goal is to keep the relative error `error_rel` close to 1.
512-
513-
Args:
514-
dt (float): Current time step
515-
error_rel (float): Current (normalized) error estimate
516-
517-
Returns:
518-
float: Time step of the next iteration
519-
"""
520-
# adjust the time step
521-
if error_rel < 0.00057665:
522-
# error was very small => maximal increase in dt
523-
# The constant on the right hand side of the comparison is chosen to
524-
# agree with the equation for adjusting dt below
525-
dt *= 4.0
526-
elif np.isnan(error_rel):
527-
# state contained NaN => decrease time step strongly
528-
dt *= 0.25
529-
else:
530-
# otherwise, adjust time step according to error
531-
dt *= max(0.9 * error_rel**-0.2, 0.1)
532-
533-
# limit time step to permissible bracket
534-
if dt > dt_max:
535-
dt = dt_max
536-
elif dt < dt_min:
537-
if np.isnan(error_rel):
538-
raise RuntimeError(dt_min_nan_err)
539-
else:
540-
raise RuntimeError(dt_min_err)
541-
542-
return dt
543-
544-
if self._compiled:
545-
adjust_dt = jit((nb.double, nb.double))(adjust_dt)
546-
547-
return adjust_dt
548-
549488
def _make_single_step_variable_dt(
550489
self, state: TField
551490
) -> Callable[[NumericArray, float, float], NumericArray]:
@@ -583,9 +522,6 @@ def _make_single_step_error_estimate(
583522
raise RuntimeError("Cannot use adaptive stepper with stochastic equation")
584523

585524
single_step = self._make_single_step_variable_dt(state)
586-
if self._compiled:
587-
sig_single_step = (nb.typeof(state.data), nb.double, nb.double)
588-
single_step = jit(sig_single_step)(single_step)
589525

590526
def single_step_error_estimate(
591527
state_data: NumericArray, t: float, dt: float
@@ -605,13 +541,18 @@ def single_step_error_estimate(
605541

606542
return single_step_error_estimate
607543

608-
def _make_adaptive_stepper(self, state: TField) -> AdaptiveStepperType:
544+
def _make_adaptive_stepper(
545+
self, state: TField, *, adjust_dt: Callable[[float, float], float] | None = None
546+
) -> AdaptiveStepperType:
609547
"""Make an adaptive Euler stepper.
610548
611549
Args:
612550
state (:class:`~pde.fields.base.FieldBase`):
613551
An example for the state from which the grid and other information can
614552
be extracted
553+
adjust_dt (callable or None):
554+
A function that is used to adjust the time step. The function takes the
555+
current time step and a relative error and returns an adjusted time step.
615556
616557
Returns:
617558
Function that can be called to advance the `state` from time `t_start` to
@@ -624,15 +565,11 @@ def _make_adaptive_stepper(self, state: TField) -> AdaptiveStepperType:
624565
sync_errors = self._make_error_synchronizer()
625566

626567
# obtain auxiliary functions
627-
adjust_dt = self._make_dt_adjuster()
568+
if adjust_dt is None:
569+
adjust_dt = _make_dt_adjuster(self.dt_min, self.dt_max)
628570
tolerance = self.tolerance
629571
dt_min = self.dt_min
630572

631-
if self._compiled:
632-
# compile paired stepper
633-
sig_stepper = (nb.typeof(state.data), nb.double, nb.double)
634-
single_step_error = jit(sig_stepper)(single_step_error)
635-
636573
def adaptive_stepper(
637574
state_data: NumericArray,
638575
t_start: float,
@@ -674,18 +611,6 @@ def adaptive_stepper(
674611

675612
return t, dt_opt, steps
676613

677-
if self._compiled:
678-
# compile inner stepper
679-
sig_adaptive = (
680-
nb.typeof(state.data),
681-
nb.double,
682-
nb.double,
683-
nb.double,
684-
nb.typeof(self.info["dt_statistics"]),
685-
self._post_step_data_type,
686-
)
687-
adaptive_stepper = jit(sig_adaptive)(adaptive_stepper)
688-
689614
self._logger.info("Initialized adaptive stepper")
690615
return adaptive_stepper
691616

@@ -754,3 +679,68 @@ def wrapped_stepper(state: TField, t_start: float, t_end: float) -> float:
754679
return t_last
755680

756681
return wrapped_stepper
682+
683+
684+
def _make_dt_adjuster(dt_min: float, dt_max: float) -> Callable[[float, float], float]:
685+
"""Return a function that can be used to adjust time steps.
686+
687+
The returned function adjust_dt(dt, error_rel) adjusts the current time step
688+
`dt` based on the normalized error estimate `error_rel` with the goal of
689+
keeping `error_rel` close to 1.
690+
691+
Behavior:
692+
- If the error is very small the time step is increased (up to a factor 4).
693+
- If the error is NaN the time step is reduced strongly.
694+
- Otherwise the time step is scaled according to error_rel**-0.2 with a
695+
conservative lower bound for the scaling factor.
696+
- The adjusted time step is clamped to the interval [dt_min, dt_max].
697+
- If the adjusted time step falls below dt_min a RuntimeError is raised.
698+
699+
Args:
700+
dt_min (float): Minimal allowed time step.
701+
dt_max (float): Maximal allowed time step.
702+
703+
Returns:
704+
Callable[[float, float], float]:
705+
Function that takes (dt, error_rel) and returns the adjusted dt.
706+
"""
707+
dt_min_nan_err = f"Encountered NaN even though dt < {dt_min}"
708+
dt_min_err = f"Time step below {dt_min}"
709+
710+
def adjust_dt(dt: float, error_rel: float) -> float:
711+
"""Helper function that adjust the time step.
712+
713+
The goal is to keep the relative error `error_rel` close to 1.
714+
715+
Args:
716+
dt (float): Current time step
717+
error_rel (float): Current (normalized) error estimate
718+
719+
Returns:
720+
float: Time step of the next iteration
721+
"""
722+
# adjust the time step
723+
if error_rel < 0.00057665:
724+
# error was very small => maximal increase in dt
725+
# The constant on the right hand side of the comparison is chosen to
726+
# agree with the equation for adjusting dt below
727+
dt *= 4.0
728+
elif np.isnan(error_rel):
729+
# state contained NaN => decrease time step strongly
730+
dt *= 0.25
731+
else:
732+
# otherwise, adjust time step according to error
733+
dt *= max(0.9 * error_rel**-0.2, 0.1)
734+
735+
# limit time step to permissible bracket
736+
if dt > dt_max:
737+
dt = dt_max
738+
elif dt < dt_min:
739+
if np.isnan(error_rel):
740+
raise RuntimeError(dt_min_nan_err)
741+
else:
742+
raise RuntimeError(dt_min_err)
743+
744+
return dt
745+
746+
return adjust_dt

pde/solvers/crank_nicolson.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from typing import Callable
99

10-
import numba as nb
1110
import numpy as np
1211

1312
from ..pdes.base import PDEBase

0 commit comments

Comments
 (0)