@@ -136,15 +136,35 @@ def apply(
136
136
# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276
137
137
138
138
139
- def _replace_tuple_with_list (spec : pytree .TreeSpec ) -> pytree .TreeSpec :
140
- _type = list if spec .type == tuple else spec .type
141
- return pytree .TreeSpec (
142
- _type , spec .context , list (map (_replace_tuple_with_list , spec .children_specs ))
139
+ # TODO(XuehaiPan): Dynamo does not support `dummy_leaf = object()` as a sentinel value in the frame.
140
+ class _DummyLeaf : # use a class instead.
141
+ pass
142
+
143
+
144
+ def _replace_list_with_tuple (spec : pytree .TreeSpec ) -> pytree .TreeSpec :
145
+ def replace_list_with_tuple (x : Any ) -> Any :
146
+ if type (x ) is list :
147
+ return pytree .tree_map (
148
+ replace_list_with_tuple ,
149
+ tuple (x ),
150
+ is_leaf = lambda x : type (x ) is list ,
151
+ )
152
+ return x
153
+
154
+ dummy_leaf = _DummyLeaf ()
155
+ dummy_tree = pytree .tree_unflatten ([dummy_leaf ] * spec .num_leaves , spec )
156
+ dummy_tree = pytree .tree_map (
157
+ replace_list_with_tuple ,
158
+ dummy_tree ,
159
+ is_leaf = lambda x : type (x ) is list ,
143
160
)
161
+ return pytree .tree_structure (dummy_tree )
144
162
145
163
146
- def _open_top_level_list_if_single_element (spec : pytree .TreeSpec ) -> pytree .TreeSpec :
147
- if spec .type == list and spec .num_children == 1 :
164
+ def _open_top_level_sequence_if_single_element (
165
+ spec : pytree .TreeSpec ,
166
+ ) -> pytree .TreeSpec :
167
+ if spec .type in (tuple , list ) and spec .num_children == 1 :
148
168
return spec .children_specs [0 ]
149
169
return spec
150
170
@@ -167,10 +187,10 @@ def _assert_identical_pytree_spec(
167
187
pass_if_any_checks : Sequence [Callable [[], bool ]] = [
168
188
lambda : spec1 == spec2 ,
169
189
# FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'.
170
- lambda : _replace_tuple_with_list (spec1 ) == _replace_tuple_with_list (spec2 ),
190
+ lambda : _replace_list_with_tuple (spec1 ) == _replace_list_with_tuple (spec2 ),
171
191
# FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list.
172
- lambda : _open_top_level_list_if_single_element (spec1 ) == spec2 ,
173
- lambda : spec1 == _open_top_level_list_if_single_element (spec2 ),
192
+ lambda : _open_top_level_sequence_if_single_element (spec1 ) == spec2 ,
193
+ lambda : spec1 == _open_top_level_sequence_if_single_element (spec2 ),
174
194
]
175
195
176
196
if not any (check () for check in pass_if_any_checks ):
0 commit comments