Skip to content

Commit 2ad0e41

Browse files
BoyuanFengpytorchmergebot
authored andcommittedJun 4, 2024
[ts-migration] support aten::__is__, aten::__isnot__, aten::__not__, profiler::_record_function_enter_new, profiler::_record_function_exit (pytorch#127656)
Support more ops in ts converter and add unit tests. Pull Request resolved: pytorch#127656 Approved by: https://github.com/SherlockNoMad
1 parent 8d153e0 commit 2ad0e41

File tree

2 files changed

+96
-12
lines changed

2 files changed

+96
-12
lines changed
 

‎test/export/test_converter.py

+58-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Owner(s): ["oncall: export"]
22

33
import unittest
4+
from typing import Tuple
45

56
import torch
67

@@ -9,7 +10,6 @@
910
from torch._dynamo.test_case import TestCase
1011
from torch._export.converter import TS2EPConverter
1112
from torch.export import ExportedProgram
12-
1313
from torch.testing._internal.common_utils import run_tests
1414

1515
requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
@@ -23,8 +23,11 @@ def _check_equal_ts_ep_converter(self, mod, inp) -> ExportedProgram:
2323
orig_out, _ = pytree.tree_flatten(mod(*inp))
2424
self.assertEqual(len(ep_out), len(orig_out))
2525
for ep_t, orig_t in zip(ep_out, orig_out):
26-
self.assertEqual(ep_t.shape, orig_t.shape)
27-
self.assertTrue(torch.allclose(ep_t, orig_t))
26+
if isinstance(ep_t, torch.Tensor):
27+
self.assertEqual(ep_t.shape, orig_t.shape)
28+
self.assertTrue(torch.allclose(ep_t, orig_t))
29+
else:
30+
self.assertEqual(ep_t, orig_t)
2831
return ep
2932

3033
def test_ts2ep_converter_basic(self):
@@ -192,6 +195,58 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
192195
M()(torch.tensor(False), torch.tensor(4)),
193196
)
194197

198+
def test_profiler__record_function(self):
199+
class Module(torch.nn.Module):
200+
def forward(self, x: torch.Tensor) -> torch.Tensor:
201+
handle = torch.ops.profiler._record_function_enter_new("foo", None)
202+
y = x * 2 + 4
203+
torch.ops.profiler._record_function_exit(handle)
204+
return y
205+
206+
x = torch.randn(10, 10)
207+
self._check_equal_ts_ep_converter(Module(), (x,))
208+
209+
def test_aten_floordiv(self):
210+
class Module(torch.nn.Module):
211+
def forward(self, x: torch.Tensor) -> torch.Tensor:
212+
return x // 2
213+
214+
x = torch.randn(10, 10)
215+
self._check_equal_ts_ep_converter(Module(), (x,))
216+
217+
def test_aten___is__(self):
218+
class Module(torch.nn.Module):
219+
def forward(
220+
self, x: torch.Tensor, y: torch.Tensor
221+
) -> Tuple[bool, torch.Tensor]:
222+
z = x + 1
223+
return x is y, z
224+
225+
inp = (torch.randn(10, 10), torch.rand(10, 10))
226+
self._check_equal_ts_ep_converter(Module(), inp)
227+
228+
def test_aten___isnot__(self):
229+
class Module(torch.nn.Module):
230+
def forward(
231+
self, x: torch.Tensor, y: torch.Tensor
232+
) -> Tuple[bool, torch.Tensor]:
233+
z = x + 1
234+
return x is not y, z
235+
236+
inp = (torch.randn(10, 10), torch.rand(10, 10))
237+
self._check_equal_ts_ep_converter(Module(), inp)
238+
239+
def test_aten___not__(self):
240+
class Module(torch.nn.Module):
241+
def forward(
242+
self, x: torch.Tensor, y: torch.Tensor
243+
) -> Tuple[bool, torch.Tensor]:
244+
z = x + 1
245+
return not (x is not y), z
246+
247+
inp = (torch.randn(10, 10), torch.rand(10, 10))
248+
self._check_equal_ts_ep_converter(Module(), inp)
249+
195250

196251
if __name__ == "__main__":
197252
run_tests()

‎torch/_export/converter.py

+38-9
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ def normalize_name(name: str) -> str:
4141
return name.replace(".", "_")
4242

4343

44+
# Given a node: torch._C.Node, map from node.kind() to a standard operator
45+
kind_to_standard_operators = {
46+
"prim::TupleIndex": operator.getitem,
47+
"aten::__is__": operator.is_,
48+
"aten::__isnot__": operator.is_not,
49+
"aten::__not__": operator.not_,
50+
}
51+
52+
4453
def get_op_overload(node: torch._C.Node):
4554
schema_str = node.schema()
4655
schema = FunctionSchema.parse(schema_str)
@@ -285,13 +294,6 @@ def convert_prim_DictConstruct(self, node: torch._C.Node):
285294
output_name = node.output().debugName()
286295
self.name_to_node[output_name] = output_dict
287296

288-
def convert_prim_TupleIndex(self, node: torch._C.Node):
289-
args = tuple(self.get_fx_value(input) for input in node.inputs())
290-
getitem_node = self.fx_graph.call_function(operator.getitem, args)
291-
292-
output_name = node.output().debugName()
293-
self.name_to_node[output_name] = getitem_node
294-
295297
def convert_aten_Int(self, node: torch._C.Node):
296298
# converts aten::Int as aten._to_copy + aten::_local_scalar_dense
297299
target = torch.ops.aten._to_copy.default
@@ -438,6 +440,28 @@ def convert_as_noop(self, node: torch._C.Node):
438440
output_name = node.output().debugName()
439441
self.name_to_node[output_name] = args[0]
440442

443+
def convert_profiler__record_function_enter_new(self, node: torch._C.Node):
444+
target = torch.ops.profiler._record_function_enter_new
445+
args = tuple(self.get_fx_value(input) for input in node.inputs())
446+
fx_node = self.fx_graph.call_function(target, args)
447+
output_name = node.output().debugName()
448+
self.name_to_node[output_name] = fx_node
449+
450+
def convert_profiler__record_function_exit(self, node: torch._C.Node):
451+
# _record_function_exit has side effect so we keep it in fx.graph
452+
# currently, _record_function_enter_new and _record_function_exit are
453+
# discarded during `retrace_as_exported_program`.
454+
target = torch.ops.profiler._record_function_exit
455+
args = tuple(self.get_fx_value(input) for input in node.inputs())
456+
self.fx_graph.call_function(target, args)
457+
458+
def convert_standard_operators(self, node: torch._C.Node):
459+
target = kind_to_standard_operators[node.kind()]
460+
args = tuple(self.get_fx_value(input) for input in node.inputs())
461+
fx_node = self.fx_graph.call_function(target, args)
462+
output_name = node.output().debugName()
463+
self.name_to_node[output_name] = fx_node
464+
441465
def convert_node(self, node: torch._C.Node):
442466
node_kind = node.kind()
443467
if node_kind == "prim::CreateObject":
@@ -457,8 +481,6 @@ def convert_node(self, node: torch._C.Node):
457481
self.convert_prim_dtype(node)
458482
elif node_kind == "prim::DictConstruct":
459483
self.convert_prim_DictConstruct(node)
460-
elif node_kind == "prim::TupleIndex":
461-
self.convert_prim_TupleIndex(node)
462484
# elif node_kind == "aten::Int":
463485
# convert_aten_Int(node)
464486
elif node_kind == "aten::_convolution":
@@ -471,7 +493,14 @@ def convert_node(self, node: torch._C.Node):
471493
self.convert_prim_if(node)
472494
elif node_kind == "aten::Bool":
473495
self.convert_as_noop(node)
496+
elif node_kind == "profiler::_record_function_enter_new":
497+
self.convert_profiler__record_function_enter_new(node)
498+
elif node_kind == "profiler::_record_function_exit":
499+
self.convert_profiler__record_function_exit(node)
500+
elif node_kind in kind_to_standard_operators:
501+
self.convert_standard_operators(node)
474502
elif node_kind.startswith("aten::"):
503+
# order matters! this should be handled after kind_to_standard_operators
475504
self.convert_aten_op(node)
476505
else:
477506
raise ValueError(f"Unsupported node kind: {node_kind}")

0 commit comments

Comments
 (0)
Please sign in to comment.