Skip to content

Commit

Permalink
Make Numba the default backend for testing
Browse files Browse the repository at this point in the history
Do not merge this commit; it is intended for temporary use in a draft PR.
  • Loading branch information
brandonwillard committed Nov 10, 2022
1 parent 9436553 commit 72a268c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 14 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
runs-on: ubuntu-latest
if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }}
strategy:
fail-fast: true
fail-fast: false
matrix:
python-version: ["3.7", "3.9"]
fast-compile: [0]
Expand Down Expand Up @@ -132,7 +132,7 @@ jobs:
if [[ $FAST_COMPILE == "1" ]]; then export AESARA_FLAGS=$AESARA_FLAGS,mode=FAST_COMPILE; fi
if [[ $FLOAT32 == "1" ]]; then export AESARA_FLAGS=$AESARA_FLAGS,floatX=float32; fi
export AESARA_FLAGS=$AESARA_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
python -m pytest -x -r A --verbose --runslow --cov=aesara/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART
python -m pytest --verbose --runslow --cov=aesara/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART
env:
MATRIX_ID: ${{ steps.matrix-id.outputs.id }}
MKL_THREADING_LAYER: GNU
Expand Down
8 changes: 4 additions & 4 deletions aesara/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
"c": CLinker(), # Don't support gc. so don't check allow_gc
"c|py": OpWiseCLinker(), # Use allow_gc Aesara flag
"c|py_nogc": OpWiseCLinker(allow_gc=False),
"vm": VMLinker(use_cloop=False), # Use allow_gc Aesara flag
"vm": NumbaLinker(), # VMLinker(use_cloop=False), # Use allow_gc Aesara flag
"cvm": VMLinker(use_cloop=True), # Use allow_gc Aesara flag
"vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
"vm_nogc": NumbaLinker(), # VMLinker(allow_gc=False, use_cloop=False),
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(),
"numba": NumbaLinker(),
Expand Down Expand Up @@ -441,9 +441,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
# Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode(VMLinker(use_cloop=False, c_thunks=False), "fast_compile")
if config.cxx:
FAST_RUN = Mode("cvm", "fast_run")
FAST_RUN = Mode("numba", "fast_run")
else:
FAST_RUN = Mode("vm", "fast_run")
FAST_RUN = Mode("numba", "fast_run")

JAX = Mode(
JAXLinker(),
Expand Down
17 changes: 14 additions & 3 deletions aesara/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def add_compile_configvars():
config.add(
"mode",
"Default compilation mode",
ConfigParam("Mode", apply=_filter_mode),
ConfigParam("NUMBA", apply=_filter_mode),
in_c_key=False,
)

Expand Down Expand Up @@ -463,7 +463,18 @@ def add_compile_configvars():
"linker",
"Default linker used if the aesara flags mode is Mode",
EnumStr(
"cvm", ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
"numba",
[
"c|py",
"py",
"c",
"c|py_nogc",
"vm",
"vm_nogc",
"cvm_nogc",
"numba",
"jax",
],
),
in_c_key=False,
)
Expand All @@ -473,7 +484,7 @@ def add_compile_configvars():
config.add(
"linker",
"Default linker used if the aesara flags mode is Mode",
EnumStr("vm", ["py", "vm_nogc"]),
EnumStr("numba", ["py", "vm_nogc", "vm", "numba", "jax"]),
in_c_key=False,
)
if type(config).cxx.is_default:
Expand Down
20 changes: 15 additions & 5 deletions tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class TestScan:
"rng_type",
[
np.random.default_rng,
np.random.RandomState,
# np.random.RandomState,
],
)
def test_inner_graph_cloning(self, rng_type):
Expand Down Expand Up @@ -396,7 +396,7 @@ def f_pow2(x_tm1):
assert all(i.value is None for i in scan_node.op.fn.input_storage)
assert all(o.value is None for o in scan_node.op.fn.output_storage)

@pytest.mark.parametrize("mode", [Mode(linker="py"), Mode(linker="cvm")])
@pytest.mark.parametrize("mode", ["NUMBA", Mode(linker="py"), Mode(linker="cvm")])
@pytest.mark.parametrize(
"x_init",
[
Expand All @@ -421,7 +421,12 @@ def f_pow(x_tm1):
assert res.dtype == exp_res.dtype

@pytest.mark.parametrize(
"mode", [Mode(linker="py", optimizer=None), Mode(linker="cvm", optimizer=None)]
"mode",
[
"NUMBA",
Mode(linker="py", optimizer=None),
Mode(linker="cvm", optimizer=None),
],
)
@pytest.mark.parametrize(
"x",
Expand Down Expand Up @@ -459,7 +464,12 @@ def inner_fn(x_seq, x_i):
assert res.dtype == exp_res.dtype

@pytest.mark.parametrize(
"mode", [Mode(linker="py", optimizer=None), Mode(linker="cvm", optimizer=None)]
"mode",
[
"NUMBA",
Mode(linker="py", optimizer=None),
Mode(linker="cvm", optimizer=None),
],
)
@pytest.mark.parametrize(
"x",
Expand Down Expand Up @@ -1126,7 +1136,7 @@ def test_inner_grad(self):
utt.assert_allclose(out, vR)

@pytest.mark.parametrize(
"mode", [Mode(linker="cvm", optimizer=None), Mode(linker="cvm")]
"mode", ["NUMBA", Mode(linker="cvm", optimizer=None), Mode(linker="cvm")]
)
def test_sequence_is_scan(self, mode):
"""Make sure that a `Scan` can be used as a sequence input to another `Scan`."""
Expand Down

0 comments on commit 72a268c

Please sign in to comment.