Skip to content

Commit ef1a142

Browse files
steeveGoogle-ML-Automation
authored andcommitted
PR #23654: [ROCm] Enable mfma instructions by passing the correct arch name
Imported from GitHub PR #23654 This PR enables the Triton pipeline to emit `#triton_gpu.amd_mfma` annotations during the Triton to TritonGPU lowering. This is done by the `TritonAMDGPUAccelerateMatmulPass`, which checks the GFX version to do that. Correctly passing the `gfx_version` reduces our kernel runtime from **~5ms** to **~620us** on MI300X, matching the performance of the Python Triton Compiler used in Torch. We expect this change to radiate quite a bit given that this pipeline is shared by the IR Fusion Emitter used widely across XLA if `tt.dot` ops are emitted. Closes #23574 Copybara import of the project: -- c256551 by Steeve Morin <[email protected]>: [ROCm] Enable mfma instructions by passing the correct arch name Without this commit, mfma instructions would not be emitted by this pass. Merging this change closes #23654 COPYBARA_INTEGRATE_REVIEW=#23654 from zml:zml/rocm/mfma c256551 PiperOrigin-RevId: 736519517
1 parent f26b03d commit ef1a142

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
9090
pm->addPass(mt::gpu::createTritonGPUCoalesce());
9191
pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
9292
pm->addPass(mt::gpu::createTritonGPUOptimizeThreadLocality());
93-
pm->addPass(mlir::createTritonAMDGPUAccelerateMatmulPass());
93+
pm->addPass(mlir::createTritonAMDGPUAccelerateMatmulPass(cc.gfx_version()));
9494
pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
9595
// TODO ROCm Check if we want to compare MI100 and greater
9696
pm->addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass());

0 commit comments

Comments
 (0)