Skip to content

Commit 59a188d

Browse files
authored
avoid breakage in old jax version without jax.extend (#1647)
* avoid breakage in old jax version without jax.extend * fix lint
1 parent 6e3f007 commit 59a188d

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

numpyro/ops/provenance.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from jax.api_util import flatten_fun, shaped_abstractify
66
import jax.core as core
77
from jax.experimental.pjit import pjit_p
8-
import jax.extend.linear_util as lu
8+
9+
try:
10+
import jax.extend.linear_util as lu
11+
except ImportError:
12+
import jax.linear_util as lu
913
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
1014
from jax.interpreters.pxla import xla_pmap_p
1115
import jax.numpy as jnp

test/ops/test_provenance.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
import jax
99
from jax.api_util import flatten_fun_nokwargs
1010
import jax.core as core
11-
import jax.extend.linear_util as lu
11+
12+
try:
13+
import jax.extend.linear_util as lu
14+
except ImportError:
15+
import jax.linear_util as lu
1216
import jax.numpy as jnp
1317

1418
from numpyro.ops.provenance import eval_provenance

0 commit comments

Comments
 (0)