@@ -13,26 +13,32 @@ def register_tree_node_class(cls):
13
13
from tensorflow .python .trackable .data_structures import ListWrapper
14
14
from tensorflow .python .trackable .data_structures import _DictWrapper
15
15
16
- optree .register_pytree_node (
17
- ListWrapper ,
18
- lambda x : (x , None ),
19
- lambda metadata , children : ListWrapper (list (children )),
20
- namespace = "keras" ,
21
- )
16
+ try :
17
+ optree .register_pytree_node (
18
+ ListWrapper ,
19
+ lambda x : (x , None ),
20
+ lambda metadata , children : ListWrapper (list (children )),
21
+ namespace = "keras" ,
22
+ )
23
+
24
+ def sorted_keys_and_values (d ):
25
+ keys = sorted (list (d .keys ()))
26
+ values = [d [k ] for k in keys ]
27
+ return values , keys , keys
28
+
29
+ optree .register_pytree_node (
30
+ _DictWrapper ,
31
+ sorted_keys_and_values ,
32
+ lambda metadata , children : _DictWrapper (
33
+ {key : child for key , child in zip (metadata , children )}
34
+ ),
35
+ namespace = "keras" ,
36
+ )
37
+ except ValueError :
38
+ # optree raises a ValueError if the class is already registered.
39
+ # Triggered if config.set_backend() is called multiple times.
40
+ pass
22
41
23
- def sorted_keys_and_values (d ):
24
- keys = sorted (list (d .keys ()))
25
- values = [d [k ] for k in keys ]
26
- return values , keys , keys
27
-
28
- optree .register_pytree_node (
29
- _DictWrapper ,
30
- sorted_keys_and_values ,
31
- lambda metadata , children : _DictWrapper (
32
- {key : child for key , child in zip (metadata , children )}
33
- ),
34
- namespace = "keras" ,
35
- )
36
42
37
43
38
44
def is_nested (structure ):
0 commit comments