1313from inspect import isabstract
1414from typing import TYPE_CHECKING , Any , Callable , Union
1515
16- import numba as nb
1716import numpy as np
1817from numba .extending import register_jitable
1918
2019from ..pdes .base import PDEBase
2120from ..tools .math import OnlineStatistics
2221from ..tools .misc import classproperty
23- from ..tools .numba import is_jitted , jit
22+ from ..tools .numba import is_jitted
2423from ..tools .typing import BackendType , NumericArray , StepperHook , TField
2524
2625if TYPE_CHECKING :
@@ -205,12 +204,6 @@ def _backend_obj(self) -> BackendBase:
205204
206205 return self .__backend_obj
207206
208- @property
209- def _compiled (self ) -> bool :
210- """bool: indicates whether functions need to be compiled"""
211- jit_enabled = not nb .config .DISABLE_JIT
212- return jit_enabled and self .backend != "numpy"
213-
214207 def _make_error_synchronizer (
215208 self , operator : int | str = "MAX"
216209 ) -> Callable [[float ], float ]:
@@ -283,12 +276,6 @@ def post_step_hook(
283276 # ensure that the initial values is a mutable array
284277 self ._post_step_data_init = np .array (self ._post_step_data_init , copy = True )
285278
286- self ._post_step_data_type = nb .typeof (self ._post_step_data_init )
287- if self ._compiled :
288- sig_hook = (nb .typeof (state .data ), nb .float64 , self ._post_step_data_type )
289- post_step_hook = jit (sig_hook )(post_step_hook )
290- self ._logger .debug ("Compiled post-step hook" )
291-
292279 return post_step_hook # type: ignore
293280
294281 def _check_backend (self , rhs : Callable ) -> None :
@@ -498,54 +485,6 @@ def __init__(
498485 self .adaptive = adaptive
499486 self .tolerance = tolerance
500487
501- def _make_dt_adjuster (self ) -> Callable [[float , float ], float ]:
502- """Return a function that can be used to adjust time steps."""
503- dt_min = self .dt_min
504- dt_min_nan_err = f"Encountered NaN even though dt < { dt_min } "
505- dt_min_err = f"Time step below { dt_min } "
506- dt_max = self .dt_max
507-
508- def adjust_dt (dt : float , error_rel : float ) -> float :
509- """Helper function that adjust the time step.
510-
511- The goal is to keep the relative error `error_rel` close to 1.
512-
513- Args:
514- dt (float): Current time step
515- error_rel (float): Current (normalized) error estimate
516-
517- Returns:
518- float: Time step of the next iteration
519- """
520- # adjust the time step
521- if error_rel < 0.00057665 :
522- # error was very small => maximal increase in dt
523- # The constant on the right hand side of the comparison is chosen to
524- # agree with the equation for adjusting dt below
525- dt *= 4.0
526- elif np .isnan (error_rel ):
527- # state contained NaN => decrease time step strongly
528- dt *= 0.25
529- else :
530- # otherwise, adjust time step according to error
531- dt *= max (0.9 * error_rel ** - 0.2 , 0.1 )
532-
533- # limit time step to permissible bracket
534- if dt > dt_max :
535- dt = dt_max
536- elif dt < dt_min :
537- if np .isnan (error_rel ):
538- raise RuntimeError (dt_min_nan_err )
539- else :
540- raise RuntimeError (dt_min_err )
541-
542- return dt
543-
544- if self ._compiled :
545- adjust_dt = jit ((nb .double , nb .double ))(adjust_dt )
546-
547- return adjust_dt
548-
549488 def _make_single_step_variable_dt (
550489 self , state : TField
551490 ) -> Callable [[NumericArray , float , float ], NumericArray ]:
@@ -583,9 +522,6 @@ def _make_single_step_error_estimate(
583522 raise RuntimeError ("Cannot use adaptive stepper with stochastic equation" )
584523
585524 single_step = self ._make_single_step_variable_dt (state )
586- if self ._compiled :
587- sig_single_step = (nb .typeof (state .data ), nb .double , nb .double )
588- single_step = jit (sig_single_step )(single_step )
589525
590526 def single_step_error_estimate (
591527 state_data : NumericArray , t : float , dt : float
@@ -605,13 +541,18 @@ def single_step_error_estimate(
605541
606542 return single_step_error_estimate
607543
608- def _make_adaptive_stepper (self , state : TField ) -> AdaptiveStepperType :
544+ def _make_adaptive_stepper (
545+ self , state : TField , * , adjust_dt : Callable [[float , float ], float ] | None = None
546+ ) -> AdaptiveStepperType :
609547 """Make an adaptive Euler stepper.
610548
611549 Args:
612550 state (:class:`~pde.fields.base.FieldBase`):
613551 An example for the state from which the grid and other information can
614552 be extracted
553+ adjust_dt (callable or None):
554+ A function that is used to adjust the time step. The function takes the
555+ current time step and a relative error and returns an adjusted time step.
615556
616557 Returns:
617558 Function that can be called to advance the `state` from time `t_start` to
@@ -624,15 +565,11 @@ def _make_adaptive_stepper(self, state: TField) -> AdaptiveStepperType:
624565 sync_errors = self ._make_error_synchronizer ()
625566
626567 # obtain auxiliary functions
627- adjust_dt = self ._make_dt_adjuster ()
568+ if adjust_dt is None :
569+ adjust_dt = _make_dt_adjuster (self .dt_min , self .dt_max )
628570 tolerance = self .tolerance
629571 dt_min = self .dt_min
630572
631- if self ._compiled :
632- # compile paired stepper
633- sig_stepper = (nb .typeof (state .data ), nb .double , nb .double )
634- single_step_error = jit (sig_stepper )(single_step_error )
635-
636573 def adaptive_stepper (
637574 state_data : NumericArray ,
638575 t_start : float ,
@@ -674,18 +611,6 @@ def adaptive_stepper(
674611
675612 return t , dt_opt , steps
676613
677- if self ._compiled :
678- # compile inner stepper
679- sig_adaptive = (
680- nb .typeof (state .data ),
681- nb .double ,
682- nb .double ,
683- nb .double ,
684- nb .typeof (self .info ["dt_statistics" ]),
685- self ._post_step_data_type ,
686- )
687- adaptive_stepper = jit (sig_adaptive )(adaptive_stepper )
688-
689614 self ._logger .info ("Initialized adaptive stepper" )
690615 return adaptive_stepper
691616
@@ -754,3 +679,68 @@ def wrapped_stepper(state: TField, t_start: float, t_end: float) -> float:
754679 return t_last
755680
756681 return wrapped_stepper
682+
683+
684+ def _make_dt_adjuster (dt_min : float , dt_max : float ) -> Callable [[float , float ], float ]:
685+ """Return a function that can be used to adjust time steps.
686+
687+ The returned function adjust_dt(dt, error_rel) adjusts the current time step
688+ `dt` based on the normalized error estimate `error_rel` with the goal of
689+ keeping `error_rel` close to 1.
690+
691+ Behavior:
692+ - If the error is very small the time step is increased (up to a factor 4).
693+ - If the error is NaN the time step is reduced strongly.
694+ - Otherwise the time step is scaled according to error_rel**-0.2 with a
695+ conservative lower bound for the scaling factor.
696+ - The adjusted time step is clamped to the interval [dt_min, dt_max].
697+ - If the adjusted time step falls below dt_min a RuntimeError is raised.
698+
699+ Args:
700+ dt_min (float): Minimal allowed time step.
701+ dt_max (float): Maximal allowed time step.
702+
703+ Returns:
704+ Callable[[float, float], float]:
705+ Function that takes (dt, error_rel) and returns the adjusted dt.
706+ """
707+ dt_min_nan_err = f"Encountered NaN even though dt < { dt_min } "
708+ dt_min_err = f"Time step below { dt_min } "
709+
710+ def adjust_dt (dt : float , error_rel : float ) -> float :
711+ """Helper function that adjust the time step.
712+
713+ The goal is to keep the relative error `error_rel` close to 1.
714+
715+ Args:
716+ dt (float): Current time step
717+ error_rel (float): Current (normalized) error estimate
718+
719+ Returns:
720+ float: Time step of the next iteration
721+ """
722+ # adjust the time step
723+ if error_rel < 0.00057665 :
724+ # error was very small => maximal increase in dt
725+ # The constant on the right hand side of the comparison is chosen to
726+ # agree with the equation for adjusting dt below
727+ dt *= 4.0
728+ elif np .isnan (error_rel ):
729+ # state contained NaN => decrease time step strongly
730+ dt *= 0.25
731+ else :
732+ # otherwise, adjust time step according to error
733+ dt *= max (0.9 * error_rel ** - 0.2 , 0.1 )
734+
735+ # limit time step to permissible bracket
736+ if dt > dt_max :
737+ dt = dt_max
738+ elif dt < dt_min :
739+ if np .isnan (error_rel ):
740+ raise RuntimeError (dt_min_nan_err )
741+ else :
742+ raise RuntimeError (dt_min_err )
743+
744+ return dt
745+
746+ return adjust_dt
0 commit comments