diff --git a/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt b/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt index 8f4afc084..1c581d93e 100644 --- a/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt +++ b/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt @@ -37,6 +37,9 @@ declare_mlir_python_sources(TorchFrontendPythonSources.TopLevel tools/compiler.py tools/gen_extra_library.py tools/extra_fn.mlir + + utils/__init__.py + utils/jit_transforms.py ) ################################################################################ diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_utils/test_jit_transforms.py b/frontends/torch-frontend/torch-frontend/python/test/test_utils/test_jit_transforms.py new file mode 100755 index 000000000..3ffd561db --- /dev/null +++ b/frontends/torch-frontend/torch-frontend/python/test/test_utils/test_jit_transforms.py @@ -0,0 +1,137 @@ +import torch +import sys + +from torch_frontend.utils import replace_copy_fill_with_slice_scatter + + +def _test_helper(model_class, inputs): + # input = torch.randn(4, 4, 4) + model = model_class() + golden = model(*inputs) + + ts_model = torch.jit.trace(model, inputs, check_trace=False) + replace_copy_fill_with_slice_scatter(ts_model.graph) + print(f"{model_class.__name__}: {ts_model.graph}") + + # validate graph. + has_slice_scatter = False + for node in ts_model.graph.nodes(): + assert node.kind() not in [ + "aten::copy_", + "aten::fill_", + "aten::select", + ], ts_model.graph + + if node.kind() == "aten::slice": + uses = node.output().uses() + assert len(uses) == 1 + user = uses[0].user + assert user.kind() == "aten::slice_scatter" + + if node.kind() == "aten::slice_scatter": + has_slice_scatter = True + assert has_slice_scatter + + out = ts_model(*inputs) + torch.testing.assert_close(golden, out) + +###################################################################### +# fill_ related + +class NSliceSliceFill(torch.nn.Module): + def __init__(self): + super().__init__() + + # (4, 4, 4) + def forward(self, x): + x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1) + x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1) + x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1) + x4 = torch.ops.aten.slice(x3, 0, 0, 2, 1) + _ = torch.ops.aten.fill_(x4, 2.0) + x = x + 1 + return x + +def test_nslice_slice_fill(): + inputs = [torch.randn(4, 4, 4)] + _test_helper(NSliceSliceFill, inputs) + + +class NSliceSelectFill(torch.nn.Module): + def __init__(self): + super().__init__() + + # (4, 4, 4) + def forward(self, x): + x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1) + x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1) + x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1) + x4 = torch.ops.aten.select(x3, 0, 2) + _ = torch.ops.aten.fill_(x4, 2.0) + x = x + 1 + return x + +def test_nslice_select_fill(): + inputs = [torch.randn(4, 4, 4)] + _test_helper(NSliceSelectFill, inputs) + + +class SelectSliceFill(torch.nn.Module): + def __init__(self): + super().__init__() + + # (4, 4, 4) + def forward(self, x): + x1 = torch.select(x, 0, 1) + x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize) + _ = torch.ops.aten.fill_(x2, -torch.inf) + x = x + 1 + return x + +def test_select_slice_fill(): + inputs = [torch.randn(4, 4, 4)] + _test_helper(SelectSliceFill, inputs) + + +###################################################################### +# copy_ related + +class NSliceSliceCopy(torch.nn.Module): + def __init__(self): + super().__init__() + + # (4, 4, 4) + def forward(self, x): + x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1) + x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1) + x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1) + x4 = torch.ops.aten.slice(x3, 0, 0, 2, 1) + zeros = torch.zeros((2, 4, 4)) + _ = torch.ops.aten.copy_(x4, zeros) + x = x + 1 + return x + +def test_nslice_slice_copy(): + inputs = [torch.randn(4, 4, 4)] + _test_helper(NSliceSliceCopy, inputs) + + +class NSliceSelectCopy(torch.nn.Module): + def __init__(self): + super().__init__() + + # (4, 4, 4) + def forward(self, x): + x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1) + x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1) + x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1) + x4 = torch.ops.aten.select(x3, 0, 2) + zeros = torch.zeros((4, 4)) + _ = torch.ops.aten.copy_(x4, zeros) + x = x + 1 + return x + +def test_nslice_select_copy(): + inputs = [torch.randn(4, 4, 4)] + _test_helper(NSliceSelectCopy, inputs) + diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py index ed4122aa8..387637f6c 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py @@ -6,3 +6,5 @@ from .fx_utils import list_decomposed_ops, preprocess_fx_graph, get_none_indices from .flash_attn_op import replace_flash_attn from .fx_rewrite import fx_replace_attn_pattern + +from . import utils diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/tools/compiler.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/tools/compiler.py index dd4f383fd..0a68f4f1c 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/tools/compiler.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/tools/compiler.py @@ -80,6 +80,9 @@ def compile_torchscript(args): sample_inputs_placeholder.append(TensorPlaceholder(shape, dtype_str_to_torch_dtype(dtype))) ts_model = torch.jit.load(args.model_path, map_location="cpu") + if args.enable_jit_rewrite: + torch_frontend.utils.replace_copy_fill_with_slice_scatter(ts_model.graph) + module = torch_frontend.compile(ts_model, sample_inputs_placeholder, args.output_type, verbose=args.verbose, debug=torch_frontend.DebugType(1)) if len(args.output_file_path) != 0: with open(args.output_file_path, "w") as f: @@ -102,6 +105,7 @@ def main(): parser.add_argument("--output_type", type=str, default="stablehlo", choices=["raw", "torch", "stablehlo"]) parser.add_argument("--elide", default=False, action="store_true") parser.add_argument("--verbose", default=False, action="store_true") + parser.add_argument("--enable_jit_rewrite", default=False, action="store_true") parser.add_argument( "--input_name_and_shapes", nargs="+", diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/__init__.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/__init__.py new file mode 100644 index 000000000..0301fe40e --- /dev/null +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/__init__.py @@ -0,0 +1 @@ +from .jit_transforms import replace_copy_fill_with_slice_scatter diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/jit_transforms.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/jit_transforms.py new file mode 100755 index 000000000..88a6f8827 --- /dev/null +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/jit_transforms.py @@ -0,0 +1,294 @@ +import torch +import sys +import functools + + +def _nslice_slice_or_select_copy(graph, pattern): + """ + slice + ... + slice/select + copy_ + """ + + def is_valid(): + fill = pattern[-1] + if fill.kind() != "aten::copy_": + return False + if pattern[-2].kind() not in ["aten::slice", "aten::select"]: + return False + for node in pattern[:-2]: + if node.kind() != "aten::slice": + return False + return ( + node.inputsAt(2).toIValue() == 0 + and node.inputsAt(3).toIValue() == sys.maxsize + ) + return True + + if not is_valid(): + return None, graph + + # availiable args. + sc_input1 = pattern[0].inputsAt(0) # self + sc_input2 = pattern[-1].inputsAt(1) # src + dim = pattern[-2].inputsAt(1) + start = pattern[-2].inputsAt(2) + + graph.setInsertPoint(pattern[-1]) + if pattern[-2].kind() == "aten::slice": + # for slice + copy_ + end = pattern[-2].inputsAt(3) + step = pattern[-2].inputsAt(4) + slice_scatter = graph.create( + "aten::slice_scatter", + [sc_input1, sc_input2, dim, start, end, step], + 1, + ) + else: + # for select + copy_ + one = graph.insertConstant(1) + end = graph.create("aten::add", [start, one], 1) + end.output().setType(torch._C.IntType.get()) + unsqueeze = graph.create("aten::unsqueeze", [sc_input2, dim], 1) + slice_scatter = graph.create( + "aten::slice_scatter", + [sc_input1, unsqueeze.output(), dim, start, end.output(), one], + 1, + ) + graph.insertNode(end) + graph.insertNode(unsqueeze) + + graph.insertNode(slice_scatter) + + graph.lint() + + return slice_scatter, graph + + +def _nslice_slice_or_select_fill(graph, pattern): + """ + slice + ... + select/slice + fill_ + """ + + def is_valid(): + fill = pattern[-1] + if fill.kind() != "aten::fill_": + return False + if pattern[-2].kind() not in ["aten::slice", "aten::select"]: + return False + for node in pattern[:-2]: + if node.kind() != "aten::slice": + return False + return ( + node.inputsAt(2).toIValue() == 0 + and node.inputsAt(3).toIValue() == sys.maxsize + ) + return True + + if not is_valid(): + return None, graph + + graph.setInsertPoint(pattern[-1]) + sc_input1 = pattern[0].inputsAt(0) + + # build src. + one = graph.insertConstant(1) + none = graph.insertConstant(None) + zeros = graph.create( + "aten::zeros_like", [sc_input1, none, none, none, none, none], 1 + ) + + fill = pattern[-1] + value = fill.inputsAt(1) + fill_value = graph.create("aten::add", [zeros.output(), value, one], 1) + + target_dim = pattern[-2].inputsAt(1) + start = pattern[-2].inputsAt(2) + # for slice + fill_ + if pattern[-2].kind() == "aten::slice": + end = pattern[-2].inputsAt(3) + step = pattern[-2].inputsAt(4) + sc_input2 = graph.create( + "aten::slice", [fill_value.output(), target_dim, start, end, step], 1 + ) + slice_scatter = graph.create( + "aten::slice_scatter", + [sc_input1, sc_input2.output(), target_dim, start, end, step], + 1, + ) + else: + # for select + fill + end = graph.create("aten::add", [start, one], 1) + end.output().setType(torch._C.IntType.get()) + sc_input2 = graph.create( + "aten::slice", + [fill_value.output(), target_dim, start, end.output(), one], + 1, + ) + slice_scatter = graph.create( + "aten::slice_scatter", + [sc_input1, sc_input2.output(), target_dim, start, end.output(), one], + 1, + ) + graph.insertNode(end) + + graph.insertNode(zeros) + graph.insertNode(fill_value) + graph.insertNode(sc_input2) + graph.insertNode(slice_scatter) + + graph.lint() + + return slice_scatter, graph + + +def _select_slice_fill_(graph, pattern): + """ + select + slice + fill_ + """ + + def is_valid(): + if len(pattern) != 3: + return False + + select = pattern[0] + slice = pattern[1] + fill = pattern[2] + if ( + select.kind() != "aten::select" + or slice.kind() != "aten::slice" + or fill.kind() != "aten::fill_" + ): + return False + start = slice.inputsAt(2).toIValue() + end = slice.inputsAt(3).toIValue() + if start != 0 or end != sys.maxsize: + return False + + return True + + if not is_valid(): + return None, graph + + graph.setInsertPoint(pattern[-1]) + + # rewrite select fill_ to slice_scatter. + select = pattern[0] + target_dim = select.inputsAt(1) + start = select.inputsAt(2) + one = graph.insertConstant(1) + end = graph.create("aten::add", [start, one], 1) + end.output().setType(torch._C.IntType.get()) + + fill = pattern[2] + sc_input1 = select.inputsAt(0) # self + none = graph.insertConstant(None) + false = graph.insertConstant(False) + value = fill.inputsAt(1) + zeros = graph.create( + "aten::zeros_like", [sc_input1, none, none, none, none, none], 1 + ) + fill_value = graph.create("aten::add", [zeros.output(), value, one], 1) + sc_input2 = graph.create( + "aten::slice", [fill_value.output(), target_dim, start, end.output(), one], 1 + ) + + slice_scatter = graph.create( + "aten::slice_scatter", + [sc_input1, sc_input2.output(), target_dim, start, end.output(), one], + 1, + ) + + graph.insertNode(end) + graph.insertNode(zeros) + graph.insertNode(fill_value) + graph.insertNode(sc_input2) + graph.insertNode(slice_scatter) + + graph.lint() + + return slice_scatter, graph + + +copy_fill_rewrites = [ + _nslice_slice_or_select_copy, + _nslice_slice_or_select_fill, + _select_slice_fill_, +] + + +def replace_copy_fill_with_slice_scatter(graph): + """ + - slice[(2, 0), (3, 9223372036854775807)] + - any select/slice + - copy_/fill_ + """ + + def is_valid_node(node, level): + if level == 0 or node.kind() in ["aten::slice", "aten::select"]: + return True + + return False + + def dfs(node, level): + pattern = [] + if is_valid_node(node, level): + pattern.append(node) + else: + return pattern + in_tensor = node.inputsAt(0) + prev = in_tensor.node() + pattern = dfs(prev, level + 1) + pattern + return pattern + + def node_compare(lhs, rhs): + if lhs[-1].isAfter(rhs[-1]): + return 1 + elif lhs[-1].isBefore(rhs[-1]): + return -1 + else: + return 0 + + # 1. find pattern points. + valid_patterns = [] + for node in graph.nodes(): + if node.kind() not in ["aten::copy_", "aten::fill_"]: + continue + pattern = dfs(node, 0) + if len(pattern) >= 2 and pattern[0].inputsAt(0).node().kind() not in [ + "aten::slice", + "aten::select", + ]: + valid_patterns.append(pattern) + + def post_process(slice_scatter, graph): + slice_scatter_use = None + sc_input1 = slice_scatter.inputsAt(0) + for use in sc_input1.uses(): + if use.user == slice_scatter: + slice_scatter_use = use + + if slice_scatter_use: + for use in sc_input1.uses(): + user = use.user + if user == pattern[0] or not use.isAfter(slice_scatter_use): + continue + + for idx, value in enumerate(user.inputs()): + if value == sc_input1: + user.replaceInput(idx, slice_scatter.output()) + + for old_node in reversed(pattern): + if len(old_node.output().uses()) == 0: + old_node.destroy() + graph.lint() + + # 2. do rewrite. + sorted_patterns = sorted(valid_patterns, key=functools.cmp_to_key(node_compare)) + for pattern in sorted_patterns: + for rewrite_func in copy_fill_rewrites: + slice_scatter, graph = rewrite_func(graph, pattern) + if slice_scatter is None: + continue + post_process(slice_scatter, graph) + break + + return graph