@@ -41,6 +41,15 @@ def normalize_name(name: str) -> str:
41
41
return name .replace ("." , "_" )
42
42
43
43
44
+ # Given a node: torch._C.Node, map from node.kind() to a standard operator
45
+ kind_to_standard_operators = {
46
+ "prim::TupleIndex" : operator .getitem ,
47
+ "aten::__is__" : operator .is_ ,
48
+ "aten::__isnot__" : operator .is_not ,
49
+ "aten::__not__" : operator .not_ ,
50
+ }
51
+
52
+
44
53
def get_op_overload (node : torch ._C .Node ):
45
54
schema_str = node .schema ()
46
55
schema = FunctionSchema .parse (schema_str )
@@ -285,13 +294,6 @@ def convert_prim_DictConstruct(self, node: torch._C.Node):
285
294
output_name = node .output ().debugName ()
286
295
self .name_to_node [output_name ] = output_dict
287
296
288
- def convert_prim_TupleIndex (self , node : torch ._C .Node ):
289
- args = tuple (self .get_fx_value (input ) for input in node .inputs ())
290
- getitem_node = self .fx_graph .call_function (operator .getitem , args )
291
-
292
- output_name = node .output ().debugName ()
293
- self .name_to_node [output_name ] = getitem_node
294
-
295
297
def convert_aten_Int (self , node : torch ._C .Node ):
296
298
# converts aten::Int as aten._to_copy + aten::_local_scalar_dense
297
299
target = torch .ops .aten ._to_copy .default
@@ -438,6 +440,28 @@ def convert_as_noop(self, node: torch._C.Node):
438
440
output_name = node .output ().debugName ()
439
441
self .name_to_node [output_name ] = args [0 ]
440
442
443
+ def convert_profiler__record_function_enter_new (self , node : torch ._C .Node ):
444
+ target = torch .ops .profiler ._record_function_enter_new
445
+ args = tuple (self .get_fx_value (input ) for input in node .inputs ())
446
+ fx_node = self .fx_graph .call_function (target , args )
447
+ output_name = node .output ().debugName ()
448
+ self .name_to_node [output_name ] = fx_node
449
+
450
+ def convert_profiler__record_function_exit (self , node : torch ._C .Node ):
451
+ # _record_function_exit has side effect so we keep it in fx.graph
452
+ # currently, _record_function_enter_new and _record_function_exit are
453
+ # discarded during `retrace_as_exported_program`.
454
+ target = torch .ops .profiler ._record_function_exit
455
+ args = tuple (self .get_fx_value (input ) for input in node .inputs ())
456
+ self .fx_graph .call_function (target , args )
457
+
458
+ def convert_standard_operators (self , node : torch ._C .Node ):
459
+ target = kind_to_standard_operators [node .kind ()]
460
+ args = tuple (self .get_fx_value (input ) for input in node .inputs ())
461
+ fx_node = self .fx_graph .call_function (target , args )
462
+ output_name = node .output ().debugName ()
463
+ self .name_to_node [output_name ] = fx_node
464
+
441
465
def convert_node (self , node : torch ._C .Node ):
442
466
node_kind = node .kind ()
443
467
if node_kind == "prim::CreateObject" :
@@ -457,8 +481,6 @@ def convert_node(self, node: torch._C.Node):
457
481
self .convert_prim_dtype (node )
458
482
elif node_kind == "prim::DictConstruct" :
459
483
self .convert_prim_DictConstruct (node )
460
- elif node_kind == "prim::TupleIndex" :
461
- self .convert_prim_TupleIndex (node )
462
484
# elif node_kind == "aten::Int":
463
485
# convert_aten_Int(node)
464
486
elif node_kind == "aten::_convolution" :
@@ -471,7 +493,14 @@ def convert_node(self, node: torch._C.Node):
471
493
self .convert_prim_if (node )
472
494
elif node_kind == "aten::Bool" :
473
495
self .convert_as_noop (node )
496
+ elif node_kind == "profiler::_record_function_enter_new" :
497
+ self .convert_profiler__record_function_enter_new (node )
498
+ elif node_kind == "profiler::_record_function_exit" :
499
+ self .convert_profiler__record_function_exit (node )
500
+ elif node_kind in kind_to_standard_operators :
501
+ self .convert_standard_operators (node )
474
502
elif node_kind .startswith ("aten::" ):
503
+ # order matters! this should be handled after kind_to_standard_operators
475
504
self .convert_aten_op (node )
476
505
else :
477
506
raise ValueError (f"Unsupported node kind: { node_kind } " )
0 commit comments