Skip to content

Commit 1ef5c69

Browse files
committed
Fixed some minor problems
1 parent 7c02321 commit 1ef5c69

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

pde/backends/numba/solvers.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,6 @@ def _make_adaptive_stepper_general(
176176
tolerance = solver.tolerance
177177
dt_min = solver.dt_min
178178

179-
signature_stepper = (
180-
nb.typeof(state.data),
181-
nb.double,
182-
nb.double,
183-
nb.double,
184-
nb.typeof(solver.info["dt_statistics"]),
185-
nb.typeof(solver._post_step_data_init),
186-
)
187-
188-
@jit(signature_stepper)
189179
def adaptive_stepper(
190180
state_data: NumericArray,
191181
t_start: float,
@@ -227,6 +217,19 @@ def adaptive_stepper(
227217

228218
return t, dt_opt, steps
229219

220+
if not nb.config.DISABLE_JIT:
221+
# do the compilation only when JIT is actually being done. This might be
222+
# disabled for debugging numba code or for determining test coverage
223+
signature_stepper = (
224+
nb.typeof(state.data),
225+
nb.double,
226+
nb.double,
227+
nb.double,
228+
nb.typeof(solver.info["dt_statistics"]),
229+
nb.typeof(solver._post_step_data_init),
230+
)
231+
adaptive_stepper = jit(signature_stepper)(adaptive_stepper)
232+
230233
solver._logger.info("Initialized adaptive stepper")
231234
return adaptive_stepper # type: ignore
232235

@@ -249,15 +252,19 @@ def _make_adaptive_stepper_euler(
249252
t_start: float, t_end: float)`
250253
"""
251254
stepper = solver._make_adaptive_stepper(state)
252-
signature = (
253-
nb.typeof(state.data),
254-
nb.double,
255-
nb.double,
256-
nb.double,
257-
nb.typeof(solver.info["dt_statistics"]),
258-
nb.typeof(solver._post_step_data_init),
259-
)
260-
return jit(signature)(stepper) # type: ignore
255+
if nb.config.DISABLE_JIT:
256+
# this can be useful to debug numba implementations and for test coverage checks
257+
return stepper
258+
else:
259+
signature = (
260+
nb.typeof(state.data),
261+
nb.double,
262+
nb.double,
263+
nb.double,
264+
nb.typeof(solver.info["dt_statistics"]),
265+
nb.typeof(solver._post_step_data_init),
266+
)
267+
return jit(signature)(stepper) # type: ignore
261268

262269

263270
def make_adaptive_stepper(

tests/solvers/test_explicit_mpi_solvers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def test_multiple_pdes_mpi(backend, rng):
9393
"t_range": 1.01,
9494
"dt": 0.1,
9595
"adaptive": True,
96-
"scheme": "euler",
9796
"tracker": None,
9897
"ret_info": True,
9998
}

0 commit comments

Comments
 (0)