Skip to content

Commit 7c02321

Browse files
committed
Moved some code from solvers to backend
Split the ExplicitSolver into EulerSolver and RungeKuttaSolver
1 parent 93b96a2 commit 7c02321

18 files changed

+540
-265
lines changed

pde/backends/base.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import inspect
99
import logging
1010
from collections import defaultdict
11-
from typing import Any, Callable, NamedTuple
11+
from typing import Any, Callable, Literal, NamedTuple
1212

1313
from ..fields import DataFieldBase
1414
from ..grids import BoundariesBase, GridBase
@@ -342,9 +342,13 @@ def make_sde_rhs(
342342
f"SDE right hand side not defined for backend {self.name}"
343343
)
344344

345-
def make_stepper(
346-
self, solver: SolverBase, state: TField, dt: float | None = None
347-
) -> Callable[[TField, float, float], float]:
345+
def make_inner_stepper(
346+
self,
347+
solver: SolverBase,
348+
stepper_style: Literal["fixed", "adaptive"],
349+
state: TField,
350+
dt: float,
351+
) -> Callable:
348352
"""Return a stepper function using an explicit scheme.
349353
350354
Args:
@@ -353,8 +357,8 @@ def make_stepper(
353357
state (:class:`~pde.fields.base.FieldBase`):
354358
An example for the state from which the grid and other information can
355359
be extracted
356-
dt (float):
357-
Time step used (Uses :attr:`SolverBase.dt_default` if `None`)
360+
stepper_style (str):
361+
Determines how the stepper is expected to work
358362
359363
Returns:
360364
Function that can be called to advance the `state` from time `t_start` to

pde/backends/numba/backend.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from __future__ import annotations
77

88
import functools
9-
from typing import Callable
9+
from typing import Callable, Literal
1010

1111
import numba as nb
1212
import numpy as np
@@ -16,6 +16,7 @@
1616
from ...fields import DataFieldBase, VectorField
1717
from ...grids import BoundariesBase, DimensionError, GridBase
1818
from ...pdes import PDEBase
19+
from ...solvers import AdaptiveSolverBase, SolverBase
1920
from ...tools.numba import get_common_numba_dtype, jit, make_array_constructor
2021
from ...tools.typing import (
2122
DataSetter,
@@ -703,3 +704,38 @@ def make_sde_rhs(
703704
together with a noise realization.
704705
"""
705706
return eq._make_sde_rhs_numba_cached(state, **kwargs)
707+
708+
def make_inner_stepper(
709+
self,
710+
solver: SolverBase,
711+
stepper_style: Literal["fixed", "adaptive"],
712+
state: TField,
713+
dt: float,
714+
) -> Callable:
715+
"""Return a stepper function using an explicit scheme.
716+
717+
Args:
718+
solver (:class:`~pde.solvers.base.SolverBase`):
719+
The solver instance, which determines how the stepper is constructed
720+
state (:class:`~pde.fields.base.FieldBase`):
721+
An example for the state from which the grid and other information can
722+
be extracted
723+
dt (float):
724+
Time step used (Uses :attr:`SolverBase.dt_default` if `None`)
725+
726+
Returns:
727+
Function that can be called to advance the `state` from time `t_start` to
728+
time `t_end`. The function call signature is `(state: numpy.ndarray,
729+
t_start: float, t_end: float)`
730+
"""
731+
assert solver.backend == "numba"
732+
733+
from .solvers import make_adaptive_stepper, make_fixed_stepper
734+
735+
if stepper_style == "fixed":
736+
return make_fixed_stepper(solver, state, dt=dt)
737+
elif stepper_style == "adaptive":
738+
assert isinstance(solver, AdaptiveSolverBase)
739+
return make_adaptive_stepper(solver, state)
740+
else:
741+
raise NotImplementedError

pde/backends/numba/solvers.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
"""Implements numba-accelerated solvers.
2+
3+
.. codeauthor:: David Zwicker <[email protected]>
4+
"""
5+
6+
from __future__ import annotations
7+
8+
from typing import Callable
9+
10+
import numba as nb
11+
import numpy as np
12+
13+
from ...solvers import *
14+
from ...solvers.base import (
15+
AdaptiveSolverBase,
16+
AdaptiveStepperType,
17+
FixedStepperType,
18+
SolverBase,
19+
)
20+
from ...tools.math import OnlineStatistics
21+
from ...tools.numba import jit
22+
from ...tools.typing import NumericArray, TField
23+
24+
SingleStepType = Callable[[NumericArray, float], None]
25+
26+
27+
def _make_fixed_stepper(
28+
solver: SolverBase, state: TField, dt: float
29+
) -> FixedStepperType:
30+
"""Return a stepper function using an explicit scheme with fixed time steps.
31+
32+
Args:
33+
solver (:class:`~pde.solvers.base.SolverBase`):
34+
The solver instance, which determines how the stepper is constructed
35+
state (:class:`~pde.fields.base.FieldBase`):
36+
An example for the state from which the grid and other information can
37+
be extracted
38+
dt (float):
39+
Time step of the explicit stepping.
40+
"""
41+
# get compiled version of a single step
42+
single_step = solver._make_single_step_fixed_dt(state, dt)
43+
single_step_signature = (nb.typeof(state.data), nb.double)
44+
single_step = jit(single_step_signature)(single_step)
45+
post_step_hook = solver._make_post_step_hook(state)
46+
47+
# provide compiled function doing all steps
48+
fixed_stepper_signature = (
49+
nb.typeof(state.data),
50+
nb.double,
51+
nb.int_,
52+
nb.typeof(solver._post_step_data_init),
53+
)
54+
55+
@jit(fixed_stepper_signature)
56+
def fixed_stepper(
57+
state_data: NumericArray, t_start: float, steps: int, post_step_data
58+
) -> float:
59+
"""Perform `steps` steps with fixed time steps."""
60+
for i in range(steps):
61+
# calculate the right hand side
62+
t = t_start + i * dt
63+
single_step(state_data, t)
64+
post_step_hook(state_data, t, post_step_data)
65+
66+
return t + dt
67+
68+
return fixed_stepper # type: ignore
69+
70+
71+
def _make_adams_bashforth_stepper(
72+
solver: AdamsBashforthSolver, state: TField, dt: float
73+
) -> FixedStepperType:
74+
"""Return a stepper function using an explicit scheme with fixed time steps.
75+
76+
Args:
77+
solver (:class:`~pde.solvers.adams_bashforth.AdamsBashforthSolver`):
78+
The solver instance, which determines how the stepper is constructed
79+
state (:class:`~pde.fields.base.FieldBase`):
80+
An example for the state from which the grid and other information can
81+
be extracted
82+
dt (float):
83+
Time step of the explicit stepping.
84+
"""
85+
if solver.pde.is_sde:
86+
raise NotImplementedError
87+
88+
rhs_pde = solver._make_pde_rhs(state, backend=solver.backend)
89+
post_step_hook = solver._make_post_step_hook(state)
90+
sig_single_step = (nb.typeof(state.data), nb.double, nb.typeof(state.data))
91+
92+
@jit(sig_single_step)
93+
def single_step(
94+
state_data: NumericArray, t: float, state_prev: NumericArray
95+
) -> None:
96+
"""Perform a single Adams-Bashforth step."""
97+
rhs_prev = rhs_pde(state_prev, t - dt).copy()
98+
rhs_cur = rhs_pde(state_data, t)
99+
state_prev[:] = state_data # save the previous state
100+
state_data += dt * (1.5 * rhs_cur - 0.5 * rhs_prev)
101+
102+
# allocate memory to store the state of the previous time step
103+
state_prev = np.empty_like(state.data)
104+
init_state_prev = True
105+
106+
def fixed_stepper(
107+
state_data: NumericArray, t_start: float, steps: int, post_step_data
108+
) -> float:
109+
"""Perform `steps` steps with fixed time steps."""
110+
nonlocal state_prev, init_state_prev
111+
112+
if init_state_prev:
113+
# initialize the state_prev with an estimate of the previous step
114+
state_prev[:] = state_data - dt * rhs_pde(state_data, t_start)
115+
init_state_prev = False
116+
117+
for i in range(steps):
118+
# calculate the right hand side
119+
t = t_start + i * dt
120+
single_step(state_data, t, state_prev)
121+
post_step_hook(state_data, t, post_step_data=post_step_data)
122+
123+
return t + dt
124+
125+
solver._logger.info("Init explicit Adams-Bashforth stepper with dt=%g", dt)
126+
127+
return fixed_stepper
128+
129+
130+
def make_fixed_stepper(
131+
solver: SolverBase, state: TField, dt: float
132+
) -> FixedStepperType:
133+
"""Return a stepper function using an explicit scheme with fixed time steps.
134+
135+
Args:
136+
solver (:class:`~pde.solvers.base.SolverBase`):
137+
The solver instance, which determines how the stepper is constructed
138+
state (:class:`~pde.fields.base.FieldBase`):
139+
An example for the state from which the grid and other information can
140+
be extracted
141+
dt (float):
142+
Time step of the explicit stepping.
143+
"""
144+
if isinstance(solver, AdamsBashforthSolver):
145+
return _make_adams_bashforth_stepper(solver, state, dt)
146+
else:
147+
return _make_fixed_stepper(solver, state, dt)
148+
149+
150+
def _make_adaptive_stepper_general(
151+
solver: AdaptiveSolverBase, state: TField
152+
) -> AdaptiveStepperType:
153+
"""Return a stepper function using an explicit scheme.
154+
155+
Args:
156+
solver (:class:`~pde.solvers.base.AdaptiveSolverBase`):
157+
The solver instance, which determines how the stepper is constructed
158+
state (:class:`~pde.fields.base.FieldBase`):
159+
An example for the state from which the grid and other information can
160+
be extracted
161+
162+
Returns:
163+
Function that can be called to advance the `state` from time `t_start` to
164+
time `t_end`. The function call signature is `(state: numpy.ndarray,
165+
t_start: float, t_end: float)`
166+
"""
167+
# obtain functions determining how the PDE is evolved
168+
single_step_error = solver._make_single_step_error_estimate(state)
169+
signature_single_step = (nb.typeof(state.data), nb.double, nb.double)
170+
single_step_error = jit(signature_single_step)(single_step_error)
171+
post_step_hook = solver._make_post_step_hook(state)
172+
sync_errors = solver._make_error_synchronizer()
173+
174+
# obtain auxiliary functions
175+
adjust_dt = solver._make_dt_adjuster()
176+
tolerance = solver.tolerance
177+
dt_min = solver.dt_min
178+
179+
signature_stepper = (
180+
nb.typeof(state.data),
181+
nb.double,
182+
nb.double,
183+
nb.double,
184+
nb.typeof(solver.info["dt_statistics"]),
185+
nb.typeof(solver._post_step_data_init),
186+
)
187+
188+
@jit(signature_stepper)
189+
def adaptive_stepper(
190+
state_data: NumericArray,
191+
t_start: float,
192+
t_end: float,
193+
dt_init: float,
194+
dt_stats: OnlineStatistics | None = None,
195+
post_step_data=None,
196+
) -> tuple[float, float, int]:
197+
"""Adaptive stepper that advances the state in time."""
198+
dt_opt = dt_init
199+
t = t_start
200+
steps = 0
201+
while True:
202+
# use a smaller (but not too small) time step if close to t_end
203+
dt_step = max(min(dt_opt, t_end - t), dt_min)
204+
205+
# try two different step sizes to estimate errors
206+
new_state, error = single_step_error(state_data, t, dt_step)
207+
208+
error_rel = error / tolerance # normalize error to given tolerance
209+
# synchronize the error between all processes (necessary for MPI)
210+
error_rel = sync_errors(error_rel)
211+
212+
# do the step if the error is sufficiently small
213+
if error_rel <= 1:
214+
steps += 1
215+
t += dt_step
216+
state_data[...] = new_state
217+
post_step_hook(state_data, t, post_step_data)
218+
219+
if dt_stats is not None:
220+
dt_stats.add(dt_step)
221+
222+
if t < t_end:
223+
# adjust the time step and continue (happens in every MPI process)
224+
dt_opt = adjust_dt(dt_step, error_rel)
225+
else:
226+
break # return to the controller
227+
228+
return t, dt_opt, steps
229+
230+
solver._logger.info("Initialized adaptive stepper")
231+
return adaptive_stepper # type: ignore
232+
233+
234+
def _make_adaptive_stepper_euler(
235+
solver: EulerSolver, state: TField
236+
) -> AdaptiveStepperType:
237+
"""Return a stepper function using an explicit scheme.
238+
239+
Args:
240+
solver (:class:`~pde.solvers.explicit.EulerSolver`):
241+
The solver instance, which determines how the stepper is constructed
242+
state (:class:`~pde.fields.base.FieldBase`):
243+
An example for the state from which the grid and other information can
244+
be extracted
245+
246+
Returns:
247+
Function that can be called to advance the `state` from time `t_start` to
248+
time `t_end`. The function call signature is `(state: numpy.ndarray,
249+
t_start: float, t_end: float)`
250+
"""
251+
stepper = solver._make_adaptive_stepper(state)
252+
signature = (
253+
nb.typeof(state.data),
254+
nb.double,
255+
nb.double,
256+
nb.double,
257+
nb.typeof(solver.info["dt_statistics"]),
258+
nb.typeof(solver._post_step_data_init),
259+
)
260+
return jit(signature)(stepper) # type: ignore
261+
262+
263+
def make_adaptive_stepper(
264+
solver: AdaptiveSolverBase, state: TField
265+
) -> AdaptiveStepperType:
266+
"""Return a stepper function using an explicit scheme.
267+
268+
Args:
269+
solver (:class:`~pde.solvers.base.AdaptiveSolverBase`):
270+
The solver instance, which determines how the stepper is constructed
271+
state (:class:`~pde.fields.base.FieldBase`):
272+
An example for the state from which the grid and other information can
273+
be extracted
274+
275+
Returns:
276+
Function that can be called to advance the `state` from time `t_start` to
277+
time `t_end`. The function call signature is `(state: numpy.ndarray,
278+
t_start: float, t_end: float)`
279+
"""
280+
if isinstance(solver, EulerSolver):
281+
return _make_adaptive_stepper_euler(solver, state)
282+
else:
283+
return _make_adaptive_stepper_general(solver, state)

0 commit comments

Comments
 (0)