Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify definitions of "default device" and "current device" #835

Open
asmeurer opened this issue Aug 15, 2024 · 9 comments
Open

Clarify definitions of "default device" and "current device" #835

asmeurer opened this issue Aug 15, 2024 · 9 comments
Labels
Narrative Content Narrative documentation content. topic: Device Handling Device handling.
Milestone

Comments

@asmeurer
Copy link
Member

The inspection API mentions "default device" in several places, for instance

However, as far as I can tell, this term is never actually defined anywhere. I thought it might be defined at https://data-apis.org/array-api/latest/design_topics/device_support.html#device-support (or at https://data-apis.org/array-api/latest/purpose_and_scope.html#terms-and-definitions), but it doesn't seem to be. Are there APIs where the default device should be used? I thought this would be the case for creation functions, but the phrase "default device" never appears (e.g., at https://data-apis.org/array-api/latest/API_specification/generated/array_api.empty.html#array_api.empty). Presumably this is what is meant by default device, but are there other instances where "default device" should be used.

Another related concept that appears in some places but is never really defined is "current device" (for example, dtypes() should use the "current device" when device=None https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.dtypes.html#array_api.info.dtypes). It would be helpful to clarify the difference between "current device" and "default device" and when one should be used and when the other should be used.

Actually the creation functions never really state clearly what should happen when device=None, except for asarray and the *_like functions, which say the device should be inferred from x.

My understanding is that for libraries with a context manager, the "current device" is the device set in the current context, whereas the "default device" is the device used when no context is set (the "current device" would be the same as the "default device" in this case). Is this correct? By this reasoning creation functions should actually use the "current device" when device=None, not the "default device". But is it also true that default_device() should always return the "default device" regardless of the current context? It seems like this would be less useful than "current device". Do we need a separate current_device() inspection API?

Or is it actually the case that these terms were both meant to mean the same thing (i.e., both mean the current device used by default in the current context)? If that's the case, then the note at https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.default_dtypes.html#array_api.info.default_dtypes doesn't make much sense to me. And if so, we should unify this terminology to avoid confusion (probably the term "default device" should be preferred since that is the name of the inspection function).

@kgryte kgryte added Narrative Content Narrative documentation content. topic: Device Handling Device handling. labels Aug 15, 2024
@kgryte kgryte added this to the v2024 milestone Aug 15, 2024
@asmeurer
Copy link
Member Author

Actually maybe current_device() isn't really necessary, since None means "current device" in all APIs that accept a device keyword.

@rgommers
Copy link
Member

Are there APIs where the default device should be used? I thought this would be the case for creation functions, but the phrase "default device" never appears

Yes, creation functions indeed. The empty docs you linked are a good example, the description for the device keyword now says: "device on which to place the created array. Default: None." The "what does None mean" part is implicit here. I agree it'd be better to spell out what should happen.

My understanding is that for libraries with a context manager, the "current device" is the device set in the current context, whereas the "default device" is the device used when no context is set (the "current device" would be the same as the "default device" in this case). Is this correct?

Correct indeed. May also be global state rather than a context manager (e.g., PyTorch has both - the global one is torch.set_default_device).

By this reasoning creation functions should actually use the "current device" when device=None, not the "default device". But is it also true that default_device() should always return the "default device" regardless of the current context? It seems like this would be less useful than "current device". Do we need a separate current_device() inspection API?

Yes agreed, creation functions will use the current device. Both may be useful to inspect indeed. I believe we didn't consider this in detail because the standard has no way to set the current device, however I think that shouldn't stop us from adding an inspection function.

Actually maybe current_device() isn't really necessary, since None means "current device" in all APIs that accept a device keyword.

It's not necessary indeed, it'd just be nicer syntax for empty((1,)).device.

@asmeurer
Copy link
Member Author

It's also worth noting that the "default real floating-point dtype" (another notion from the same APIs) can change at runtime in PyTorch using the set_default_dtype function. At data-apis/array-api-compat#166 I've made default_dtypes()['real floating'] return torch.get_default_dtype(), i.e., get the value dynamically. That seems the most obviously useful, but it also disagrees with the notion of "default" for devices.

@kgryte kgryte modified the milestones: v2024, v2025 Jan 23, 2025
@crusaderky
Copy link
Contributor

crusaderky commented Mar 31, 2025

While the value of knowing the current device used when one passes device=None to creation functions is obvious, to me the concept of "the device when the library was initialised, before any runtime changes, environment variables, or context managers were ever applied" feels useless and confusing. I can't think of any use case where it is useful to know it?

FWIW, I just fell into this misunderstanding while writing scipy/scipy#22756.

@crusaderky
Copy link
Contributor

crusaderky commented Mar 31, 2025

JAX has interpreted default_device() as "whatever the device= kwarg will accept", and returns None.
Which notably would cause this to fail:

d = xp.__array_namespace_info__().default_device()
a = xp.asarray(0, device=d)
assert a.device == d

This will also fail:

assert d in xp.__array_namespace_info__().devices

I think it's reasonable to say that a backend offers no guarantee of using the same device when one isn't explicitly pinned, but such a quirk should be spelled out by the Standard.

@jakevdp
Copy link

jakevdp commented Mar 31, 2025

JAX has interpreted default_device() as "whatever the device= kwarg will accept", and returns None.

We were doing our best to interpret https://data-apis.org/array-api/2023.12/design_topics/device_support.html in a way that was compatible with JAX's pre-existing array creation and device placement semantics, where the default device=None means "uncommitted to any particular device and the compiler is free to choose appropriate device placement for the computation". Given this context, returning None from default_device seemed more accurate than returning some device object that may not actually be used in the default case if the compiler chose differently.

Please let me know if there's something in the specification that we've misinterpreted.

@crusaderky
Copy link
Contributor

crusaderky commented Apr 1, 2025

@jakevdp, when you say

uncommitted to any particular device and the compiler is free to choose appropriate device placement for the computation

I'm reading the JAX documentation and it gives me the impression that when you call a creation function (empty, asarray, ones, etc.) the array will be created on whatever the user set with jax.default_device, but they are uncommitted, which means that if they're put in a binop with another array on a different device, they're free to be queitly transferred over. Did I interpret it correctly? Or are there cases where jnp.empty(..., device=None) will not honour jax.default_device, e.g. in case of high memory pressure? If yes, does it mean it can appear on another device on the same backend, or on any backend?

@jakevdp
Copy link

jakevdp commented Apr 1, 2025

I'm reading the JAX documentation and it gives me the impression that when you call a creation function (empty, asarray, ones, etc.) the array will be created on whatever the user set with jax.default_device

This is true to an extent: when executing eagerly (i.e. outside JIT) the buffer behind the array has to be allocated somewhere, and if device is not specified it will be allocated on the default device. It is still "uncommitted", however, in the sense that the compiler is free to move that buffer later if it is advantageous to do so. This is different than if you had explicitly specified device=default_device.

Or are there cases where jnp.empty(..., device=None) will not honour jax.default_device, e.g. in case of high memory pressure?

Within JIT-compiled code, there is not necessarily any physical buffer associated with a python array variable, and so the device specification is meaningless and will be ignored. If a buffer is in fact needed, the compiler will allocate bytes on the device that makes sense in the context of the overall compiled computation. That could be the same backend as the default or different, depending on the nature of the compiled code, the parameters (e.g. sharding specifications) passed to the jax.jit call itself, and the device placement of the arrays passed to the compiled function. Is that what you have in mind?

@jakevdp
Copy link

jakevdp commented Apr 1, 2025

I should add: this is one reason that prior to the array API standard, we never implemented a device argument for array creation functions at the jax.numpy level. In general it's the wrong API level for that kind of specification when it comes to JAX code. But we wanted to be compliant with the standard, so we added device logic that was (to the best of our ability) consistent with both (1) how JAX works and (2) how the Array API specification conceives things.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Narrative Content Narrative documentation content. topic: Device Handling Device handling.
Projects
None yet
Development

No branches or pull requests

5 participants