-
Notifications
You must be signed in to change notification settings - Fork 518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm] Very slow Triton performance vs PyTorch for the same kernel #23574
Comments
Hi @steeve Thanks for the asking. May I ask which version of jax/xla you're using? We have disabled triton on rocm side until this PR for gemm fusion autotuner is landed #19754 But we have already backported it to 0.4.35, which you should be able to use it from https://github.com/ROCm/jax/releases/tag/rocm-jax-v0.4.35 Regarding warp size, we are going to land this PR to upstream as well. You can cherry-pick for now ROCm@89cd01b |
Thank you for responding! |
@steeve Hi Steeve, how are you running Triton kernel(https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/pa_decode.py#L309) with XLA? Did you JAX/JAX-Triton(https://github.com/jax-ml/jax-triton) or calling Triton kernel with Pytorch/XLA(something like this https://pytorch.org/xla/release/r2.6/features/triton.html)? |
@rahulbatra85, we dumped the Triton-MLIR and leverage
tt.func public @_paged_attn_decode_v1_w_dot_kernel_tt_load(%ptr1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr2: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr3: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr4: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr5: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %ptr6: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %ptr7: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %ptr8: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %ptr9: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr10: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr11: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr12: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr13: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr14: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr15: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr16: !tt.ptr<i32, 1> {tt.divisibility = 16 : i32}, %ptr0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
%0 = tt.load %ptr6 : !tt.ptr<f32>
%1 = tt.load %ptr7 : !tt.ptr<f32>
%2 = tt.load %ptr8 : !tt.ptr<f32>
%3 = tt.load %ptr9 : !tt.ptr<i32>
%4 = tt.load %ptr10 : !tt.ptr<i32>
%5 = tt.load %ptr11 : !tt.ptr<i32>
%6 = tt.load %ptr12 : !tt.ptr<i32>
%7 = tt.load %ptr13 : !tt.ptr<i32>
%8 = tt.load %ptr14 : !tt.ptr<i32>
%9 = tt.load %ptr15 : !tt.ptr<i32>
%10 = tt.load %ptr16 : !tt.ptr<i32>
tt.call @_paged_attn_decode_v1_w_dot_kernel(%ptr0, %ptr1, %ptr2, %ptr3, %ptr4, %ptr5, %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10) : (!tt.ptr<bf16>, !tt.ptr<bf16>, !tt.ptr<bf16>, !tt.ptr<bf16>, !tt.ptr<i32>, !tt.ptr<i64>, f32, f32, f32, i32, i32, i32, i32, i32, i32, i32, i32) -> ()
tt.return
} So a bit different from jax-triton, which has its own custom call target |
Also @rahulbatra85 this kernel does run in ~5ms on our MI300X on XLA, but it's still a far cry from the ~650us ballpark on aiter/PyTorch. We're not sure why as we're literally using the MLIR from from torch. Params:
Hope that helps! |
We'll try that tomorrow first thing!! Thanks @i-chaochen ! |
wondering could you post the dump out hlo for us to try? so we rerun it with xla to see what's missing? |
sure here is the rocm/aiter one:
|
@steeve In one case, it's using a mix of XLA/Triton compiler whereas in the other it's using just Triton compiler(aiter/Pytorch) By the way, it's helpful to know what's the exact framework/compiler you eventually want this to work with.
You can do 2) by using Jax-Triton. This will compile Triton kernel purely using Triton compiler. |
We're planning to use XLA's Triton compiler "integration", of which the pipeline is at: https://github.com/openxla/xla/blob/main/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc#L60 I'd wager it should be pretty similar, since they pull the passes from the triton compiler? At least maybe not 10x slower haha Also, we're not using JAX, but our own ML framework: https://github.com/zml/zml |
Thanks for trying this out!
Wow - that is one way to do it. Can you share more on the steps to reproduce it in this way? This code path is still pretty new on the AMD side and in active development. The focus thus far has been on getting it working well for gemm-fusions in XLA. We do want to get it to a general purpose Triton emitter, however. |
Let us give you a JAX reproducer tomorrow (Tuesday), which might be quite easy. |
Ok, great thanks! So a repro using the AITER kernel you are working with would be more interesting. |
Hello, thanks for the support :) We wrote a reproduction gist here: https://gist.github.com/hugomano/1b3fbe00cb4a6e77f438627a4c829234 We ran it with ROCm docker image Happy to move forward! |
So I cherry-picked both commits on top on current XLA main and sadly, the AITER kernel still runs in 5ms :( |
And another:
|
Oh wow, I think I found why it was so fast on Torch, generating the block table with the following method had the strides completely wrong def iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor:
dimensions = []
for index, _ in enumerate(shape):
if index != dim:
dimension = 1
else:
dimension = shape[index]
dimensions = [*dimensions, dimension]
return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape) And now XLA runs at the same speed as PyTorch! But... it runs in 3ms (where ~700us should be a rough target). So the kernel might be slower than what we want. |
Thanks for the update @steeve IIUC, you mean your after you used stride as |
It is. |
Okay, big update, we found the reason: mfma instructions were not emitted for Applying this patches brings us from 5ms to 625us on MI300X:
Knowing that XLA uses Triton for fusion and this code is used, I think it might improve Triton performance a lot on AMD. Oh, and Torch was fine, I must have been too tired. |
Aaaand, PR: #23654 |
…h name Imported from GitHub PR #23654 This PR enables the Triton pipeline to emit `#triton_gpu.amd_mfma` annotations during the Triton to TritonGPU lowering. This is done by the `TritonAMDGPUAccelerateMatmulPass`, which checks the GFX version to do that. Correctly passing the `gfx_version` reduces our kernel runtime from **~5ms** to **~620us** on MI300X, matching the performance of the Python Triton Compiler used in Torch. We expect this change to radiate quite a bit given that this pipeline is shared by the IR Fusion Emitter used widely across XLA if `tt.dot` ops are emitted. Closes #23574 Copybara import of the project: -- c256551 by Steeve Morin <[email protected]>: [ROCm] Enable mfma instructions by passing the correct arch name Without this commit, mfma instructions would not be emitted by this pass. Merging this change closes #23654 FUTURE_COPYBARA_INTEGRATE_REVIEW=#23654 from zml:zml/rocm/mfma c256551 PiperOrigin-RevId: 736471444
…h name Imported from GitHub PR #23654 This PR enables the Triton pipeline to emit `#triton_gpu.amd_mfma` annotations during the Triton to TritonGPU lowering. This is done by the `TritonAMDGPUAccelerateMatmulPass`, which checks the GFX version to do that. Correctly passing the `gfx_version` reduces our kernel runtime from **~5ms** to **~620us** on MI300X, matching the performance of the Python Triton Compiler used in Torch. We expect this change to radiate quite a bit given that this pipeline is shared by the IR Fusion Emitter used widely across XLA if `tt.dot` ops are emitted. Closes #23574 Copybara import of the project: -- c256551 by Steeve Morin <[email protected]>: [ROCm] Enable mfma instructions by passing the correct arch name Without this commit, mfma instructions would not be emitted by this pass. Merging this change closes #23654 FUTURE_COPYBARA_INTEGRATE_REVIEW=#23654 from zml:zml/rocm/mfma c256551 PiperOrigin-RevId: 736471444
Hello folks,
When running a custom paged attention kernel, we get about ~10x performance hit on a Triton kernel when running on Triton/XLA than on Triton/PyTorch. We are using the kernel at https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/pa_decode.py#L309
We're not sure what can cause this, the grids are the same. We saw that ThreadPerWarp is hardcoded to 32 in XLA (should be 64), but I'm not sure that might explain it.
When trying this Pallas implementation at https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/paged_attention.py in JAX, we get ~750us on H200 but 20ms (!) on MI300x.
We kind of exhausted the obvious stuff, and are starting to wonder this might be due to a problem of ROCm in XLA.
Thank you!
The text was updated successfully, but these errors were encountered: