Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Very slow Triton performance vs PyTorch for the same kernel #23574

Closed
steeve opened this issue Mar 10, 2025 · 24 comments · Fixed by #23691
Closed

[ROCm] Very slow Triton performance vs PyTorch for the same kernel #23574

steeve opened this issue Mar 10, 2025 · 24 comments · Fixed by #23691

Comments

@steeve
Copy link
Contributor

steeve commented Mar 10, 2025

Hello folks,

When running a custom paged attention kernel, we get about ~10x performance hit on a Triton kernel when running on Triton/XLA than on Triton/PyTorch. We are using the kernel at https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/pa_decode.py#L309

We're not sure what can cause this, the grids are the same. We saw that ThreadPerWarp is hardcoded to 32 in XLA (should be 64), but I'm not sure that might explain it.

When trying this Pallas implementation at https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/paged_attention.py in JAX, we get ~750us on H200 but 20ms (!) on MI300x.

We kind of exhausted the obvious stuff, and are starting to wonder this might be due to a problem of ROCm in XLA.

Thank you!

@i-chaochen
Copy link
Contributor

i-chaochen commented Mar 10, 2025

Hi @steeve

Thanks for the asking.

May I ask which version of jax/xla you're using? We have disabled triton on rocm side until this PR for gemm fusion autotuner is landed #19754

But we have already backported it to 0.4.35, which you should be able to use it from https://github.com/ROCm/jax/releases/tag/rocm-jax-v0.4.35

Regarding warp size, we are going to land this PR to upstream as well. You can cherry-pick for now ROCm@89cd01b

@steeve
Copy link
Contributor Author

steeve commented Mar 10, 2025

Thank you for responding!
I might have added that I'm using the __gpu$xla.gpu.triton stablehlo.custom_call and emitting ttir directly
Do you think it might be linked to the gemm autotuner ?

@rahulbatra85
Copy link
Contributor

rahulbatra85 commented Mar 10, 2025

@steeve Hi Steeve, how are you running Triton kernel(https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/pa_decode.py#L309) with XLA? Did you JAX/JAX-Triton(https://github.com/jax-ml/jax-triton) or calling Triton kernel with Pytorch/XLA(something like this https://pytorch.org/xla/release/r2.6/features/triton.html)?

@steeve
Copy link
Contributor Author

steeve commented Mar 10, 2025

@rahulbatra85, we dumped the Triton-MLIR and leverage __gpu$xla.gpu.triton. We also created a trampoline to please XLA's calling convention:

  • load the scalars from tt.ptr
  • pass the output as the last argument
  tt.func public @_paged_attn_decode_v1_w_dot_kernel_tt_load(%ptr1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr2: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr3: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr4: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr5: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr6: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %ptr7: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %ptr8: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %ptr9: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr10: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr11: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr12: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr13: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr14: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr15: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr16: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
    %0 = tt.load %ptr6 : !tt.ptr<f32>
    %1 = tt.load %ptr7 : !tt.ptr<f32>
    %2 = tt.load %ptr8 : !tt.ptr<f32>
    %3 = tt.load %ptr9 : !tt.ptr<i32>
    %4 = tt.load %ptr10 : !tt.ptr<i32>
    %5 = tt.load %ptr11 : !tt.ptr<i32>
    %6 = tt.load %ptr12 : !tt.ptr<i32>
    %7 = tt.load %ptr13 : !tt.ptr<i32>
    %8 = tt.load %ptr14 : !tt.ptr<i32>
    %9 = tt.load %ptr15 : !tt.ptr<i32>
    %10 = tt.load %ptr16 : !tt.ptr<i32>
    tt.call @_paged_attn_decode_v1_w_dot_kernel(%ptr0, %ptr1, %ptr2, %ptr3, %ptr4, %ptr5, %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10) : (!tt.ptr<bf16>, !tt.ptr<bf16>, !tt.ptr<bf16>, !tt.ptr<bf16>, !tt.ptr<i32>, !tt.ptr<i64>, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32) -> ()
    tt.return
  }

So a bit different from jax-triton, which has its own custom call target

@steeve
Copy link
Contributor Author

steeve commented Mar 10, 2025

May I ask which version of jax/xla you're using? We have disabled triton on rocm side until this PR for gemm fusion autotuner is landed #19754

Also, it seems to have landed in e92a0f9 no?

We are on this XLA/PJRT ROCm version from Feb 6th: 6789523

Also using ROCm 6.3

@steeve
Copy link
Contributor Author

steeve commented Mar 10, 2025

(https://github.com/steeve) Hi Steeve, how are you running Triton kernel(https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/pa_decode.py#L309) with XLA?

Also @rahulbatra85 this kernel does run in ~5ms on our MI300X on XLA, but it's still a far cry from the ~650us ballpark on aiter/PyTorch. We're not sure why as we're literally using the MLIR from from torch.

Params:

    batch_size = 256
    num_heads = 32
    num_kv_heads = 8
    q_kv_head_ratio = 4 # (32/8)
    head_size = 128
    block_size = 16
    num_blocks = 16384
    max_seq_len = 2048
    attn_scale = 0.08838834765
    grid = (batch_size, num_kv_heads, 1)

Hope that helps!

@steeve
Copy link
Contributor Author

steeve commented Mar 10, 2025

Regarding warp size, we are going to land this PR to upstream as well. You can cherry-pick for now ROCm@89cd01b

We'll try that tomorrow first thing!! Thanks @i-chaochen !

@i-chaochen
Copy link
Contributor

wondering could you post the dump out hlo for us to try? so we rerun it with xla to see what's missing?

@steeve
Copy link
Contributor Author

steeve commented Mar 10, 2025

sure here is the rocm/aiter one:

HloModule zml, entry_computation_layout={(bf16[256,32,128]{2,1,0}, bf16[16384,8,16,128]{3,2,1,0}, bf16[16384,8,16,128]{3,2,1,0}, s32[256,128]{1,0}, s64[256]{0})->bf16[256,32,128]{2,1,0}}

ENTRY %main.14 (Arg_0.1: bf16[256,32,128], Arg_1.2: bf16[16384,8,16,128], Arg_2.3: bf16[16384,8,16,128], Arg_3.4: s32[256,128], Arg_4.5: s64[256]) -> bf16[256,32,128] {
  %Arg_0.1 = bf16[256,32,128]{2,1,0} parameter(0)
  %Arg_1.2 = bf16[16384,8,16,128]{3,2,1,0} parameter(1)
  %Arg_2.3 = bf16[16384,8,16,128]{3,2,1,0} parameter(2)
  %Arg_3.4 = s32[256,128]{1,0} parameter(3)
  %Arg_4.5 = s64[256]{0} parameter(4)
  %constant.6 = f32[] constant(0.0883883461)
  %constant.7 = s32[] constant(8192)
  %constant.8 = s32[] constant(256)
  %constant.9 = s32[] constant(2)
  %constant.10 = s32[] constant(32768)
  %constant.11 = s32[] constant(4096)
  %constant.12 = s32[] constant(512)
  ROOT %custom-call.13 = bf16[256,32,128]{2,1,0} custom-call(bf16[256,32,128]{2,1,0} %Arg_0.1, bf16[16384,8,16,128]{3,2,1,0} %Arg_1.2, bf16[16384,8,16,128]{3,2,1,0} %Arg_2.3, s32[256,128]{1,0} %Arg_3.4, s64[256]{0} %Arg_4.5, /*index=5*/f32[] %constant.6, s32[] %constant.7, s32[] %constant.8, s32[] %constant.7, s32[] %constant.8, /*index=10*/s32[] %constant.9, s32[] %constant.10, s32[] %constant.11, s32[] %constant.8, s32[] %constant.9, /*index=15*/s32[] %constant.12), custom_call_target="__gpu$xla.gpu.triton", api_version=API_VERSION_TYPED_FFI, metadata={source_file="external/zml+/zml/ops.zig" source_line=1212}, backend_config={grid_x = 256 : i32, grid_y = 8 : i32, grid_z = 1 : i32, ir = "#loc = loc(\22/workspace/pa_decode_rework.py\22:159:0)\0A#loc1 = loc(unknown)\0A#loc54 = loc(\22/workspace/pa_decode_rework.py\22:273:42)\0A#loc71 = loc(\22/workspace/pa_decode_rework.py\22:288:43)\0A#loc89 = loc(callsite(#loc1 at #loc54))\0A#loc93 = loc(callsite(#loc1 at #loc71))\0Amodule {\0A  tt.func public @_paged_attn_decode_v1_w_dot_kernel_tt_load_only(%ptr1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr2: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr3: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr4: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr5: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr6: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr7: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr8: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr9: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr10: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr11: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr12: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr13: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr14: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr15: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr16: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {\0A\0A    %0 = tt.load %ptr6 : !tt.ptr<f32>\0A    %1 = tt.load %ptr7 : !tt.ptr<f32>  \0A    %2 = tt.load %ptr8 : !tt.ptr<f32>\0A    %3 = tt.load %ptr9 : !tt.ptr<i32>\0A    %4 = tt.load %ptr10 : !tt.ptr<i32>\0A    %5 = tt.load %ptr11 : !tt.ptr<i32>\0A    %6 = tt.load %ptr12 : !tt.ptr<i32>\0A    %7 = tt.load %ptr13 : !tt.ptr<i32>\0A    %8 = tt.load %ptr14 : !tt.ptr<i32>\0A    %9 = tt.load %ptr15 : !tt.ptr<i32>\0A    %10 = tt.load %ptr16 : !tt.ptr<i32>\0A    tt.call @_paged_attn_decode_v1_w_dot_kernel(%ptr0, %ptr1, %ptr2, %ptr3, %ptr4, %ptr5, %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10) : (!tt.ptr<bf16>, !tt.ptr<bf16>, !tt.ptr<bf16>, !tt.ptr<bf16>, !tt.ptr<i32>, !tt.ptr<i64>, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32) -> ()\0A    tt.return\0A  }\0A\0A  tt.func private @_paged_attn_decode_v1_w_dot_kernel(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg2: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg3: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg4: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg5: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg6: f32 loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg7: f32 loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg8: f32 loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg9: i32 {tt.divisibility = 16 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg10: i32 {tt.divisibility = 16 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg11: i32 {tt.divisibility = 16 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg12: i32 {tt.divisibility = 16 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg13: i32 {tt.divisibility = 16 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg14: i32 {tt.divisibility = 16 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg15: i32 {tt.divisibility = 16 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0), %arg16: i32 {tt.divisibility = 16 : i32} loc(\22/workspace/pa_decode_rework.py\22:159:0)) attributes {noinline = false} {\0A    %cst = arith.constant dense<0xFF800000> : tensor<16xf32> loc(#loc1)\0A    %cst_0 = arith.constant dense<0.000000e+00> : tensor<16xf32> loc(#loc1)\0A    %c15_i64 = arith.constant 15 : i64 loc(#loc1)\0A    %c1_i64 = arith.constant 1 : i64 loc(#loc1)\0A    %c0_i64 = arith.constant 0 : i64 loc(#loc1)\0A    %cst_1 = arith.constant dense<0.000000e+00> : tensor<16x128xbf16> loc(#loc1)\0A    %cst_2 = arith.constant dense<1.44269502> : tensor<16xf32> loc(#loc1)\0A    %cst_3 = arith.constant dense<1.44269502> : tensor<16x16xf32> loc(#loc1)\0A    %cst_4 = arith.constant dense<0xFF800000> : tensor<16x16xf32> loc(#loc1)\0A    %cst_5 = arith.constant dense<0.000000e+00> : tensor<16x16xf32> loc(#loc1)\0A    %cst_6 = arith.constant dense<16> : tensor<16x1xi32> loc(#loc1)\0A    %c16_i64 = arith.constant 16 : i64 loc(#loc1)\0A    %cst_7 = arith.constant dense<0.000000e+00> : tensor<16x128xf32> loc(#loc1)\0A    %cst_8 = arith.constant dense<128> : tensor<1x128xi32> loc(#loc1)\0A    %cst_9 = arith.constant dense<4> : tensor<16x1xi32> loc(#loc1)\0A    %c4_i32 = arith.constant 4 : i32 loc(#loc1)\0A    %0 = tt.get_program_id x : i32 loc(#loc2)\0A    %1 = tt.get_program_id y : i32 loc(#loc3)\0A    %2 = tt.addptr %arg5, %0 : !tt.ptr<i64>, i32 loc(#loc4)\0A    %3 = tt.load %2 : !tt.ptr<i64> loc(#loc5)\0A    %4 = arith.addi %3, %c15_i64 : i64 loc(#loc86)\0A    %5 = arith.divsi %4, %c16_i64 : i64 loc(#loc87)\0A    %6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc9)\0A    %7 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc10)\0A    %8 = arith.muli %0, %arg11 : i32 loc(#loc11)\0A    %9 = arith.muli %1, %c4_i32 : i32 loc(#loc12)\0A    %10 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> loc(#loc13)\0A    %11 = tt.splat %9 : i32 -> tensor<16x1xi32> loc(#loc14)\0A    %12 = arith.addi %11, %10 : tensor<16x1xi32> loc(#loc14)\0A    %13 = tt.splat %arg12 : i32 -> tensor<16x1xi32> loc(#loc15)\0A    %14 = arith.muli %12, %13 : tensor<16x1xi32> loc(#loc15)\0A    %15 = tt.splat %8 : i32 -> tensor<16x1xi32> loc(#loc16)\0A    %16 = arith.addi %15, %14 : tensor<16x1xi32> loc(#loc16)\0A    %17 = tt.expand_dims %7 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc17)\0A    %18 = tt.broadcast %16 : tensor<16x1xi32> -> tensor<16x128xi32> loc(#loc18)\0A    %19 = tt.broadcast %17 : tensor<1x128xi32> -> tensor<16x128xi32> loc(#loc18)\0A    %20 = arith.addi %18, %19 : tensor<16x128xi32> loc(#loc18)\0A    %21 = arith.cmpi slt, %10, %cst_9 : tensor<16x1xi32> loc(#loc19)\0A    %22 = arith.cmpi slt, %17, %cst_8 : tensor<1x128xi32> loc(#loc20)\0A    %23 = tt.broadcast %21 : tensor<16x1xi1> -> tensor<16x128xi1> loc(#loc21)\0A    %24 = tt.broadcast %22 : tensor<1x128xi1> -> tensor<16x128xi1> loc(#loc21)\0A    %25 = arith.andi %23, %24 : tensor<16x128xi1> loc(#loc21)\0A    %26 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<16x128x!tt.ptr<bf16>> loc(#loc22)\0A    %27 = tt.addptr %26, %20 : tensor<16x128x!tt.ptr<bf16>>, tensor<16x128xi32> loc(#loc22)\0A    %28 = tt.load %27, %25, %cst_1 : tensor<16x128x!tt.ptr<bf16>> loc(#loc23)\0A    %29 = arith.extf %28 : tensor<16x128xbf16> to tensor<16x128xf32> loc(#loc24)\0A    %30 = tt.splat %arg6 : f32 -> tensor<16x128xf32> loc(#loc24)\0A    %31 = arith.mulf %29, %30 : tensor<16x128xf32> loc(#loc24)\0A    %32 = arith.truncf %31 : tensor<16x128xf32> to tensor<16x128xbf16> loc(#loc25)\0A    %33 = arith.muli %1, %arg14 : i32 loc(#loc26)\0A    %34 = tt.splat %arg15 : i32 -> tensor<16x1xi32> loc(#loc27)\0A    %35 = arith.muli %10, %34 : tensor<16x1xi32> loc(#loc27)\0A    %36 = tt.splat %33 : i32 -> tensor<16x1xi32> loc(#loc28)\0A    %37 = arith.addi %36, %35 : tensor<16x1xi32> loc(#loc28)\0A    %38 = tt.broadcast %37 : tensor<16x1xi32> -> tensor<16x128xi32> loc(#loc29)\0A    %39 = arith.addi %38, %19 : tensor<16x128xi32> loc(#loc29)\0A    %40 = arith.muli %0, %arg16 : i32 loc(#loc30)\0A    %41 = tt.addptr %arg4, %40 : !tt.ptr<i32>, i32 loc(#loc31)\0A    %42 = arith.extsi %6 : tensor<16xi32> to tensor<16xi64> loc(#loc32)\0A    %43 = tt.splat %3 : i64 -> tensor<16x1xi64> loc(#loc33)\0A    %44 = arith.cmpi slt, %10, %cst_6 : tensor<16x1xi32> loc(#loc34)\0A    %45 = tt.splat %arg2 : !tt.ptr<bf16> -> tensor<16x128x!tt.ptr<bf16>> loc(#loc35)\0A    %46 = tt.splat %3 : i64 -> tensor<1x16xi64> loc(#loc36)\0A    %47 = tt.broadcast %21 : tensor<16x1xi1> -> tensor<16x16xi1> loc(#loc37)\0A    %48 = tt.splat %arg3 : !tt.ptr<bf16> -> tensor<16x128x!tt.ptr<bf16>> loc(#loc38)\0A    %49:3 = scf.for %arg17 = %c0_i64 to %5 step %c1_i64 iter_args(%arg18 = %cst_7, %arg19 = %cst_0, %arg20 = %cst) -> (tensor<16x128xf32>, tensor<16xf32>, tensor<16xf32>)  : i64 {\0A      %63 = tt.addptr %41, %arg17 : !tt.ptr<i32>, i64 loc(#loc40)\0A      %64 = tt.load %63 : !tt.ptr<i32> loc(#loc41)\0A      %65 = arith.muli %64, %arg13 : i32 loc(#loc42)\0A      %66 = tt.splat %65 : i32 -> tensor<16x128xi32> loc(#loc43)\0A      %67 = arith.addi %66, %39 : tensor<16x128xi32> loc(#loc43)\0A      %68 = arith.muli %arg17, %c16_i64 : i64 loc(#loc44)\0A      %69 = tt.splat %68 : i64 -> tensor<16xi64> loc(#loc32)\0A      %70 = arith.addi %69, %42 : tensor<16xi64> loc(#loc32)\0A      %71 = tt.expand_dims %70 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64> loc(#loc45)\0A      %72 = arith.cmpi slt, %71, %43 : tensor<16x1xi64> loc(#loc33)\0A      %73 = arith.andi %72, %44 : tensor<16x1xi1> loc(#loc46)\0A      %74 = tt.broadcast %73 : tensor<16x1xi1> -> tensor<16x128xi1> loc(#loc47)\0A      %75 = arith.andi %74, %24 : tensor<16x128xi1> loc(#loc47)\0A      %76 = tt.addptr %45, %67 : tensor<16x128x!tt.ptr<bf16>>, tensor<16x128xi32> loc(#loc35)\0A      %77 = tt.load %76, %75, %cst_1 : tensor<16x128x!tt.ptr<bf16>> loc(#loc48)\0A      %78 = tt.trans %77 {order = array<i32: 1, 0>} : tensor<16x128xbf16> -> tensor<128x16xbf16> loc(#loc49)\0A      %79 = tt.dot %32, %78, %cst_5 : tensor<16x128xbf16> * tensor<128x16xbf16> -> tensor<16x16xf32> loc(#loc49)\0A      %80 = tt.expand_dims %70 {axis = 0 : i32} : tensor<16xi64> -> tensor<1x16xi64> loc(#loc50)\0A      %81 = arith.cmpi slt, %80, %46 : tensor<1x16xi64> loc(#loc36)\0A      %82 = tt.broadcast %81 : tensor<1x16xi1> -> tensor<16x16xi1> loc(#loc37)\0A      %83 = arith.andi %47, %82 : tensor<16x16xi1> loc(#loc37)\0A      %84 = arith.select %83, %79, %cst_4 : tensor<16x16xi1>, tensor<16x16xf32> loc(#loc51)\0A      %85 = arith.select %83, %84, %cst_4 : tensor<16x16xi1>, tensor<16x16xf32> loc(#loc52)\0A      %86 = \22tt.reduce\22(%85) <{axis = 1 : i32}> ({\0A      ^bb0(%arg21: f32 loc(callsite(#loc1 at #loc54)), %arg22: f32 loc(callsite(#loc1 at #loc54))):\0A        %107 = arith.maxnumf %arg21, %arg22 : f32 loc(#loc95)\0A        tt.reduce.return %107 : f32 loc(#loc88)\0A      }) : (tensor<16x16xf32>) -> tensor<16xf32> loc(#loc88)\0A      %87 = arith.maxnumf %86, %arg20 : tensor<16xf32> loc(#loc56)\0A      %88 = tt.expand_dims %87 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32> loc(#loc57)\0A      %89 = tt.broadcast %88 : tensor<16x1xf32> -> tensor<16x16xf32> loc(#loc58)\0A      %90 = arith.subf %85, %89 : tensor<16x16xf32> loc(#loc58)\0A      %91 = arith.mulf %90, %cst_3 : tensor<16x16xf32> loc(#loc59)\0A      %92 = math.exp2 %91 : tensor<16x16xf32> loc(#loc60)\0A      %93 = arith.subf %arg20, %87 : tensor<16xf32> loc(#loc61)\0A      %94 = arith.mulf %93, %cst_2 : tensor<16xf32> loc(#loc62)\0A      %95 = math.exp2 %94 : tensor<16xf32> loc(#loc63)\0A      %96 = tt.expand_dims %95 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32> loc(#loc64)\0A      %97 = tt.broadcast %96 : tensor<16x1xf32> -> tensor<16x128xf32> loc(#loc65)\0A      %98 = arith.mulf %arg18, %97 : tensor<16x128xf32> loc(#loc65)\0A      %99 = tt.addptr %48, %67 : tensor<16x128x!tt.ptr<bf16>>, tensor<16x128xi32> loc(#loc38)\0A      %100 = tt.load %99, %75, %cst_1 : tensor<16x128x!tt.ptr<bf16>> loc(#loc66)\0A      %101 = arith.truncf %92 : tensor<16x16xf32> to tensor<16x16xbf16> loc(#loc67)\0A      %102 = tt.dot %101, %100, %98 : tensor<16x16xbf16> * tensor<16x128xbf16> -> tensor<16x128xf32> loc(#loc68)\0A      %103 = arith.mulf %arg19, %95 : tensor<16xf32> loc(#loc69)\0A      %104 = arith.extf %101 : tensor<16x16xbf16> to tensor<16x16xf32> loc(#loc91)\0A      %105 = \22tt.reduce\22(%104) <{axis = 1 : i32}> ({\0A      ^bb0(%arg21: f32 loc(callsite(#loc1 at #loc71)), %arg22: f32 loc(callsite(#loc1 at #loc71))):\0A        %107 = arith.addf %arg21, %arg22 : f32 loc(#loc96)\0A        tt.reduce.return %107 : f32 loc(#loc92)\0A      }) : (tensor<16x16xf32>) -> tensor<16xf32> loc(#loc92)\0A      %106 = arith.addf %103, %105 : tensor<16xf32> loc(#loc74)\0A      scf.yield %102, %106, %87 : tensor<16x128xf32>, tensor<16xf32>, tensor<16xf32> loc(#loc75)\0A    } loc(#loc39)\0A    %50 = tt.expand_dims %49#1 {axis = 1 : i32} : tensor<16xf32> -> tensor<16x1xf32> loc(#loc76)\0A    %51 = tt.broadcast %50 : tensor<16x1xf32> -> tensor<16x128xf32> loc(#loc77)\0A    %52 = arith.divf %49#0, %51 : tensor<16x128xf32> loc(#loc77)\0A    %53 = arith.muli %0, %arg9 : i32 loc(#loc78)\0A    %54 = tt.splat %arg10 : i32 -> tensor<16x1xi32> loc(#loc79)\0A    %55 = arith.muli %12, %54 : tensor<16x1xi32> loc(#loc79)\0A    %56 = tt.splat %53 : i32 -> tensor<16x1xi32> loc(#loc80)\0A    %57 = arith.addi %56, %55 : tensor<16x1xi32> loc(#loc80)\0A    %58 = tt.broadcast %57 : tensor<16x1xi32> -> tensor<16x128xi32> loc(#loc81)\0A    %59 = arith.addi %58, %19 : tensor<16x128xi32> loc(#loc81)\0A    %60 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<16x128x!tt.ptr<bf16>> loc(#loc82)\0A    %61 = tt.addptr %60, %59 : tensor<16x128x!tt.ptr<bf16>>, tensor<16x128xi32> loc(#loc82)\0A    %62 = arith.truncf %52 : tensor<16x128xf32> to tensor<16x128xbf16> loc(#loc83)\0A    tt.store %61, %62, %25 : tensor<16x128x!tt.ptr<bf16>> loc(#loc84)\0A    tt.return loc(#loc85)\0A  } loc(#loc)\0A} loc(#loc)\0A#loc2 = loc(\22/workspace/pa_decode_rework.py\22:194:28)\0A#loc3 = loc(\22/workspace/pa_decode_rework.py\22:195:32)\0A#loc4 = loc(\22/workspace/pa_decode_rework.py\22:199:37)\0A#loc5 = loc(\22/workspace/pa_decode_rework.py\22:199:22)\0A#loc6 = loc(\22/usr/local/lib/python3.12/dist-packages/triton/language/standard.py\22:40:22)\0A#loc7 = loc(\22/workspace/pa_decode_rework.py\22:201:35)\0A#loc8 = loc(\22/usr/local/lib/python3.12/dist-packages/triton/language/standard.py\22:40:28)\0A#loc9 = loc(\22/workspace/pa_decode_rework.py\22:203:28)\0A#loc10 = loc(\22/workspace/pa_decode_rework.py\22:204:32)\0A#loc11 = loc(\22/workspace/pa_decode_rework.py\22:218:18)\0A#loc12 = loc(\22/workspace/pa_decode_rework.py\22:219:25)\0A#loc13 = loc(\22/workspace/pa_decode_rework.py\22:219:51)\0A#loc14 = loc(\22/workspace/pa_decode_rework.py\22:219:40)\0A#loc15 = loc(\22/workspace/pa_decode_rework.py\22:219:63)\0A#loc16 = loc(\22/workspace/pa_decode_rework.py\22:219:10)\0A#loc17 = loc(\22/workspace/pa_decode_rework.py\22:220:23)\0A#loc18 = loc(\22/workspace/pa_decode_rework.py\22:220:10)\0A#loc19 = loc(\22/workspace/pa_decode_rework.py\22:224:36)\0A#loc20 = loc(\22/workspace/pa_decode_rework.py\22:224:77)\0A#loc21 = loc(\22/workspace/pa_decode_rework.py\22:224:53)\0A#loc22 = loc(\22/workspace/pa_decode_rework.py\22:226:24)\0A#loc23 = loc(\22/workspace/pa_decode_rework.py\22:226:16)\0A#loc24 = loc(\22/workspace/pa_decode_rework.py\22:227:13)\0A#loc25 = loc(\22/workspace/pa_decode_rework.py\22:227:23)\0A#loc26 = loc(\22/workspace/pa_decode_rework.py\22:234:22)\0A#loc27 = loc(\22/workspace/pa_decode_rework.py\22:235:30)\0A#loc28 = loc(\22/workspace/pa_decode_rework.py\22:235:10)\0A#loc29 = loc(\22/workspace/pa_decode_rework.py\22:236:10)\0A#loc30 = loc(\22/workspace/pa_decode_rework.py\22:238:51)\0A#loc31 = loc(\22/workspace/pa_decode_rework.py\22:238:41)\0A#loc32 = loc(\22/workspace/pa_decode_rework.py\22:243:39)\0A#loc33 = loc(\22/workspace/pa_decode_rework.py\22:245:37)\0A#loc34 = loc(\22/workspace/pa_decode_rework.py\22:246:35)\0A#loc35 = loc(\22/workspace/pa_decode_rework.py\22:251:36)\0A#loc36 = loc(\22/workspace/pa_decode_rework.py\22:258:76)\0A#loc37 = loc(\22/workspace/pa_decode_rework.py\22:258:52)\0A#loc38 = loc(\22/workspace/pa_decode_rework.py\22:281:36)\0A#loc39 = loc(\22/workspace/pa_decode_rework.py\22:240:19)\0A#loc40 = loc(\22/workspace/pa_decode_rework.py\22:241:50)\0A#loc41 = loc(\22/workspace/pa_decode_rework.py\22:241:30)\0A#loc42 = loc(\22/workspace/pa_decode_rework.py\22:242:36)\0A#loc43 = loc(\22/workspace/pa_decode_rework.py\22:242:49)\0A#loc44 = loc(\22/workspace/pa_decode_rework.py\22:243:27)\0A#loc45 = loc(\22/workspace/pa_decode_rework.py\22:245:26)\0A#loc46 = loc(\22/workspace/pa_decode_rework.py\22:246:15)\0A#loc47 = loc(\22/workspace/pa_decode_rework.py\22:247:15)\0A#loc48 = loc(\22/workspace/pa_decode_rework.py\22:251:22)\0A#loc49 = loc(\22/workspace/pa_decode_rework.py\22:256:23)\0A#loc50 = loc(\22/workspace/pa_decode_rework.py\22:258:65)\0A#loc51 = loc(\22/workspace/pa_decode_rework.py\22:260:12)\0A#loc52 = loc(\22/workspace/pa_decode_rework.py\22:271:12)\0A#loc53 = loc(\22/usr/local/lib/python3.12/dist-packages/triton/language/standard.py\22:184:40)\0A#loc55 = loc(\22/usr/local/lib/python3.12/dist-packages/triton/language/standard.py\22:163:27)\0A#loc56 = loc(\22/workspace/pa_decode_rework.py\22:273:55)\0A#loc57 = loc(\22/workspace/pa_decode_rework.py\22:276:45)\0A#loc58 = loc(\22/workspace/pa_decode_rework.py\22:276:31)\0A#loc59 = loc(\22/workspace/pa_decode_rework.py\22:276:57)\0A#loc60 = loc(\22/workspace/pa_decode_rework.py\22:276:25)\0A#loc61 = loc(\22/workspace/pa_decode_rework.py\22:277:42)\0A#loc62 = loc(\22/workspace/pa_decode_rework.py\22:277:59)\0A#loc63 = loc(\22/workspace/pa_decode_rework.py\22:277:29)\0A#loc64 = loc(\22/workspace/pa_decode_rework.py\22:278:21)\0A#loc65 = loc(\22/workspace/pa_decode_rework.py\22:278:15)\0A#loc66 = loc(\22/workspace/pa_decode_rework.py\22:281:22)\0A#loc67 = loc(\22/workspace/pa_decode_rework.py\22:285:17)\0A#loc68 = loc(\22/workspace/pa_decode_rework.py\22:286:25)\0A#loc69 = loc(\22/workspace/pa_decode_rework.py\22:288:28)\0A#loc70 = loc(\22/usr/local/lib/python3.12/dist-packages/triton/language/standard.py\22:266:46)\0A#loc72 = loc(\22/usr/local/lib/python3.12/dist-packages/triton/language/standard.py\22:267:36)\0A#loc73 = loc(\22/usr/local/lib/python3.12/dist-packages/triton/language/standard.py\22:256:15)\0A#loc74 = loc(\22/workspace/pa_decode_rework.py\22:288:36)\0A#loc75 = loc(\22/workspace/pa_decode_rework.py\22:289:8)\0A#loc76 = loc(\22/workspace/pa_decode_rework.py\22:291:24)\0A#loc77 = loc(\22/workspace/pa_decode_rework.py\22:291:16)\0A#loc78 = loc(\22/workspace/pa_decode_rework.py\22:294:18)\0A#loc79 = loc(\22/workspace/pa_decode_rework.py\22:295:63)\0A#loc80 = loc(\22/workspace/pa_decode_rework.py\22:295:10)\0A#loc81 = loc(\22/workspace/pa_decode_rework.py\22:296:10)\0A#loc82 = loc(\22/workspace/pa_decode_rework.py\22:300:23)\0A#loc83 = loc(\22/workspace/pa_decode_rework.py\22:300:40)\0A#loc84 = loc(\22/workspace/pa_decode_rework.py\22:300:33)\0A#loc85 = loc(\22/workspace/pa_decode_rework.py\22:300:4)\0A#loc86 = loc(callsite(#loc6 at #loc7))\0A#loc87 = loc(callsite(#loc8 at #loc7))\0A#loc88 = loc(callsite(#loc53 at #loc54))\0A#loc90 = loc(callsite(#loc55 at #loc53))\0A#loc91 = loc(callsite(#loc70 at #loc71))\0A#loc92 = loc(callsite(#loc72 at #loc71))\0A#loc94 = loc(callsite(#loc73 at #loc72))\0A#loc95 = loc(callsite(#loc90 at #loc54))\0A#loc96 = loc(callsite(#loc94 at #loc71))\0A", name = "_paged_attn_decode_v1_w_dot_kernel_tt_load_only", num_stages = 3 : i32, num_warps = 4 : i32}
}

@rahulbatra85
Copy link
Contributor

@steeve In one case, it's using a mix of XLA/Triton compiler whereas in the other it's using just Triton compiler(aiter/Pytorch)

By the way, it's helpful to know what's the exact framework/compiler you eventually want this to work with.
Do you want

  1. Pallas/JAX/XLA
  2. Triton/JAX/XLA
  3. Triton/Pytorch/XLA
  4. Triton/Pytorch

You can do 2) by using Jax-Triton. This will compile Triton kernel purely using Triton compiler.

@steeve
Copy link
Contributor Author

steeve commented Mar 10, 2025

We're planning to use XLA's Triton compiler "integration", of which the pipeline is at: https://github.com/openxla/xla/blob/main/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc#L60

I'd wager it should be pretty similar, since they pull the passes from the triton compiler? At least maybe not 10x slower haha

Also, we're not using JAX, but our own ML framework: https://github.com/zml/zml

@jayfurmanek
Copy link
Contributor

Thanks for trying this out!

@rahulbatra85, we dumped the Triton-MLIR and leverage __gpu$xla.gpu.triton. We also created a trampoline to please XLA's calling convention:

  • load the scalars from tt.ptr

  • pass the output as the last argument

    tt.func public @_paged_attn_decode_v1_w_dot_kernel_tt_load(%ptr1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr3: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr4: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr5: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr6: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %ptr7: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %ptr8: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %ptr9: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr10: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr11: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr12: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr13: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr14: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr15: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr16: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
    %0 = tt.load %ptr6 : !tt.ptr
    %1 = tt.load %ptr7 : !tt.ptr
    %2 = tt.load %ptr8 : !tt.ptr
    %3 = tt.load %ptr9 : !tt.ptr
    %4 = tt.load %ptr10 : !tt.ptr
    %5 = tt.load %ptr11 : !tt.ptr
    %6 = tt.load %ptr12 : !tt.ptr
    %7 = tt.load %ptr13 : !tt.ptr
    %8 = tt.load %ptr14 : !tt.ptr
    %9 = tt.load %ptr15 : !tt.ptr
    %10 = tt.load %ptr16 : !tt.ptr
    tt.call @_paged_attn_decode_v1_w_dot_kernel(%ptr0, %ptr1, %ptr2, %ptr3, %ptr4, %ptr5, %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32) -> ()
    tt.return
    }
    So a bit different from jax-triton, which has its own custom call target

Wow - that is one way to do it. Can you share more on the steps to reproduce it in this way? This code path is still pretty new on the AMD side and in active development. The focus thus far has been on getting it working well for gemm-fusions in XLA. We do want to get it to a general purpose Triton emitter, however.
The warpSize fix above is needed
There are also some annotations we need. This commit adds a few: ROCm@de4dbab
We'll also need AMD specific parameters like waves_per_eu plumbed in to the launcher and num_stages set appropriately. We are working on those.

@steeve
Copy link
Contributor Author

steeve commented Mar 10, 2025

Let us give you a JAX reproducer tomorrow (Tuesday), which might be quite easy.
Do you think those missing params might explain the performance gap ?

@jayfurmanek
Copy link
Contributor

Ok, great thanks!
Note that the built-in Pallas kernels are pretty NV specific and Pallas doesn't offer a kernel level autotuner to pick the right settings for num_stages or block sizes which can often have a large effect on performance as well.

So a repro using the AITER kernel you are working with would be more interesting.

steeve added a commit to zml/pjrt-artifacts that referenced this issue Mar 11, 2025
@hugomano
Copy link
Contributor

Hello, thanks for the support :)

We wrote a reproduction gist here: https://gist.github.com/hugomano/1b3fbe00cb4a6e77f438627a4c829234

We ran it with ROCm docker image rocm/jax-community@sha256:8bab484be1713655f74da51a191ed824bb9d03db1104fd63530a1ac3c37cf7b and JAX version 0.5.2.dev20250302+15255dd69

Happy to move forward!

@steeve
Copy link
Contributor Author

steeve commented Mar 11, 2025

So I cherry-picked both commits on top on current XLA main and sadly, the AITER kernel still runs in 5ms :(

@steeve
Copy link
Contributor Author

steeve commented Mar 11, 2025

More info, from the rocprof

Image

@steeve
Copy link
Contributor Author

steeve commented Mar 11, 2025

And another:

2025-03-11 20:55:21.565194: I xla/stream_executor/rocm/rocm_stream.cc:336] launching kernel: _paged_attn_decode_v1_w_dot_kernel_tt_load_only__1; gdx: 256 gdy: 8 gdz: 1 bdx: 256 bdy: 1 bdz: 1 smem: 8192 func: 0x788e948b5f70

@steeve
Copy link
Contributor Author

steeve commented Mar 11, 2025

More info from rocprof-sys-run on hipModuleLaunchKernel:

Image

@steeve
Copy link
Contributor Author

steeve commented Mar 11, 2025

Oh wow, I think I found why it was so fast on Torch, generating the block table with the following method had the strides completely wrong (0, 1) instead of (128, 1):

def iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor:
    dimensions = []

    for index, _ in enumerate(shape):
        if index != dim:
            dimension = 1
        else:
            dimension = shape[index]

        dimensions = [*dimensions, dimension]

    return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape)

And now XLA runs at the same speed as PyTorch! But... it runs in 3ms (where ~700us should be a rough target). So the kernel might be slower than what we want.

@i-chaochen
Copy link
Contributor

i-chaochen commented Mar 11, 2025

Oh wow, I think I found why it was so fast on Torch, generating the block table with the following method had the strides completely wrong (0, 1) instead of (128, 1):

def iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor:
dimensions = []

for index, _ in enumerate(shape):
    if index != dim:
        dimension = 1
    else:
        dimension = shape[index]

    dimensions = [*dimensions, dimension]

return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape)

And now XLA runs at the same speed as PyTorch! But... it runs in 3ms (where ~700us should be a rough target). So the kernel might be slower than what we want.

Thanks for the update @steeve

IIUC, you mean your iota, which is based on torch to generate block table, has generated the wrong stride size?

after you used stride as (128, 1), and xla/triton has the same speed as pytorch/triton... which is 3ms.... so the bad performance is also happening in pytorch/triton?

@steeve
Copy link
Contributor Author

steeve commented Mar 12, 2025

It is.
Note that the above JAX snippet (with aiter/pa_decode.py kernel) runs in 980us on H200 with num_stages=2 and num_warps=2.

@steeve
Copy link
Contributor Author

steeve commented Mar 12, 2025

Okay, big update, we found the reason: mfma instructions were not emitted for tt.dot.

Applying this patches brings us from 5ms to 625us on MI300X:

From 5ae7c57db92cedd196417dbbda7e756f127bf1ae Mon Sep 17 00:00:00 2001
From: Steeve Morin <[email protected]>
Date: Wed, 12 Mar 2025 15:05:58 +0000
Subject: [PATCH] [ROCm] Enable mfma instructions by passing the correct arch
 name

Without this commit, mfma instructions would not be emitted
by this pass.
---
 xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc
index f6bd98aedae96..e45c854ba962d 100644
--- a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc
+++ b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc
@@ -89,7 +89,7 @@ absl::Status CreateTritonPipeline(
   pm->addPass(mt::gpu::createTritonGPUCoalesce());
   pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
   pm->addPass(mt::gpu::createTritonGPUOptimizeThreadLocality());
-  pm->addPass(mlir::createTritonAMDGPUAccelerateMatmulPass());
+  pm->addPass(mlir::createTritonAMDGPUAccelerateMatmulPass(cc.gfx_version()));
   pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
   // TODO ROCm Check if we want to compare MI100 and greater
   pm->addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass());

Knowing that XLA uses Triton for fusion and this code is used, I think it might improve Triton performance a lot on AMD.

Oh, and Torch was fine, I must have been too tired.

@steeve
Copy link
Contributor Author

steeve commented Mar 12, 2025

Aaaand, PR: #23654

copybara-service bot pushed a commit that referenced this issue Mar 13, 2025
…h name

Imported from GitHub PR #23654

This PR enables the Triton pipeline to emit `#triton_gpu.amd_mfma` annotations during the Triton to TritonGPU lowering.

This is done by the `TritonAMDGPUAccelerateMatmulPass`, which checks the GFX version to do that.

Correctly passing the `gfx_version` reduces our kernel runtime from **~5ms** to **~620us** on MI300X, matching the performance of the Python Triton Compiler used in Torch.

We expect this change to radiate quite a bit given that this pipeline is shared by the IR Fusion Emitter used widely across XLA if `tt.dot` ops are emitted.

Closes #23574
Copybara import of the project:

--
c256551 by Steeve Morin <[email protected]>:

[ROCm] Enable mfma instructions by passing the correct arch name

Without this commit, mfma instructions would not be emitted
by this pass.

Merging this change closes #23654

FUTURE_COPYBARA_INTEGRATE_REVIEW=#23654 from zml:zml/rocm/mfma c256551
PiperOrigin-RevId: 736471444
copybara-service bot pushed a commit that referenced this issue Mar 13, 2025
…h name

Imported from GitHub PR #23654

This PR enables the Triton pipeline to emit `#triton_gpu.amd_mfma` annotations during the Triton to TritonGPU lowering.

This is done by the `TritonAMDGPUAccelerateMatmulPass`, which checks the GFX version to do that.

Correctly passing the `gfx_version` reduces our kernel runtime from **~5ms** to **~620us** on MI300X, matching the performance of the Python Triton Compiler used in Torch.

We expect this change to radiate quite a bit given that this pipeline is shared by the IR Fusion Emitter used widely across XLA if `tt.dot` ops are emitted.

Closes #23574
Copybara import of the project:

--
c256551 by Steeve Morin <[email protected]>:

[ROCm] Enable mfma instructions by passing the correct arch name

Without this commit, mfma instructions would not be emitted
by this pass.

Merging this change closes #23654

FUTURE_COPYBARA_INTEGRATE_REVIEW=#23654 from zml:zml/rocm/mfma c256551
PiperOrigin-RevId: 736471444
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants