Skip to content

Commit c2f0c24

Browse files
committed
Remove leading underscores as in jax-ml/jax@760deb3
1 parent 346a314 commit c2f0c24

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pmwd/ode_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from jax._src import core
3535
from jax import custom_derivatives
3636
from jax import lax
37-
from jax._src.numpy.util import _promote_dtypes_inexact
37+
from jax._src.numpy.util import promote_dtypes_inexact
3838
from jax._src.util import safe_map, safe_zip
3939
from jax.flatten_util import ravel_pytree
4040
from jax.tree_util import tree_leaves, tree_map
@@ -80,7 +80,7 @@ def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
8080
# Algorithm from:
8181
# E. Hairer, S. P. Norsett G. Wanner,
8282
# Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
83-
y0, f0 = _promote_dtypes_inexact(y0, f0)
83+
y0, f0 = promote_dtypes_inexact(y0, f0)
8484
dtype = y0.dtype
8585

8686
scale = atol + jnp.abs(y0) * rtol

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
packages=find_packages(),
2525
python_requires='>=3.7',
2626
install_requires=[
27-
'jax',
27+
'jax>=0.4.7',
2828
'numpy>=1.20', # numpy.typing
2929
'mcfit',
3030
],

0 commit comments

Comments
 (0)