Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use JAX's AbstractMesh in distribution lib #21027

Open
hertschuh opened this issue Mar 13, 2025 · 6 comments
Open

Use JAX's AbstractMesh in distribution lib #21027

hertschuh opened this issue Mar 13, 2025 · 6 comments

Comments

@hertschuh
Copy link
Collaborator

In the JAX distribution lib, use AbstractMesh instead of Mesh since it doesn't result in a JIT cache misses when the devices change. It may also simplify the distribution API.

@vedantag17
Copy link

Is anyone working on this issue?

@hertschuh
Copy link
Collaborator Author

@vedantag17

Is anyone working on this issue?

Not yet, no.

@vedantag17
Copy link

Oh, I'm a newbie in Keras, but here's my understanding:
The current function builds a concrete device array. In contrast, AbstractMesh requires an immutable tuple of (axis_name, axis_size) pairs.
Is this correct?

@hertschuh
Copy link
Collaborator Author

@vedantag17 , yes, that is correct.

@vedantag17
Copy link

@hertschuh can you review the PR, there are some failing checks.

@hertschuh
Copy link
Collaborator Author

@vedantag17

The truth is that it is more complicated than this. At some point you need to create a concrete Mesh object to actually do the distribution. The AbstractMesh part should only come into play in contexts where you use the Mesh.

Also, I'm realizing we're not ready to work on this right now because:

  • the JAX APIs related to AbstractMesh are still in flux and changing right now, so we at least need to wait until the next JAX release.
  • we're not using shard_map in Keras right now, which is one of the main reason for preferring AbstractMesh.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants