@@ -29,16 +29,32 @@ def _module_lowering(
29
29
output_type ,
30
30
torch_mod ,
31
31
extra_library_file_name = None ,
32
+ backend_legal_ops = None ,
32
33
):
33
34
34
35
if output_type == OutputType .RAW :
35
36
if verbose :
36
37
print (torch_mod )
37
38
return torch_mod
38
39
# TODO: pass extra_library_file_name by caller
40
+
41
+ backend_legal_op_arg_str = ""
42
+ if backend_legal_ops is not None :
43
+ if not len (backend_legal_ops ) == 0 :
44
+ backend_legal_op_arg_str = "backend-legal-ops=" + "," .join (
45
+ backend_legal_ops
46
+ )
47
+
39
48
if extra_library_file_name is None :
40
49
extra_library_file_name = ""
41
- option_string = "{extra-library=" + extra_library_file_name + "}"
50
+ option_string = (
51
+ "{"
52
+ + backend_legal_op_arg_str
53
+ + " extra-library="
54
+ + extra_library_file_name
55
+ + "}"
56
+ )
57
+
42
58
run_pipeline_with_repro_report (
43
59
torch_mod ,
44
60
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{ option_string } )" ,
@@ -61,6 +77,7 @@ def export_and_import(
61
77
func_name : str = "main" ,
62
78
enable_graph_printing : bool = False ,
63
79
enable_ir_printing : bool = False ,
80
+ backend_legal_ops : Optional [list [str ]] = None ,
64
81
** kwargs ,
65
82
):
66
83
context = ir .Context ()
@@ -98,7 +115,10 @@ def export_and_import(
98
115
)
99
116
100
117
return _module_lowering (
101
- enable_ir_printing , OutputType .get (output_type ), fx_importer .module
118
+ enable_ir_printing ,
119
+ OutputType .get (output_type ),
120
+ fx_importer .module ,
121
+ backend_legal_ops = backend_legal_ops ,
102
122
)
103
123
104
124
@@ -110,6 +130,7 @@ def stateless_fx_import(
110
130
model_name : str = "main" ,
111
131
enable_graph_printing : bool = False ,
112
132
enable_ir_printing : bool = False ,
133
+ backend_legal_ops : Optional [list [str ]] = None ,
113
134
):
114
135
if enable_graph_printing :
115
136
gm .print_readable ()
@@ -119,5 +140,8 @@ def stateless_fx_import(
119
140
fx_importer = FxImporter (context = context , hooks = hooks )
120
141
fx_importer .import_stateless_graph (gm .graph , func_name = model_name )
121
142
return _module_lowering (
122
- enable_ir_printing , OutputType .get (output_type ), fx_importer .module
143
+ enable_ir_printing ,
144
+ OutputType .get (output_type ),
145
+ fx_importer .module ,
146
+ backend_legal_ops = backend_legal_ops ,
123
147
)
0 commit comments