You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
We are interested in adding a few Jax pallas operator implementations to tritonbench. We could first add flash_attention.
To install JAX:
```
python install.py --userbenchmark triton --jax
```
Working on #2328
Pull Request resolved: #2331
Reviewed By: chuanhaozhuge
Differential Revision: D59012756
Pulled By: xuzhao9
fbshipit-source-id: 7052e544535f38bd7c8a1086e39cf80f7d518ae2
Add https://github.com/google/jax/blob/main/jax/experimental/mosaic/gpu/examples/flash_attention.py from jax to Tritonbench
The text was updated successfully, but these errors were encountered: