Skip to content

Commit 786e2f4

Browse files
committed
Unpin JAX and Tensorflow versions.
To get the latest ones: JAX 0.5.2 and Tensorflow 2.19.0
1 parent decd6ba commit 786e2f4

9 files changed

+13
-18
lines changed

integration_tests/import_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ def setup_package():
2828
whl_path = re.findall(
2929
r"[^\s]*\.whl",
3030
build_process.stdout,
31-
)[-1]
31+
)
3232
if not whl_path:
3333
print(build_process.stderr)
3434
raise ValueError("Installing Keras package unsuccessful. ")
35-
return whl_path
35+
return whl_path[-1]
3636

3737

3838
def create_virtualenv():

keras/src/backend/jax/image.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def affine_transform(
465465
# transform the indices
466466
coordinates = jnp.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
467467
coordinates = jnp.moveaxis(coordinates, source=-1, destination=1)
468-
coordinates += jnp.reshape(a=offset, shape=(*offset.shape, 1, 1, 1))
468+
coordinates += jnp.reshape(offset, shape=(*offset.shape, 1, 1, 1))
469469

470470
# apply affine transformation
471471
_map_coordinates = functools.partial(

keras/src/backend/torch/image.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def affine_transform(
424424
# transform the indices
425425
coordinates = torch.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
426426
coordinates = torch.moveaxis(coordinates, source=-1, destination=1)
427-
coordinates += torch.reshape(a=offset, shape=(*offset.shape, 1, 1, 1))
427+
coordinates += torch.reshape(offset, shape=(*offset.shape, 1, 1, 1))
428428

429429
# Note: torch.stack is faster than torch.vmap when the batch size is small.
430430
affined = torch.stack(

keras/src/export/onnx.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def export_onnx(model, filepath, verbose=None, input_signature=None, **kwargs):
8080
decorated_fn = get_concrete_fn(model, input_signature, **kwargs)
8181

8282
# Use `tf2onnx` to convert the `decorated_fn` to the ONNX format.
83-
patch_tf2onnx() # TODO: Remove this once `tf2onnx` supports numpy 2.
83+
# patch_tf2onnx() # TODO: Remove this once `tf2onnx` supports numpy 2.
8484
tf2onnx.convert.from_function(
8585
decorated_fn, input_signature, output_path=filepath
8686
)

requirements-common.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@ absl-py
1010
requests
1111
h5py
1212
ml-dtypes
13-
protobuf
13+
protobuf>=4.21.6 # Earlier versions break Tensorflow>=2.19
1414
tensorboard-plugin-profile
1515
rich
1616
build
1717
optree
1818
pytest-cov
1919
packaging
20+
tf2onnx>=1.16.1 # For Numpy 2 support
2021
# for tree_test.py
2122
dm_tree
2223
coverage!=7.6.5 # 7.6.5 breaks CI

requirements-jax-cuda.txt

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# Tensorflow cpu-only version (needed for testing).
2-
tensorflow-cpu~=2.18.0
3-
tf2onnx
2+
tensorflow-cpu
43

54
# Torch cpu-only version (needed for testing).
65
--extra-index-url https://download.pytorch.org/whl/cpu
76
torch==2.6.0+cpu
87

98
# Jax with cuda support.
109
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
11-
jax[cuda12]==0.4.28
10+
jax[cuda12]
1211
flax
1312

1413
-r requirements-common.txt

requirements-tensorflow-cuda.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Tensorflow with cuda support.
2-
tensorflow[and-cuda]~=2.18.0
3-
tf2onnx
2+
tensorflow[and-cuda]
43

54
# Torch cpu-only version (needed for testing).
65
--extra-index-url https://download.pytorch.org/whl/cpu

requirements-torch-cuda.txt

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Tensorflow cpu-only version (needed for testing).
22
tensorflow-cpu~=2.18.0
3-
tf2onnx
43

54
# Torch with cuda support.
65
# - torch is pinned to a version that is compatible with torch-xla

requirements.txt

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
# Tensorflow.
2-
tensorflow-cpu~=2.18.0;sys_platform != 'darwin'
3-
tensorflow~=2.18.0;sys_platform == 'darwin'
2+
tensorflow-cpu~=2.19.0;sys_platform != 'darwin'
3+
tensorflow~=2.19.0;sys_platform == 'darwin'
44
tf_keras
5-
tf2onnx
65

76
# Torch.
87
--extra-index-url https://download.pytorch.org/whl/cpu
98
torch==2.6.0+cpu
109
torch-xla==2.6.0;sys_platform != 'darwin'
1110

1211
# Jax.
13-
# Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test.
14-
# Note that we test against the latest JAX on GPU.
15-
jax[cpu]==0.5.0
12+
jax[cpu]
1613
flax
1714

1815
# Common deps.

0 commit comments

Comments
 (0)