Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why remove the rewriters related to cuDNN FMHA? #23389

Open
Jinghao14 opened this issue Mar 5, 2025 · 2 comments
Open

Why remove the rewriters related to cuDNN FMHA? #23389

Jinghao14 opened this issue Mar 5, 2025 · 2 comments
Labels
question Further information is requested

Comments

@Jinghao14
Copy link

Noted the PR #22225 which deleted passes and rewriters related to cuDNN FMHA.
I am curious about the reason for doing this, and if I need to use flash attention, with tenslrflow+xla, how should I do it?

@bchetioui
Copy link
Member

bchetioui commented Mar 7, 2025

Hi @Jinghao14. The main reasons why we decided to delete these passes and rewriters are as follows:

  1. such pattern matchers that try to recover complicated structure from soup of ops tend to be fairly brittle and lead to annoying performance cliffs as the compiler makes even slightly different decisions. We thought it wiser to directly dispatch attention to a custom call at the framework level, so that XLA would not have to recover lost structure;
  2. these matchers were matching general formulations of attention in HLO, and dispatching them to a "Flash" implementation under the hood. While that is all beneficial in terms of performance, the rewrite between the "classical" formulation of attention into a "Flash" variant does not preserve numerics---so we would never have been able to turn this behaviour on by default.

JAX now has a dedicated scaled dot product attention ("SDPA") API, which can be used to dispatch to cuDNN's implementation. I am less familiar with TensorFlow's APIs, so I'm not sure what the canonical way of doing this there is, but note that the custom calls are still available in XLA's runtime---so having such custom calls in the HLO you feed to the compiler will still dispatch to cuDNN. See this test for an example.

Does that help?

@Jinghao14
Copy link
Author

Hi @bchetioui Thank you for your response. However, I can only use TensorFlow, and I am looking for a way to use Flash Attention in TensorFlow. Would implementing a custom call to invoke the Flash Attention kernel be a good approach?

@aniruthraj aniruthraj added the question Further information is requested label Mar 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants