File tree Expand file tree Collapse file tree 2 files changed +3
-5
lines changed Expand file tree Collapse file tree 2 files changed +3
-5
lines changed Original file line number Diff line number Diff line change 22# SPDX-License-Identifier: Apache-2.0
33
44import jax
5- from jax ._src .pjit import pjit_p
65from jax .api_util import flatten_fun , shaped_abstractify
76import jax .core as core
7+ from jax .experimental .pjit import pjit_p
88from jax .interpreters .partial_eval import trace_to_jaxpr_dynamic
99from jax .interpreters .pxla import xla_pmap_p
10- from jax .interpreters .xla import xla_call_p
1110import jax .linear_util as lu
1211import jax .numpy as jnp
1312
@@ -102,7 +101,6 @@ def track_deps_call_rule(eqn, provenance_inputs):
102101
103102
104103track_deps_rules [core .call_p ] = track_deps_call_rule
105- track_deps_rules [xla_call_p ] = track_deps_call_rule
106104track_deps_rules [xla_pmap_p ] = track_deps_call_rule
107105
108106
Original file line number Diff line number Diff line change 99from setuptools import find_packages , setup
1010
1111PROJECT_PATH = os .path .dirname (os .path .abspath (__file__ ))
12- _jax_version_constraints = ">=0.4"
13- _jaxlib_version_constraints = ">=0.4"
12+ _jax_version_constraints = ">=0.4.7 "
13+ _jaxlib_version_constraints = ">=0.4.7 "
1414
1515# Find version
1616for line in open (os .path .join (PROJECT_PATH , "numpyro" , "version.py" )):
You can’t perform that action at this time.
0 commit comments