Skip to content

Commit dac30fa

Browse files
authored
Fix optree regsitration (#21049)
The following will break as reimporting Keras will try to re-register the Tensorflow list/dict wrappers. Presumably anything that forced an actual reimport of `keras` would trigger the same crash. ```python import keras keras.config.set_backend("tensorflow") ```
1 parent 6e688ab commit dac30fa

File tree

1 file changed

+22
-19
lines changed

1 file changed

+22
-19
lines changed

Diff for: keras/src/tree/optree_impl.py

+22-19
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,29 @@ def register_tree_node_class(cls):
1313
from tensorflow.python.trackable.data_structures import ListWrapper
1414
from tensorflow.python.trackable.data_structures import _DictWrapper
1515

16-
optree.register_pytree_node(
17-
ListWrapper,
18-
lambda x: (x, None),
19-
lambda metadata, children: ListWrapper(list(children)),
20-
namespace="keras",
21-
)
16+
try:
17+
optree.register_pytree_node(
18+
ListWrapper,
19+
lambda x: (x, None),
20+
lambda metadata, children: ListWrapper(list(children)),
21+
namespace="keras",
22+
)
2223

23-
def sorted_keys_and_values(d):
24-
keys = sorted(list(d.keys()))
25-
values = [d[k] for k in keys]
26-
return values, keys, keys
27-
28-
optree.register_pytree_node(
29-
_DictWrapper,
30-
sorted_keys_and_values,
31-
lambda metadata, children: _DictWrapper(
32-
{key: child for key, child in zip(metadata, children)}
33-
),
34-
namespace="keras",
35-
)
24+
def sorted_keys_and_values(d):
25+
keys = sorted(list(d.keys()))
26+
values = [d[k] for k in keys]
27+
return values, keys, keys
28+
29+
optree.register_pytree_node(
30+
_DictWrapper,
31+
sorted_keys_and_values,
32+
lambda metadata, children: _DictWrapper(
33+
{key: child for key, child in zip(metadata, children)}
34+
),
35+
namespace="keras",
36+
)
37+
except ValueError:
38+
pass # We may have already registered if we are reiporting keras.
3639

3740

3841
def is_nested(structure):

0 commit comments

Comments
 (0)