Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch-frontend] add jit transforms #495

Merged
merged 4 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions frontends/torch-frontend/torch-frontend/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ declare_mlir_python_sources(TorchFrontendPythonSources.TopLevel
tools/compiler.py
tools/gen_extra_library.py
tools/extra_fn.mlir

utils/__init__.py
utils/jit_transforms.py
)

################################################################################
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
import sys

from torch_frontend.utils import replace_copy_fill_with_slice_scatter


def _test_helper(model_class, inputs):
# input = torch.randn(4, 4, 4)
model = model_class()
golden = model(*inputs)

ts_model = torch.jit.trace(model, inputs, check_trace=False)
replace_copy_fill_with_slice_scatter(ts_model.graph)
print(f"{model_class.__name__}: {ts_model.graph}")

# validate graph.
has_slice_scatter = False
for node in ts_model.graph.nodes():
assert node.kind() not in [
"aten::copy_",
"aten::fill_",
"aten::select",
], ts_model.graph

if node.kind() == "aten::slice":
uses = node.output().uses()
assert len(uses) == 1
user = uses[0].user
assert user.kind() == "aten::slice_scatter"

if node.kind() == "aten::slice_scatter":
has_slice_scatter = True
assert has_slice_scatter

out = ts_model(*inputs)
torch.testing.assert_close(golden, out)

######################################################################
# fill_ related

class NSliceSliceFill(torch.nn.Module):
def __init__(self):
super().__init__()

# (4, 4, 4)
def forward(self, x):
x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1)
x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1)
x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1)
x4 = torch.ops.aten.slice(x3, 0, 0, 2, 1)
_ = torch.ops.aten.fill_(x4, 2.0)
x = x + 1
return x

def test_nslice_slice_fill():
inputs = [torch.randn(4, 4, 4)]
_test_helper(NSliceSliceFill, inputs)


class NSliceSelectFill(torch.nn.Module):
def __init__(self):
super().__init__()

# (4, 4, 4)
def forward(self, x):
x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1)
x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1)
x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1)
x4 = torch.ops.aten.select(x3, 0, 2)
_ = torch.ops.aten.fill_(x4, 2.0)
x = x + 1
return x

def test_nslice_select_fill():
inputs = [torch.randn(4, 4, 4)]
_test_helper(NSliceSelectFill, inputs)


class SelectSliceFill(torch.nn.Module):
def __init__(self):
super().__init__()

# (4, 4, 4)
def forward(self, x):
x1 = torch.select(x, 0, 1)
x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize)
_ = torch.ops.aten.fill_(x2, -torch.inf)
x = x + 1
return x

def test_select_slice_fill():
inputs = [torch.randn(4, 4, 4)]
_test_helper(SelectSliceFill, inputs)


######################################################################
# copy_ related

class NSliceSliceCopy(torch.nn.Module):
def __init__(self):
super().__init__()

# (4, 4, 4)
def forward(self, x):
x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1)
x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1)
x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1)
x4 = torch.ops.aten.slice(x3, 0, 0, 2, 1)
zeros = torch.zeros((2, 4, 4))
_ = torch.ops.aten.copy_(x4, zeros)
x = x + 1
return x

def test_nslice_slice_copy():
inputs = [torch.randn(4, 4, 4)]
_test_helper(NSliceSliceCopy, inputs)


class NSliceSelectCopy(torch.nn.Module):
def __init__(self):
super().__init__()

# (4, 4, 4)
def forward(self, x):
x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1)
x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1)
x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1)
x4 = torch.ops.aten.select(x3, 0, 2)
zeros = torch.zeros((4, 4))
_ = torch.ops.aten.copy_(x4, zeros)
x = x + 1
return x

def test_nslice_select_copy():
inputs = [torch.randn(4, 4, 4)]
_test_helper(NSliceSelectCopy, inputs)

Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
from .fx_utils import list_decomposed_ops, preprocess_fx_graph, get_none_indices
from .flash_attn_op import replace_flash_attn
from .fx_rewrite import fx_replace_attn_pattern

from . import utils
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def compile_torchscript(args):
sample_inputs_placeholder.append(TensorPlaceholder(shape, dtype_str_to_torch_dtype(dtype)))

ts_model = torch.jit.load(args.model_path, map_location="cpu")
if args.enable_jit_rewrite:
torch_frontend.utils.replace_copy_fill_with_slice_scatter(ts_model.graph)

module = torch_frontend.compile(ts_model, sample_inputs_placeholder, args.output_type, verbose=args.verbose, debug=torch_frontend.DebugType(1))
if len(args.output_file_path) != 0:
with open(args.output_file_path, "w") as f:
Expand All @@ -102,6 +105,7 @@ def main():
parser.add_argument("--output_type", type=str, default="stablehlo", choices=["raw", "torch", "stablehlo"])
parser.add_argument("--elide", default=False, action="store_true")
parser.add_argument("--verbose", default=False, action="store_true")
parser.add_argument("--enable_jit_rewrite", default=False, action="store_true")
parser.add_argument(
"--input_name_and_shapes",
nargs="+",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .jit_transforms import replace_copy_fill_with_slice_scatter
Loading
Loading