Skip to content

Commit 86d4b7d

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[FX][export][dynamo] use tuple instead of list in normalized args_spec (pytorch#138212)
Pull Request resolved: pytorch#138212 Approved by: https://github.com/jansel
1 parent ce63193 commit 86d4b7d

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

torch/_inductor/pattern_matcher.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -582,18 +582,25 @@ def simple_flatten(
582582
def pytree_flatten(
583583
args: Sequence[Any], kwargs: Mapping[Any, Any]
584584
) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
585-
def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec:
586-
if s.type is None:
587-
return s
588-
mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
589-
return pytree.TreeSpec(
590-
mapping.get(s.type, s.type),
591-
s.context,
592-
list(map(norm_spec, s.children_specs)),
593-
)
585+
type_mapping = {immutable_list: tuple, list: tuple, immutable_dict: dict}
586+
587+
def convert_type(x: Any) -> Any:
588+
cls = type(x)
589+
convert_fn = type_mapping.get(cls)
590+
if convert_fn is not None:
591+
return pytree.tree_map(
592+
convert_type,
593+
convert_fn(x),
594+
is_leaf=lambda x: type(x) in type_mapping,
595+
)
596+
return x
594597

595-
flat, spec = pytree.tree_flatten([args, kwargs])
596-
spec = norm_spec(spec)
598+
normalized_args_tree = pytree.tree_map(
599+
convert_type,
600+
(args, kwargs),
601+
is_leaf=lambda x: type(x) in type_mapping,
602+
)
603+
flat, spec = pytree.tree_flatten(normalized_args_tree)
597604
return flat, spec
598605

599606
def __repr__(self) -> str:

torch/onnx/_internal/io_adapter.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,35 @@ def apply(
136136
# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276
137137

138138

139-
def _replace_tuple_with_list(spec: pytree.TreeSpec) -> pytree.TreeSpec:
140-
_type = list if spec.type == tuple else spec.type
141-
return pytree.TreeSpec(
142-
_type, spec.context, list(map(_replace_tuple_with_list, spec.children_specs))
139+
# TODO(XuehaiPan): Dynamo does not support `dummy_leaf = object()` as a sentinel value in the frame.
140+
class _DummyLeaf: # use a class instead.
141+
pass
142+
143+
144+
def _replace_list_with_tuple(spec: pytree.TreeSpec) -> pytree.TreeSpec:
145+
def replace_list_with_tuple(x: Any) -> Any:
146+
if type(x) is list:
147+
return pytree.tree_map(
148+
replace_list_with_tuple,
149+
tuple(x),
150+
is_leaf=lambda x: type(x) is list,
151+
)
152+
return x
153+
154+
dummy_leaf = _DummyLeaf()
155+
dummy_tree = pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec)
156+
dummy_tree = pytree.tree_map(
157+
replace_list_with_tuple,
158+
dummy_tree,
159+
is_leaf=lambda x: type(x) is list,
143160
)
161+
return pytree.tree_structure(dummy_tree)
144162

145163

146-
def _open_top_level_list_if_single_element(spec: pytree.TreeSpec) -> pytree.TreeSpec:
147-
if spec.type == list and spec.num_children == 1:
164+
def _open_top_level_sequence_if_single_element(
165+
spec: pytree.TreeSpec,
166+
) -> pytree.TreeSpec:
167+
if spec.type in (tuple, list) and spec.num_children == 1:
148168
return spec.children_specs[0]
149169
return spec
150170

@@ -167,10 +187,10 @@ def _assert_identical_pytree_spec(
167187
pass_if_any_checks: Sequence[Callable[[], bool]] = [
168188
lambda: spec1 == spec2,
169189
# FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'.
170-
lambda: _replace_tuple_with_list(spec1) == _replace_tuple_with_list(spec2),
190+
lambda: _replace_list_with_tuple(spec1) == _replace_list_with_tuple(spec2),
171191
# FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list.
172-
lambda: _open_top_level_list_if_single_element(spec1) == spec2,
173-
lambda: spec1 == _open_top_level_list_if_single_element(spec2),
192+
lambda: _open_top_level_sequence_if_single_element(spec1) == spec2,
193+
lambda: spec1 == _open_top_level_sequence_if_single_element(spec2),
174194
]
175195

176196
if not any(check() for check in pass_if_any_checks):

0 commit comments

Comments
 (0)