Skip to content

Commit ed4df38

Browse files
[onnx] Add torch-mlir-import-onnx tool. (#2637)
Simple Python console script to import an ONNX protobuf to the torch dialect for additional processing. For installed wheels, this can be used with something like: ``` torch-mlir-import-onnx test/python/onnx_importer/LeakyReLU.onnx ``` Or from a dev setup: ``` python -m torch_mlir.tools.import_onnx ... ```
1 parent 7cf52ae commit ed4df38

File tree

6 files changed

+109
-2
lines changed

6 files changed

+109
-2
lines changed

python/CMakeLists.txt

+7
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers
3838
extras/onnx_importer.py
3939
)
4040

41+
declare_mlir_python_sources(TorchMLIRPythonSources.Tools
42+
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
43+
ADD_TO_PARENT TorchMLIRPythonSources
44+
SOURCES
45+
tools/import_onnx/__main__.py
46+
)
47+
4148
################################################################################
4249
# Extensions
4350
################################################################################
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
"""Console tool for converting an ONNX proto to torch IR.
7+
8+
Typically, when installed from a wheel, this can be invoked as:
9+
10+
torch-mlir-import-onnx some.pb
11+
12+
Or from Python:
13+
14+
python -m torch_mlir.tools.import_onnx ...
15+
"""
16+
import argparse
17+
from pathlib import Path
18+
import sys
19+
20+
import onnx
21+
22+
from ...extras import onnx_importer
23+
24+
from ...dialects import torch as torch_d
25+
from ...ir import (
26+
Context,
27+
)
28+
29+
30+
def main(args):
31+
model_proto = load_onnx_model(args.input_file)
32+
context = Context()
33+
torch_d.register_dialect(context)
34+
model_info = onnx_importer.ModelInfo(model_proto)
35+
m = model_info.create_module(context=context)
36+
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
37+
imp.import_all()
38+
if not args.no_verify:
39+
m.verify()
40+
41+
# TODO: This isn't very efficient output. If these files ever
42+
# get large, enable bytecode and direct binary emission to save
43+
# some copies.
44+
if args.output_file and args.output_file != "-":
45+
with open(args.output_file, "wt") as f:
46+
print(m.get_asm(assume_verified=not args.no_verify), file=f)
47+
else:
48+
print(m.get_asm(assume_verified=not args.no_verify))
49+
50+
51+
def load_onnx_model(file_path: Path) -> onnx.ModelProto:
52+
raw_model = onnx.load(file_path)
53+
inferred_model = onnx.shape_inference.infer_shapes(raw_model)
54+
return inferred_model
55+
56+
57+
def parse_arguments(argv=None):
58+
parser = argparse.ArgumentParser(description="Torch-mlir ONNX import tool")
59+
parser.add_argument("input_file", help="ONNX protobuf input", type=Path)
60+
parser.add_argument(
61+
"-o", dest="output_file", help="Output path (or '-' for stdout)"
62+
)
63+
parser.add_argument(
64+
"--no-verify",
65+
action="store_true",
66+
help="Disable verification prior to printing",
67+
)
68+
args = parser.parse_args(argv)
69+
return args
70+
71+
72+
def _cli_main():
73+
sys.exit(main(parse_arguments()))
74+
75+
76+
if __name__ == "__main__":
77+
_cli_main()

setup.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,11 @@ def build_extension(self, ext):
186186
"onnx": [
187187
"onnx>=1.15.0",
188188
],
189-
}
189+
},
190+
entry_points={
191+
"console_scripts": [
192+
"torch-mlir-import-onnx = torch_mlir.tools.import_onnx:_cli_main",
193+
],
194+
},
190195
zip_safe=False,
191196
)

test/lit.cfg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
2525

2626
# suffixes: A list of file extensions to treat as test files.
27-
config.suffixes = ['.mlir', '.py']
27+
config.suffixes = ['.mlir', '.py', '.runlit']
2828

2929
# test_source_root: The root path where tests are located.
3030
config.test_source_root = os.path.dirname(__file__)
+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
pytorch0.3:h
2+
"
3+
01" LeakyRelu*
4+
alpha
5+
�#<�torch-jit-exportZ
6+
0
7+

8+

9+

10+
b
11+
1
12+

13+

14+

15+
B
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# RUN: %PYTHON -m torch_mlir.tools.import_onnx %S/LeakyReLU.onnx | FileCheck %s
2+
3+
# CHECK: torch.operator "onnx.LeakyRelu"

0 commit comments

Comments
 (0)