Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch optree exception when changing backends #20886

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

t-kalinowski
Copy link
Contributor

This PR fixes an exception raised when keras.config.set_backend("tensorflow") is called.

ValueError: PyTree type <class 'tensorflow.python.trackable.data_structures.ListWrapper'> is already registered in namespace 'keras'.

To reproduce:

import keras
keras.config.set_backend("jax")
keras.config.set_backend("tensorflow")

Example python session with all output:

$ /home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/bin/python3
Python 3.11.11 (main, Jan 14 2025, 22:49:08) [Clang 19.1.6 ] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import keras
2025-02-09 07:46:12.849103: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
>>> keras.config.set_backend("jax")
>>> keras.config.set_backend("tensorflow")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/utils/backend_utils.py", line 130, in set_backend
    import keras
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/__init__.py", line 2, in <module>
    from keras.api import DTypePolicy
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/api/__init__.py", line 8, in <module>
    from keras.api import activations
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/api/activations/__init__.py", line 7, in <module>
    from keras.src.activations import deserialize
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/__init__.py", line 1, in <module>
    from keras.src import activations
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/activations/__init__.py", line 3, in <module>
    from keras.src.activations.activations import celu
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/activations/activations.py", line 1, in <module>
    from keras.src import backend
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/backend/__init__.py", line 10, in <module>
    from keras.src.backend.common.dtypes import result_type
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/backend/common/__init__.py", line 2, in <module>
    from keras.src.backend.common.dtypes import result_type
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/backend/common/dtypes.py", line 5, in <module>
    from keras.src.backend.common.variables import standardize_dtype
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/backend/common/variables.py", line 11, in <module>
    from keras.src.utils.module_utils import tensorflow as tf
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/utils/__init__.py", line 1, in <module>
    from keras.src.utils.audio_dataset_utils import audio_dataset_from_directory
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/utils/audio_dataset_utils.py", line 4, in <module>
    from keras.src.utils import dataset_utils
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/utils/dataset_utils.py", line 9, in <module>
    from keras.src import tree
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/tree/__init__.py", line 1, in <module>
    from keras.src.tree.tree_api import assert_same_paths
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/tree/tree_api.py", line 8, in <module>
    from keras.src.tree import optree_impl as tree_impl
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/tree/optree_impl.py", line 16, in <module>
    optree.register_pytree_node(
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/optree/registry.py", line 319, in register_pytree_node
    _C.register_node(
ValueError: PyTree type <class 'tensorflow.python.trackable.data_structures.ListWrapper'> is already registered in namespace 'keras'.
>>> keras.__version__
'3.8.0'

@t-kalinowski t-kalinowski force-pushed the config-set-backend-fix branch from 7df740a to f90a7a4 Compare February 9, 2025 13:01
@codecov-commenter
Copy link

codecov-commenter commented Feb 9, 2025

Codecov Report

Attention: Patch coverage is 76.47059% with 4 lines in your changes missing coverage. Please review.

Project coverage is 82.24%. Comparing base (b9a49ea) to head (6cba7d0).

Files with missing lines Patch % Lines
keras/src/tree/optree_impl.py 77.77% 2 Missing ⚠️
keras/src/utils/tracking.py 75.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20886      +/-   ##
==========================================
- Coverage   82.24%   82.24%   -0.01%     
==========================================
  Files         561      561              
  Lines       52647    52655       +8     
  Branches     8136     8136              
==========================================
+ Hits        43302    43306       +4     
- Misses       7341     7345       +4     
  Partials     2004     2004              
Flag Coverage Δ
keras 82.05% <76.47%> (-0.01%) ⬇️
keras-jax 64.21% <35.29%> (-0.01%) ⬇️
keras-numpy 59.02% <35.29%> (-0.01%) ⬇️
keras-openvino 32.51% <35.29%> (+<0.01%) ⬆️
keras-tensorflow 64.84% <76.47%> (-0.01%) ⬇️
keras-torch 64.27% <35.29%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@t-kalinowski
Copy link
Contributor Author

Calling set_backend() seems to come with a few other issues, besides this one.

E.g., when attempting to make a simple mnist model:

AttributeError: 'Functional' object has no attribute 'name'

getting past that, attempting to call model.summary():

AttributeError: 'Functional' object has no attribute '_operations'

Switching back to jax, recreating the same model:

TypeError: unexpected PRNG key type <class 'tensorflow.python.framework.ops.EagerTensor'>

I think it might be easier to restructure the project to avoid needing set_backend() and instead start fresh sessions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

3 participants