You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Noted the PR #22225 which deleted passes and rewriters related to cuDNN FMHA.
I am curious about the reason for doing this, and if I need to use flash attention, with tenslrflow+xla, how should I do it?
The text was updated successfully, but these errors were encountered:
Hi @Jinghao14. The main reasons why we decided to delete these passes and rewriters are as follows:
such pattern matchers that try to recover complicated structure from soup of ops tend to be fairly brittle and lead to annoying performance cliffs as the compiler makes even slightly different decisions. We thought it wiser to directly dispatch attention to a custom call at the framework level, so that XLA would not have to recover lost structure;
these matchers were matching general formulations of attention in HLO, and dispatching them to a "Flash" implementation under the hood. While that is all beneficial in terms of performance, the rewrite between the "classical" formulation of attention into a "Flash" variant does not preserve numerics---so we would never have been able to turn this behaviour on by default.
JAX now has a dedicated scaled dot product attention ("SDPA") API, which can be used to dispatch to cuDNN's implementation. I am less familiar with TensorFlow's APIs, so I'm not sure what the canonical way of doing this there is, but note that the custom calls are still available in XLA's runtime---so having such custom calls in the HLO you feed to the compiler will still dispatch to cuDNN. See this test for an example.
Hi @bchetioui Thank you for your response. However, I can only use TensorFlow, and I am looking for a way to use Flash Attention in TensorFlow. Would implementing a custom call to invoke the Flash Attention kernel be a good approach?
Noted the PR #22225 which deleted passes and rewriters related to cuDNN FMHA.
I am curious about the reason for doing this, and if I need to use flash attention, with tenslrflow+xla, how should I do it?
The text was updated successfully, but these errors were encountered: