Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jax pallas kernel examples for tritonbench #2328

Open
xuzhao9 opened this issue Jun 24, 2024 · 0 comments
Open

Add jax pallas kernel examples for tritonbench #2328

xuzhao9 opened this issue Jun 24, 2024 · 0 comments

Comments

@xuzhao9
Copy link
Contributor

xuzhao9 commented Jun 24, 2024

Add https://github.com/google/jax/blob/main/jax/experimental/mosaic/gpu/examples/flash_attention.py from jax to Tritonbench

facebook-github-bot pushed a commit that referenced this issue Jun 25, 2024
Summary:
We are interested in adding a few Jax pallas operator implementations to tritonbench. We could first add flash_attention.

To install JAX:

```
python install.py --userbenchmark triton --jax
```

Working on #2328

Pull Request resolved: #2331

Reviewed By: chuanhaozhuge

Differential Revision: D59012756

Pulled By: xuzhao9

fbshipit-source-id: 7052e544535f38bd7c8a1086e39cf80f7d518ae2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant