diff --git a/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt b/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt
index 8f4afc084..1c581d93e 100644
--- a/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt
+++ b/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt
@@ -37,6 +37,9 @@ declare_mlir_python_sources(TorchFrontendPythonSources.TopLevel
     tools/compiler.py
     tools/gen_extra_library.py
     tools/extra_fn.mlir
+
+    utils/__init__.py
+    utils/jit_transforms.py
 )
 
 ################################################################################
diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_utils/test_jit_transforms.py b/frontends/torch-frontend/torch-frontend/python/test/test_utils/test_jit_transforms.py
new file mode 100755
index 000000000..3ffd561db
--- /dev/null
+++ b/frontends/torch-frontend/torch-frontend/python/test/test_utils/test_jit_transforms.py
@@ -0,0 +1,137 @@
+import torch
+import sys
+
+from torch_frontend.utils import replace_copy_fill_with_slice_scatter
+
+
+def _test_helper(model_class, inputs):
+    # input = torch.randn(4, 4, 4)
+    model = model_class()
+    golden = model(*inputs)
+
+    ts_model = torch.jit.trace(model, inputs, check_trace=False)
+    replace_copy_fill_with_slice_scatter(ts_model.graph)
+    print(f"{model_class.__name__}: {ts_model.graph}")
+
+    # validate graph.
+    has_slice_scatter = False
+    for node in ts_model.graph.nodes():
+        assert node.kind() not in [
+            "aten::copy_",
+            "aten::fill_",
+            "aten::select",
+        ], ts_model.graph
+
+        if node.kind() == "aten::slice":
+            uses = node.output().uses()
+            assert len(uses) == 1
+            user = uses[0].user
+            assert user.kind() == "aten::slice_scatter"
+
+        if node.kind() == "aten::slice_scatter":
+            has_slice_scatter = True
+    assert has_slice_scatter
+
+    out = ts_model(*inputs)
+    torch.testing.assert_close(golden, out)
+
+######################################################################
+# fill_ related
+
+class NSliceSliceFill(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    # (4, 4, 4)
+    def forward(self, x):
+        x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1)
+        x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1)
+        x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1)
+        x4 = torch.ops.aten.slice(x3, 0, 0, 2, 1)
+        _ = torch.ops.aten.fill_(x4, 2.0)
+        x = x + 1
+        return x
+
+def test_nslice_slice_fill():
+    inputs = [torch.randn(4, 4, 4)]
+    _test_helper(NSliceSliceFill, inputs)
+
+
+class NSliceSelectFill(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    # (4, 4, 4)
+    def forward(self, x):
+        x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1)
+        x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1)
+        x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1)
+        x4 = torch.ops.aten.select(x3, 0, 2)
+        _ = torch.ops.aten.fill_(x4, 2.0)
+        x = x + 1
+        return x
+
+def test_nslice_select_fill():
+    inputs = [torch.randn(4, 4, 4)]
+    _test_helper(NSliceSelectFill, inputs)
+
+
+class SelectSliceFill(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    # (4, 4, 4)
+    def forward(self, x):
+        x1 = torch.select(x, 0, 1)
+        x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize)
+        _ = torch.ops.aten.fill_(x2, -torch.inf)
+        x = x + 1
+        return x
+
+def test_select_slice_fill():
+    inputs = [torch.randn(4, 4, 4)]
+    _test_helper(SelectSliceFill, inputs)
+
+
+######################################################################
+# copy_ related
+
+class NSliceSliceCopy(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    # (4, 4, 4)
+    def forward(self, x):
+        x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1)
+        x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1)
+        x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1)
+        x4 = torch.ops.aten.slice(x3, 0, 0, 2, 1)
+        zeros = torch.zeros((2, 4, 4))
+        _ = torch.ops.aten.copy_(x4, zeros)
+        x = x + 1
+        return x
+
+def test_nslice_slice_copy():
+    inputs = [torch.randn(4, 4, 4)]
+    _test_helper(NSliceSliceCopy, inputs)
+
+
+class NSliceSelectCopy(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    # (4, 4, 4)
+    def forward(self, x):
+        x1 = torch.ops.aten.slice(x, 1, 0, sys.maxsize, 1)
+        x2 = torch.ops.aten.slice(x1, 0, 0, sys.maxsize, 1)
+        x3 = torch.ops.aten.slice(x2, 2, 0, sys.maxsize, 1)
+        x4 = torch.ops.aten.select(x3, 0, 2)
+        zeros = torch.zeros((4, 4))
+        _ = torch.ops.aten.copy_(x4, zeros)
+        x = x + 1
+        return x
+
+def test_nslice_select_copy():
+    inputs = [torch.randn(4, 4, 4)]
+    _test_helper(NSliceSelectCopy, inputs)
+
diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py
index ed4122aa8..387637f6c 100644
--- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py
+++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py
@@ -6,3 +6,5 @@
 from .fx_utils import list_decomposed_ops, preprocess_fx_graph, get_none_indices
 from .flash_attn_op import replace_flash_attn
 from .fx_rewrite import fx_replace_attn_pattern
+
+from . import utils
diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/tools/compiler.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/tools/compiler.py
index dd4f383fd..0a68f4f1c 100644
--- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/tools/compiler.py
+++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/tools/compiler.py
@@ -80,6 +80,9 @@ def compile_torchscript(args):
         sample_inputs_placeholder.append(TensorPlaceholder(shape, dtype_str_to_torch_dtype(dtype)))
 
     ts_model = torch.jit.load(args.model_path, map_location="cpu")
+    if args.enable_jit_rewrite:
+      torch_frontend.utils.replace_copy_fill_with_slice_scatter(ts_model.graph)
+
     module = torch_frontend.compile(ts_model, sample_inputs_placeholder, args.output_type, verbose=args.verbose, debug=torch_frontend.DebugType(1))
     if len(args.output_file_path) != 0:
       with open(args.output_file_path, "w") as f:
@@ -102,6 +105,7 @@ def main():
     parser.add_argument("--output_type", type=str, default="stablehlo", choices=["raw", "torch", "stablehlo"])
     parser.add_argument("--elide", default=False, action="store_true")
     parser.add_argument("--verbose", default=False, action="store_true")
+    parser.add_argument("--enable_jit_rewrite", default=False, action="store_true")
     parser.add_argument(
         "--input_name_and_shapes",
         nargs="+",
diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/__init__.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/__init__.py
new file mode 100644
index 000000000..0301fe40e
--- /dev/null
+++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/__init__.py
@@ -0,0 +1 @@
+from .jit_transforms import replace_copy_fill_with_slice_scatter
diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/jit_transforms.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/jit_transforms.py
new file mode 100755
index 000000000..88a6f8827
--- /dev/null
+++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/utils/jit_transforms.py
@@ -0,0 +1,294 @@
+import torch
+import sys
+import functools
+
+
+def _nslice_slice_or_select_copy(graph, pattern):
+    """
+    slice + ... + slice/select + copy_
+    """
+
+    def is_valid():
+        fill = pattern[-1]
+        if fill.kind() != "aten::copy_":
+            return False
+        if pattern[-2].kind() not in ["aten::slice", "aten::select"]:
+            return False
+        for node in pattern[:-2]:
+            if node.kind() != "aten::slice":
+                return False
+            return (
+                node.inputsAt(2).toIValue() == 0
+                and node.inputsAt(3).toIValue() == sys.maxsize
+            )
+        return True
+
+    if not is_valid():
+        return None, graph
+
+    # availiable args.
+    sc_input1 = pattern[0].inputsAt(0)  # self
+    sc_input2 = pattern[-1].inputsAt(1)  # src
+    dim = pattern[-2].inputsAt(1)
+    start = pattern[-2].inputsAt(2)
+
+    graph.setInsertPoint(pattern[-1])
+    if pattern[-2].kind() == "aten::slice":
+        # for slice + copy_
+        end = pattern[-2].inputsAt(3)
+        step = pattern[-2].inputsAt(4)
+        slice_scatter = graph.create(
+            "aten::slice_scatter",
+            [sc_input1, sc_input2, dim, start, end, step],
+            1,
+        )
+    else:
+        # for select + copy_
+        one = graph.insertConstant(1)
+        end = graph.create("aten::add", [start, one], 1)
+        end.output().setType(torch._C.IntType.get())
+        unsqueeze = graph.create("aten::unsqueeze", [sc_input2, dim], 1)
+        slice_scatter = graph.create(
+            "aten::slice_scatter",
+            [sc_input1, unsqueeze.output(), dim, start, end.output(), one],
+            1,
+        )
+        graph.insertNode(end)
+        graph.insertNode(unsqueeze)
+
+    graph.insertNode(slice_scatter)
+
+    graph.lint()
+
+    return slice_scatter, graph
+
+
+def _nslice_slice_or_select_fill(graph, pattern):
+    """
+    slice + ... + select/slice + fill_
+    """
+
+    def is_valid():
+        fill = pattern[-1]
+        if fill.kind() != "aten::fill_":
+            return False
+        if pattern[-2].kind() not in ["aten::slice", "aten::select"]:
+            return False
+        for node in pattern[:-2]:
+            if node.kind() != "aten::slice":
+                return False
+            return (
+                node.inputsAt(2).toIValue() == 0
+                and node.inputsAt(3).toIValue() == sys.maxsize
+            )
+        return True
+
+    if not is_valid():
+        return None, graph
+
+    graph.setInsertPoint(pattern[-1])
+    sc_input1 = pattern[0].inputsAt(0)
+
+    # build src.
+    one = graph.insertConstant(1)
+    none = graph.insertConstant(None)
+    zeros = graph.create(
+        "aten::zeros_like", [sc_input1, none, none, none, none, none], 1
+    )
+
+    fill = pattern[-1]
+    value = fill.inputsAt(1)
+    fill_value = graph.create("aten::add", [zeros.output(), value, one], 1)
+
+    target_dim = pattern[-2].inputsAt(1)
+    start = pattern[-2].inputsAt(2)
+    # for slice + fill_
+    if pattern[-2].kind() == "aten::slice":
+        end = pattern[-2].inputsAt(3)
+        step = pattern[-2].inputsAt(4)
+        sc_input2 = graph.create(
+            "aten::slice", [fill_value.output(), target_dim, start, end, step], 1
+        )
+        slice_scatter = graph.create(
+            "aten::slice_scatter",
+            [sc_input1, sc_input2.output(), target_dim, start, end, step],
+            1,
+        )
+    else:
+        # for select + fill
+        end = graph.create("aten::add", [start, one], 1)
+        end.output().setType(torch._C.IntType.get())
+        sc_input2 = graph.create(
+            "aten::slice",
+            [fill_value.output(), target_dim, start, end.output(), one],
+            1,
+        )
+        slice_scatter = graph.create(
+            "aten::slice_scatter",
+            [sc_input1, sc_input2.output(), target_dim, start, end.output(), one],
+            1,
+        )
+        graph.insertNode(end)
+
+    graph.insertNode(zeros)
+    graph.insertNode(fill_value)
+    graph.insertNode(sc_input2)
+    graph.insertNode(slice_scatter)
+
+    graph.lint()
+
+    return slice_scatter, graph
+
+
+def _select_slice_fill_(graph, pattern):
+    """
+    select + slice + fill_
+    """
+
+    def is_valid():
+        if len(pattern) != 3:
+            return False
+
+        select = pattern[0]
+        slice = pattern[1]
+        fill = pattern[2]
+        if (
+            select.kind() != "aten::select"
+            or slice.kind() != "aten::slice"
+            or fill.kind() != "aten::fill_"
+        ):
+            return False
+        start = slice.inputsAt(2).toIValue()
+        end = slice.inputsAt(3).toIValue()
+        if start != 0 or end != sys.maxsize:
+            return False
+
+        return True
+
+    if not is_valid():
+        return None, graph
+
+    graph.setInsertPoint(pattern[-1])
+
+    # rewrite select fill_ to slice_scatter.
+    select = pattern[0]
+    target_dim = select.inputsAt(1)
+    start = select.inputsAt(2)
+    one = graph.insertConstant(1)
+    end = graph.create("aten::add", [start, one], 1)
+    end.output().setType(torch._C.IntType.get())
+
+    fill = pattern[2]
+    sc_input1 = select.inputsAt(0)  # self
+    none = graph.insertConstant(None)
+    false = graph.insertConstant(False)
+    value = fill.inputsAt(1)
+    zeros = graph.create(
+        "aten::zeros_like", [sc_input1, none, none, none, none, none], 1
+    )
+    fill_value = graph.create("aten::add", [zeros.output(), value, one], 1)
+    sc_input2 = graph.create(
+        "aten::slice", [fill_value.output(), target_dim, start, end.output(), one], 1
+    )
+
+    slice_scatter = graph.create(
+        "aten::slice_scatter",
+        [sc_input1, sc_input2.output(), target_dim, start, end.output(), one],
+        1,
+    )
+
+    graph.insertNode(end)
+    graph.insertNode(zeros)
+    graph.insertNode(fill_value)
+    graph.insertNode(sc_input2)
+    graph.insertNode(slice_scatter)
+
+    graph.lint()
+
+    return slice_scatter, graph
+
+
+copy_fill_rewrites = [
+    _nslice_slice_or_select_copy,
+    _nslice_slice_or_select_fill,
+    _select_slice_fill_,
+]
+
+
+def replace_copy_fill_with_slice_scatter(graph):
+    """
+    - slice[(2, 0), (3, 9223372036854775807)]
+    - any select/slice
+    - copy_/fill_
+    """
+
+    def is_valid_node(node, level):
+        if level == 0 or node.kind() in ["aten::slice", "aten::select"]:
+            return True
+
+        return False
+
+    def dfs(node, level):
+        pattern = []
+        if is_valid_node(node, level):
+            pattern.append(node)
+        else:
+            return pattern
+        in_tensor = node.inputsAt(0)
+        prev = in_tensor.node()
+        pattern = dfs(prev, level + 1) + pattern
+        return pattern
+
+    def node_compare(lhs, rhs):
+        if lhs[-1].isAfter(rhs[-1]):
+            return 1
+        elif lhs[-1].isBefore(rhs[-1]):
+            return -1
+        else:
+            return 0
+
+    # 1. find pattern points.
+    valid_patterns = []
+    for node in graph.nodes():
+        if node.kind() not in ["aten::copy_", "aten::fill_"]:
+            continue
+        pattern = dfs(node, 0)
+        if len(pattern) >= 2 and pattern[0].inputsAt(0).node().kind() not in [
+            "aten::slice",
+            "aten::select",
+        ]:
+            valid_patterns.append(pattern)
+
+    def post_process(slice_scatter, graph):
+        slice_scatter_use = None
+        sc_input1 = slice_scatter.inputsAt(0)
+        for use in sc_input1.uses():
+            if use.user == slice_scatter:
+                slice_scatter_use = use
+
+        if slice_scatter_use:
+            for use in sc_input1.uses():
+                user = use.user
+                if user == pattern[0] or not use.isAfter(slice_scatter_use):
+                    continue
+
+                for idx, value in enumerate(user.inputs()):
+                    if value == sc_input1:
+                        user.replaceInput(idx, slice_scatter.output())
+
+        for old_node in reversed(pattern):
+            if len(old_node.output().uses()) == 0:
+                old_node.destroy()
+        graph.lint()
+
+    # 2. do rewrite.
+    sorted_patterns = sorted(valid_patterns, key=functools.cmp_to_key(node_compare))
+    for pattern in sorted_patterns:
+        for rewrite_func in copy_fill_rewrites:
+            slice_scatter, graph = rewrite_func(graph, pattern)
+            if slice_scatter is None:
+                continue
+            post_process(slice_scatter, graph)
+            break
+
+    return graph