Skip to content

Commit

Permalink
[Inductor][Optimus] Fix group fusion stride layout (pytorch#134696)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#134696

X-link: pytorch/benchmark#2442

context:
https://fb.workplace.com/groups/1075192433118967/permalink/1401282167176657/

moving the changes to the group gemm op has compilation errors, see details in D55606636

Test Plan:
# local reproduce
```
CUDA_LAUNCH_BLOCKING=1 buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode batch-split-group --model_type "afoc" --flow_id 544109991
```

Counter({'pattern_matcher_nodes': 1215, 'pattern_matcher_count': 1090, 'normalization_pass': 430, 'remove_split_with_size_one_pass': 416, 'batch_aten_mul': 13, 'scmerge_split_sections_removed': 11, 'scmerge_cat_removed': 5, 'scmerge_cat_added': 4, 'batch_linear_post_grad': 4, 'scmerge_split_removed': 3, 'batch_aten_sub': 2, 'batch_layernorm': 1, 'group_linear': 1})

```
CUDA_VISIBLE_DEVICES=3 OC_CAUSE=1 buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode group-batch-split --model_type "cmf_shrink" --flow_id 587303213
```
P1551948670
Counter({'pattern_matcher_nodes': 2244, 'pattern_matcher_count': 1738, 'normalization_pass': 404, 'extern_calls': 370, 'benchmarking.TritonBenchmarker.benchmark_gpu': 293, 'remove_split_with_size_one_pass': 269, 'merge_splits_pass': 74, 'normalization_aten_pass': 56, 'batch_aten_mul': 11, 'fxgraph_cache_miss': 10, 'group_linear': 9, 'scmerge_split_sections_removed': 5, 'scmerge_split_removed': 4, 'scmerge_cat_removed': 4, 'unbind_stack_pass': 4, 'batch_sigmoid': 2, 'batch_linear': 2, 'move_reshape_out_of_split_stack_pass': 2, 'batch_aten_sub': 2, 'batch_aten_add': 2, 'batch_layernorm': 1, 'scmerge_split_added': 1, 'scmerge_cat_added': 1, 'split_stack_to_cats_pass': 1, 'split_cat_to_slices_pass': 1, 'benchmarking.TritonBenchmarker.triton_do_bench': 1, 'batch_relu': 1})

# e2e

### AFOC
baseline:
f545589474
proposal:
f545589302

 {F1474302182}

### cmf shrink

ads_dper3:0e442d2994ad1421377489d53ef99593
training_platform:be4b7015f1582fb1760bd72cf83ff38d

baseline
f635512197

baseline + group_fusion
f635975547

{F1832419326}
{F1832419319}
{F1832419401}

The group fusion can be enabled but has qps regression by using group fusion, will do a dive deep study.

Differential Revision: D61888433
  • Loading branch information
mengluy0125 authored and facebook-github-bot committed Aug 28, 2024
1 parent f997b2b commit 8af38bc
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 11 deletions.
6 changes: 6 additions & 0 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3242,3 +3242,9 @@ def record(cls):
finally:
if config.record_compile_time_instruction_count:
cls.end()


def realize_inputs(inputs: List[torch.fx.Node]):
for inp in inputs:
if isinstance(inp, torch.fx.node.Node):
inp.meta["inductor_realize_to_strides"] = True
8 changes: 1 addition & 7 deletions torch/_inductor/fx_passes/decompose_mem_bound_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from torch import Tensor
from torch._dynamo.utils import counters
from torch._dynamo.utils import counters, realize_inputs

from .. import config
from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern
Expand Down Expand Up @@ -33,12 +33,6 @@ def check_device(a: Tensor, b: Tensor) -> bool:
return a.is_cuda and b.is_cuda


def realize_inputs(inputs: List[torch.fx.Node]):
for inp in inputs:
if isinstance(inp, torch.fx.node.Node):
inp.meta["inductor_realize_to_strides"] = True


def should_decompose_bmm(mat1, mat2) -> bool:
if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
mat1 = mat1.meta["val"]
Expand Down
46 changes: 42 additions & 4 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)

import torch
from torch._dynamo.utils import counters, optimus_scuba_log
from torch._dynamo.utils import counters, optimus_scuba_log, realize_inputs
from torch._utils_internal import upload_graph
from torch.fx.passes.graph_transform_observer import GraphTransformObserver

Expand Down Expand Up @@ -299,6 +299,31 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):

@register_fusion("group_linear", pre_grad=False)
class GroupLinearFusion(GroupFusion):
def get_stride_type(self, node):
node_shape = node.meta["tensor_meta"].shape # type: ignore[union-attr]

def col_major_stride():
return (
node.meta["tensor_meta"].stride[0] == 1
and node.meta["tensor_meta"].stride[1] > 1
and node.meta["tensor_meta"].stride[1] == node_shape[0]
)

def row_major_stride():
return (
node.meta["tensor_meta"].stride[1] == 1
and node.meta["tensor_meta"].stride[0] > 1
and node.meta["tensor_meta"].stride[0] == node_shape[1]
)

stride = None
if row_major_stride():
stride = "row"
if col_major_stride():
stride = "col"

return stride

def _addmm_node_can_be_fused(self, node: torch.fx.Node):
input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr]
weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr]
Expand Down Expand Up @@ -331,15 +356,28 @@ def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]:
if CallFunctionVarArgs(aten.mm.default).match(
node
) and self._mm_node_can_be_fused(node):
group_key = ("group_linear", True)
# don't allow inductor lowering to change the stride for the nodes
realize_inputs([node.args[0], node.args[1]]) # type: ignore[list-item, possibly-undefined]
input_stride = self.get_stride_type(node.args[0])
weight_stride = self.get_stride_type(node.args[1])
group_key = ("group_linear", str(input_stride), str(weight_stride))
elif CallFunctionVarArgs(aten.addmm.default).match(
node
) and self._addmm_node_can_be_fused(node):
# don't allow inductor lowering to change the stride for the nodes
realize_inputs([node.args[0], node.args[1], node.args[2]]) # type: ignore[list-item, possibly-undefined]
input_stride = self.get_stride_type(node.args[1])
weight_stride = self.get_stride_type(node.args[2])
bias = node.args[0]
group_key = ("group_linear", bias is None)
group_key = (
"group_linear",
bias is None,
str(input_stride),
str(weight_stride),
) # type: ignore[assignment]
else:
group_key = None
return group_key
return group_key # type: ignore[return-value]

def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
group_inputs = []
Expand Down

0 comments on commit 8af38bc

Please sign in to comment.