Skip to content

Commit 2cc31d6

Browse files
Backend-legal-ops argument for fx lowering (#3956)
Added `backend-legal-ops` argument in `fx.import_and_export` to stop decomposition of certain torch ops. This PR is based on this [issue](#3953)
1 parent f42c7e4 commit 2cc31d6

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

python/torch_mlir/fx.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,32 @@ def _module_lowering(
2929
output_type,
3030
torch_mod,
3131
extra_library_file_name=None,
32+
backend_legal_ops=None,
3233
):
3334

3435
if output_type == OutputType.RAW:
3536
if verbose:
3637
print(torch_mod)
3738
return torch_mod
3839
# TODO: pass extra_library_file_name by caller
40+
41+
backend_legal_op_arg_str = ""
42+
if backend_legal_ops is not None:
43+
if not len(backend_legal_ops) == 0:
44+
backend_legal_op_arg_str = "backend-legal-ops=" + ",".join(
45+
backend_legal_ops
46+
)
47+
3948
if extra_library_file_name is None:
4049
extra_library_file_name = ""
41-
option_string = "{extra-library=" + extra_library_file_name + "}"
50+
option_string = (
51+
"{"
52+
+ backend_legal_op_arg_str
53+
+ " extra-library="
54+
+ extra_library_file_name
55+
+ "}"
56+
)
57+
4258
run_pipeline_with_repro_report(
4359
torch_mod,
4460
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
@@ -61,6 +77,7 @@ def export_and_import(
6177
func_name: str = "main",
6278
enable_graph_printing: bool = False,
6379
enable_ir_printing: bool = False,
80+
backend_legal_ops: Optional[list[str]] = None,
6481
**kwargs,
6582
):
6683
context = ir.Context()
@@ -98,7 +115,10 @@ def export_and_import(
98115
)
99116

100117
return _module_lowering(
101-
enable_ir_printing, OutputType.get(output_type), fx_importer.module
118+
enable_ir_printing,
119+
OutputType.get(output_type),
120+
fx_importer.module,
121+
backend_legal_ops=backend_legal_ops,
102122
)
103123

104124

@@ -110,6 +130,7 @@ def stateless_fx_import(
110130
model_name: str = "main",
111131
enable_graph_printing: bool = False,
112132
enable_ir_printing: bool = False,
133+
backend_legal_ops: Optional[list[str]] = None,
113134
):
114135
if enable_graph_printing:
115136
gm.print_readable()
@@ -119,5 +140,8 @@ def stateless_fx_import(
119140
fx_importer = FxImporter(context=context, hooks=hooks)
120141
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
121142
return _module_lowering(
122-
enable_ir_printing, OutputType.get(output_type), fx_importer.module
143+
enable_ir_printing,
144+
OutputType.get(output_type),
145+
fx_importer.module,
146+
backend_legal_ops=backend_legal_ops,
123147
)

0 commit comments

Comments
 (0)