Skip to content

Commit f90a7a4

Browse files
committed
catch optree exception when changing backends
1 parent b9a49ea commit f90a7a4

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

keras/src/tree/optree_impl.py

+25-19
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,32 @@ def register_tree_node_class(cls):
1313
from tensorflow.python.trackable.data_structures import ListWrapper
1414
from tensorflow.python.trackable.data_structures import _DictWrapper
1515

16-
optree.register_pytree_node(
17-
ListWrapper,
18-
lambda x: (x, None),
19-
lambda metadata, children: ListWrapper(list(children)),
20-
namespace="keras",
21-
)
16+
try:
17+
optree.register_pytree_node(
18+
ListWrapper,
19+
lambda x: (x, None),
20+
lambda metadata, children: ListWrapper(list(children)),
21+
namespace="keras",
22+
)
23+
24+
def sorted_keys_and_values(d):
25+
keys = sorted(list(d.keys()))
26+
values = [d[k] for k in keys]
27+
return values, keys, keys
28+
29+
optree.register_pytree_node(
30+
_DictWrapper,
31+
sorted_keys_and_values,
32+
lambda metadata, children: _DictWrapper(
33+
{key: child for key, child in zip(metadata, children)}
34+
),
35+
namespace="keras",
36+
)
37+
except ValueError:
38+
# optree raises a ValueError if the class is already registered.
39+
# Triggered if config.set_backend() is called multiple times.
40+
pass
2241

23-
def sorted_keys_and_values(d):
24-
keys = sorted(list(d.keys()))
25-
values = [d[k] for k in keys]
26-
return values, keys, keys
27-
28-
optree.register_pytree_node(
29-
_DictWrapper,
30-
sorted_keys_and_values,
31-
lambda metadata, children: _DictWrapper(
32-
{key: child for key, child in zip(metadata, children)}
33-
),
34-
namespace="keras",
35-
)
3642

3743

3844
def is_nested(structure):

keras/src/utils/tracking.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ def wrapper(*args, **kwargs):
2727

2828
return wrapper
2929

30+
def safe_register_tree_node_class(cls):
31+
try:
32+
return tree.register_tree_node_class(cls)
33+
except ValueError:
34+
# optree raises a ValueError if the class is already registered.
35+
# Triggered if config.set_backend() is called multiple times.
36+
return cls
37+
38+
3039

3140
class Tracker:
3241
"""Attribute tracker, used for e.g. Variable tracking.
@@ -133,7 +142,7 @@ def replace_tracked_value(self, store_name, old_value, new_value):
133142
self.stored_ids[store_name].add(id(new_value))
134143

135144

136-
@tree.register_tree_node_class
145+
@safe_register_tree_node_class
137146
class TrackedList(list):
138147
def __init__(self, values=None, tracker=None):
139148
self.tracker = tracker
@@ -194,7 +203,7 @@ def tree_unflatten(cls, metadata, children):
194203
return cls(children)
195204

196205

197-
@tree.register_tree_node_class
206+
@safe_register_tree_node_class
198207
class TrackedDict(dict):
199208
def __init__(self, values=None, tracker=None):
200209
self.tracker = tracker
@@ -245,7 +254,7 @@ def tree_unflatten(cls, keys, values):
245254
return cls(zip(keys, values))
246255

247256

248-
@tree.register_tree_node_class
257+
@safe_register_tree_node_class
249258
class TrackedSet(set):
250259
def __init__(self, values=None, tracker=None):
251260
self.tracker = tracker

0 commit comments

Comments
 (0)