File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change 55from jax .api_util import flatten_fun , shaped_abstractify
66import jax .core as core
77from jax .experimental .pjit import pjit_p
8- import jax .extend .linear_util as lu
8+
9+ try :
10+ import jax .extend .linear_util as lu
11+ except ImportError :
12+ import jax .linear_util as lu
913from jax .interpreters .partial_eval import trace_to_jaxpr_dynamic
1014from jax .interpreters .pxla import xla_pmap_p
1115import jax .numpy as jnp
Original file line number Diff line number Diff line change 88import jax
99from jax .api_util import flatten_fun_nokwargs
1010import jax .core as core
11- import jax .extend .linear_util as lu
11+
12+ try :
13+ import jax .extend .linear_util as lu
14+ except ImportError :
15+ import jax .linear_util as lu
1216import jax .numpy as jnp
1317
1418from numpyro .ops .provenance import eval_provenance
You can’t perform that action at this time.
0 commit comments