Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Inductor][Optimus] Fix group fusion stride layout (pytorch#134696)
Summary: Pull Request resolved: pytorch#134696 X-link: pytorch/benchmark#2442 context: https://fb.workplace.com/groups/1075192433118967/permalink/1401282167176657/ moving the changes to the group gemm op has compilation errors, see details in D55606636 Test Plan: # local reproduce ``` CUDA_LAUNCH_BLOCKING=1 buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode batch-split-group --model_type "afoc" --flow_id 544109991 ``` Counter({'pattern_matcher_nodes': 1215, 'pattern_matcher_count': 1090, 'normalization_pass': 430, 'remove_split_with_size_one_pass': 416, 'batch_aten_mul': 13, 'scmerge_split_sections_removed': 11, 'scmerge_cat_removed': 5, 'scmerge_cat_added': 4, 'batch_linear_post_grad': 4, 'scmerge_split_removed': 3, 'batch_aten_sub': 2, 'batch_layernorm': 1, 'group_linear': 1}) ``` CUDA_VISIBLE_DEVICES=3 OC_CAUSE=1 buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode group-batch-split --model_type "cmf_shrink" --flow_id 587303213 ``` P1551948670 Counter({'pattern_matcher_nodes': 2244, 'pattern_matcher_count': 1738, 'normalization_pass': 404, 'extern_calls': 370, 'benchmarking.TritonBenchmarker.benchmark_gpu': 293, 'remove_split_with_size_one_pass': 269, 'merge_splits_pass': 74, 'normalization_aten_pass': 56, 'batch_aten_mul': 11, 'fxgraph_cache_miss': 10, 'group_linear': 9, 'scmerge_split_sections_removed': 5, 'scmerge_split_removed': 4, 'scmerge_cat_removed': 4, 'unbind_stack_pass': 4, 'batch_sigmoid': 2, 'batch_linear': 2, 'move_reshape_out_of_split_stack_pass': 2, 'batch_aten_sub': 2, 'batch_aten_add': 2, 'batch_layernorm': 1, 'scmerge_split_added': 1, 'scmerge_cat_added': 1, 'split_stack_to_cats_pass': 1, 'split_cat_to_slices_pass': 1, 'benchmarking.TritonBenchmarker.triton_do_bench': 1, 'batch_relu': 1}) # e2e ### AFOC baseline: f545589474 proposal: f545589302 {F1474302182} ### cmf shrink ads_dper3:0e442d2994ad1421377489d53ef99593 training_platform:be4b7015f1582fb1760bd72cf83ff38d baseline f635512197 baseline + group_fusion f635975547 {F1832419326} {F1832419319} {F1832419401} The group fusion can be enabled but has qps regression by using group fusion, will do a dive deep study. Differential Revision: D61888433
- Loading branch information