Skip to content

Conversation

@t-kalinowski
Copy link
Contributor

This PR fixes an exception raised when keras.config.set_backend("tensorflow") is called.

ValueError: PyTree type <class 'tensorflow.python.trackable.data_structures.ListWrapper'> is already registered in namespace 'keras'.

To reproduce:

import keras
keras.config.set_backend("jax")
keras.config.set_backend("tensorflow")

Example python session with all output:

$ /home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/bin/python3
Python 3.11.11 (main, Jan 14 2025, 22:49:08) [Clang 19.1.6 ] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import keras
2025-02-09 07:46:12.849103: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
>>> keras.config.set_backend("jax")
>>> keras.config.set_backend("tensorflow")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/utils/backend_utils.py", line 130, in set_backend
    import keras
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/__init__.py", line 2, in <module>
    from keras.api import DTypePolicy
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/api/__init__.py", line 8, in <module>
    from keras.api import activations
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/api/activations/__init__.py", line 7, in <module>
    from keras.src.activations import deserialize
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/__init__.py", line 1, in <module>
    from keras.src import activations
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/activations/__init__.py", line 3, in <module>
    from keras.src.activations.activations import celu
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/activations/activations.py", line 1, in <module>
    from keras.src import backend
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/backend/__init__.py", line 10, in <module>
    from keras.src.backend.common.dtypes import result_type
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/backend/common/__init__.py", line 2, in <module>
    from keras.src.backend.common.dtypes import result_type
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/backend/common/dtypes.py", line 5, in <module>
    from keras.src.backend.common.variables import standardize_dtype
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/backend/common/variables.py", line 11, in <module>
    from keras.src.utils.module_utils import tensorflow as tf
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/utils/__init__.py", line 1, in <module>
    from keras.src.utils.audio_dataset_utils import audio_dataset_from_directory
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/utils/audio_dataset_utils.py", line 4, in <module>
    from keras.src.utils import dataset_utils
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/utils/dataset_utils.py", line 9, in <module>
    from keras.src import tree
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/tree/__init__.py", line 1, in <module>
    from keras.src.tree.tree_api import assert_same_paths
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/tree/tree_api.py", line 8, in <module>
    from keras.src.tree import optree_impl as tree_impl
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/keras/src/tree/optree_impl.py", line 16, in <module>
    optree.register_pytree_node(
  File "/home/tomasz/.cache/r-reticulate/uv-cache/archive-v0/dToXXMV7F8HhQ1QmEA2SZ/lib/python3.11/site-packages/optree/registry.py", line 319, in register_pytree_node
    _C.register_node(
ValueError: PyTree type <class 'tensorflow.python.trackable.data_structures.ListWrapper'> is already registered in namespace 'keras'.
>>> keras.__version__
'3.8.0'

@t-kalinowski t-kalinowski force-pushed the config-set-backend-fix branch from 7df740a to f90a7a4 Compare February 9, 2025 13:01
@codecov-commenter
Copy link

codecov-commenter commented Feb 9, 2025

Codecov Report

Attention: Patch coverage is 76.47059% with 4 lines in your changes missing coverage. Please review.

Project coverage is 82.24%. Comparing base (b9a49ea) to head (6cba7d0).

Files with missing lines Patch % Lines
keras/src/tree/optree_impl.py 77.77% 2 Missing ⚠️
keras/src/utils/tracking.py 75.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20886      +/-   ##
==========================================
- Coverage   82.24%   82.24%   -0.01%     
==========================================
  Files         561      561              
  Lines       52647    52655       +8     
  Branches     8136     8136              
==========================================
+ Hits        43302    43306       +4     
- Misses       7341     7345       +4     
  Partials     2004     2004              
Flag Coverage Δ
keras 82.05% <76.47%> (-0.01%) ⬇️
keras-jax 64.21% <35.29%> (-0.01%) ⬇️
keras-numpy 59.02% <35.29%> (-0.01%) ⬇️
keras-openvino 32.51% <35.29%> (+<0.01%) ⬆️
keras-tensorflow 64.84% <76.47%> (-0.01%) ⬇️
keras-torch 64.27% <35.29%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@t-kalinowski
Copy link
Contributor Author

Calling set_backend() seems to come with a few other issues, besides this one.

E.g., when attempting to make a simple mnist model:

AttributeError: 'Functional' object has no attribute 'name'

getting past that, attempting to call model.summary():

AttributeError: 'Functional' object has no attribute '_operations'

Switching back to jax, recreating the same model:

TypeError: unexpected PRNG key type <class 'tensorflow.python.framework.ops.EagerTensor'>

I think it might be easier to restructure the project to avoid needing set_backend() and instead start fresh sessions.

Comment on lines -16 to -21
optree.register_pytree_node(
ListWrapper,
lambda x: (x, None),
lambda metadata, children: ListWrapper(list(children)),
namespace="keras",
)
Copy link

@XuehaiPan XuehaiPan Mar 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing unregister first would be better than ignoring arbitrary exceptions.

    with contextlib.suppress(ValueError):
        optree.unregister_pytree_node(
            ListWrapper,
            namespace="keras",
        )

    optree.register_pytree_node(
        ListWrapper,
        lambda x: (x, None),
        lambda metadata, children: ListWrapper(list(children)),
        namespace="keras",
    )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the note! If I recall correctly, unregistering was unsupported for the types when I tried while drafting this PR. There was a new optree release recently, so maybe this changed. It's worth a try. I agree that approach would be better if it works.

Also, this set of changes is already on master after #21049

@hertschuh
Copy link
Collaborator

This was addressed in #21049

@hertschuh hertschuh closed this Aug 8, 2025
@github-project-automation github-project-automation bot moved this from Assigned Reviewer to Closed/Rejected in PR Queue Aug 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: Closed/Rejected

Development

Successfully merging this pull request may close these issues.

5 participants