From f90a7a48e8f1220ab13796bf22a95789b1982c81 Mon Sep 17 00:00:00 2001 From: Tomasz Kalinowski Date: Sun, 9 Feb 2025 07:48:10 -0500 Subject: [PATCH 1/2] catch optree exception when changing backends --- keras/src/tree/optree_impl.py | 44 ++++++++++++++++++++--------------- keras/src/utils/tracking.py | 15 +++++++++--- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/keras/src/tree/optree_impl.py b/keras/src/tree/optree_impl.py index 1aad9c2b135..1e85a254157 100644 --- a/keras/src/tree/optree_impl.py +++ b/keras/src/tree/optree_impl.py @@ -13,26 +13,32 @@ def register_tree_node_class(cls): from tensorflow.python.trackable.data_structures import ListWrapper from tensorflow.python.trackable.data_structures import _DictWrapper - optree.register_pytree_node( - ListWrapper, - lambda x: (x, None), - lambda metadata, children: ListWrapper(list(children)), - namespace="keras", - ) + try: + optree.register_pytree_node( + ListWrapper, + lambda x: (x, None), + lambda metadata, children: ListWrapper(list(children)), + namespace="keras", + ) + + def sorted_keys_and_values(d): + keys = sorted(list(d.keys())) + values = [d[k] for k in keys] + return values, keys, keys + + optree.register_pytree_node( + _DictWrapper, + sorted_keys_and_values, + lambda metadata, children: _DictWrapper( + {key: child for key, child in zip(metadata, children)} + ), + namespace="keras", + ) + except ValueError: + # optree raises a ValueError if the class is already registered. + # Triggered if config.set_backend() is called multiple times. + pass - def sorted_keys_and_values(d): - keys = sorted(list(d.keys())) - values = [d[k] for k in keys] - return values, keys, keys - - optree.register_pytree_node( - _DictWrapper, - sorted_keys_and_values, - lambda metadata, children: _DictWrapper( - {key: child for key, child in zip(metadata, children)} - ), - namespace="keras", - ) def is_nested(structure): diff --git a/keras/src/utils/tracking.py b/keras/src/utils/tracking.py index a2a26679937..f2e5f6e5379 100644 --- a/keras/src/utils/tracking.py +++ b/keras/src/utils/tracking.py @@ -27,6 +27,15 @@ def wrapper(*args, **kwargs): return wrapper +def safe_register_tree_node_class(cls): + try: + return tree.register_tree_node_class(cls) + except ValueError: + # optree raises a ValueError if the class is already registered. + # Triggered if config.set_backend() is called multiple times. + return cls + + class Tracker: """Attribute tracker, used for e.g. Variable tracking. @@ -133,7 +142,7 @@ def replace_tracked_value(self, store_name, old_value, new_value): self.stored_ids[store_name].add(id(new_value)) -@tree.register_tree_node_class +@safe_register_tree_node_class class TrackedList(list): def __init__(self, values=None, tracker=None): self.tracker = tracker @@ -194,7 +203,7 @@ def tree_unflatten(cls, metadata, children): return cls(children) -@tree.register_tree_node_class +@safe_register_tree_node_class class TrackedDict(dict): def __init__(self, values=None, tracker=None): self.tracker = tracker @@ -245,7 +254,7 @@ def tree_unflatten(cls, keys, values): return cls(zip(keys, values)) -@tree.register_tree_node_class +@safe_register_tree_node_class class TrackedSet(set): def __init__(self, values=None, tracker=None): self.tracker = tracker From 6cba7d0818d1516c508e7f8b5030380b674b229a Mon Sep 17 00:00:00 2001 From: Tomasz Kalinowski Date: Sun, 9 Feb 2025 08:54:32 -0500 Subject: [PATCH 2/2] format --- keras/src/tree/optree_impl.py | 1 - keras/src/utils/tracking.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/keras/src/tree/optree_impl.py b/keras/src/tree/optree_impl.py index 1e85a254157..26126d47034 100644 --- a/keras/src/tree/optree_impl.py +++ b/keras/src/tree/optree_impl.py @@ -40,7 +40,6 @@ def sorted_keys_and_values(d): pass - def is_nested(structure): return not optree.tree_is_leaf( structure, none_is_leaf=True, namespace="keras" diff --git a/keras/src/utils/tracking.py b/keras/src/utils/tracking.py index f2e5f6e5379..44623d4ad65 100644 --- a/keras/src/utils/tracking.py +++ b/keras/src/utils/tracking.py @@ -27,6 +27,7 @@ def wrapper(*args, **kwargs): return wrapper + def safe_register_tree_node_class(cls): try: return tree.register_tree_node_class(cls) @@ -36,7 +37,6 @@ def safe_register_tree_node_class(cls): return cls - class Tracker: """Attribute tracker, used for e.g. Variable tracking.