-
Notifications
You must be signed in to change notification settings - Fork 19.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use JAX's AbstractMesh in distribution lib #21027
Comments
Is anyone working on this issue? |
Not yet, no. |
Oh, I'm a newbie in Keras, but here's my understanding: |
@vedantag17 , yes, that is correct. |
@hertschuh can you review the PR, there are some failing checks. |
The truth is that it is more complicated than this. At some point you need to create a concrete Mesh object to actually do the distribution. The Also, I'm realizing we're not ready to work on this right now because:
|
In the JAX distribution lib, use AbstractMesh instead of Mesh since it doesn't result in a JIT cache misses when the devices change. It may also simplify the distribution API.
The text was updated successfully, but these errors were encountered: