Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Stricter" option? #233

Open
awf opened this issue Jan 9, 2025 · 5 comments
Open

"Stricter" option? #233

awf opened this issue Jan 9, 2025 · 5 comments

Comments

@awf
Copy link

awf commented Jan 9, 2025

I would love to write code that is within the array API, but which still uses the GPU.

I get the "Avoid Restricting Behavior that is Outside the Scope of the Standard" clause in https://data-apis.org/array-api-compat/dev/special-considerations.html, but I would really prefer to replace xp.blah(..) with xp.back.blah(..) in order to mark places where I am going outside the standard.

I also understand that I could/should use array-api-strict in unit tests, and then use array-api-compat in the main code, but some tests are too slow to run on CPU, and for machine learning code, it's often quite hard to get good test coverage - it still seems valuable to get early indication that code I develop in PyTorch, say, has a good chance of porting to JAX.

For my use case, I would like a mode where torch/__init__.py replaces

from torch import * # noqa: F403

with

from torch import (  # noqa: F403
    abs, # Explicit list of known-compliant functions, excluding those defined in _aliases, _linalg etc
    acos,
    acosh,
    argmax,
...
    tile,
    trunc,
    uint8,
    zeros_like,
)

I understand that I might still inadvertently use getattr methods which are not in the API, but that is relatively easy to overcome (just avoid chaining and dot methods).

@ev-br
Copy link
Member

ev-br commented Jan 10, 2025

IIUC the classic approach would be to test on torch CPU + array-api-strict in some CI. This will ensure that the array api compliant code runs on torch GPU and jax GPU. Could you explain how this is limiting in your use case?

@awf
Copy link
Author

awf commented Jan 10, 2025

I'm mainly thinking of AI/ML use cases where conventional unit testing is hard - scaling down to a meaningful CPU-scale test can be time-consuming, perhaps as costly in time as just maintaining the port to another framework. Of course, one might argue that people who aren't willing to write proper unit tests will not care about cross-compatibility, but I don't think that's true. There is huge value in array-api, even for people who are writing "research code", which is undergoing high churn, and for which tests might lag the main codebase.

Secondly, I like to test the code I'm actually going to run - testing for accuracy on CPU when the code will run on GPU is fraught with difficulty. In such cases, we often have a GPU machine available for CI, so running a separate CPU test is an additional effort.

Thirdly, right now, I may want to use array_api_compat.size. It's the best solution to a problem which will probably not be solved for some time. Naturally, array-api-strict doesn't offer size (nor should it, but it means I can't write cross-framework code that tests against array-api-strict).

I might invert the question: assume that I have taken my current code, in framework X:

# foo, not portable
def foo(x):
  r = (x * x).sum(1).tanh()
  r = torch.max(r, -r)
  return (x.max() > 0)*r

And I have converted it to array-api, in order that I can run it on muliple frameworks (and get the readability advantages that a single source of documentation confers):

# foo, portable
def foo_portable(x):
  xp = array_api_compat.array_namespace(x)
  r = (x * x)
  r = xp.sum(r, axis=1)
  r = xp.tanh(r)
  r = xp.maximum(r, -r)
  return xp.astype(xp.max(r) > 0, xp.float32)*r

That was a certain amount of effort, which I achieved with some side-by-side running using array-api-strict.

Now, while experimenting, I decide to replace tanh with exp2

def foo_v2(x):
  xp = array_api_compat.array_namespace(x)
  r = (x * x)
  r = xp.sum(r, axis=1)
  r = xp.exp2(r)
  r = xp.maximum(r, -r)
  return xp.astype(xp.max(r) > 0, xp.float32)*r

In this proposed array-api-compat "stricter" mode, I will get an error telling me that xp.exp2 is not defined, which is helpful: I simply replace it with xp.backend.exp2 and proceed. I still have mostly-compatible code, which happens to work on PyTorch and JAX, and when I want to port it to "proper" array-api, I can easily see places where I have used backend.

Given the effort already expended in porting foo, the additional effort of adding .backend to some methods is tiny. It means I can continue to work until it's a good time to start polishing my pull request, at which point I can figure out how best to multi-framework exp2.

@ev-br
Copy link
Member

ev-br commented Feb 21, 2025

Sorry for dropping the ball here!

All in all all workflows you're discussing do sound exactly like a use case for array-api-strict. It indeed currently does two things: 1) check strict spec compilance (what you're after), and 2) delegate to numpy (which you say is a blocker).
The second part can indeed be generalized: array_api_strict object can grow a notion of a backend, and use, say, torch GPU instead of numpy. Would that work for your use case?

(Now that I think of this, I suspect that if you use torch.compile on a code which uses array-api-strict, it will compile thes internal np calls into triton kernels automatically :-)).

@awf
Copy link
Author

awf commented Feb 24, 2025

All in all all workflows you're discussing do sound exactly like a use case for array-api-strict. It indeed currently does two things: 1) check strict spec compilance (what you're after), and 2) delegate to numpy (which you say is a blocker).

I also want (3) a way to temporarily bypass strictness on a single call, while testing and porting. So, from the example above:

  r = xp.exp2(r)  # <--- Will (correctly) fail today on AAS

I would like to temporarily say

  r = xp.backend.exp2(r)  # <--- Passes the call through to whatever backend is running

The second part can indeed be generalized: array_api_strict object can grow a notion of a backend, and use, say, torch GPU instead of numpy. Would that work for your use case?

Yes, adding a "backend" option to array-api-strict would also work, assuming this is expected to mean that I could pass a torch.Tensor to AAS code.

I guess it seems easier to me to adapt AAC (I can post a very small PR) than to quite radically change AAS.

(Now that I think of this, I suspect that if you use torch.compile on a code which uses array-api-strict, it will compile thes internal np calls into triton kernels automatically :-)).

Yes, and that's nice, but doesn't answer the JAX question, and doesn't yet work as well as the handwritten kernels already in torch.

@ev-br
Copy link
Member

ev-br commented Feb 26, 2025

I also want (3) a way to temporarily bypass strictness on a single call, while testing and porting. So, from the example above: ...

At a danger of sounding like "your workflow is wrong" (which I don't mean by any means!), I am genuinely puzzled by a need to temporarily bypass strictness on a single call, as you say. I mean, if I'm looking at a given function call, it's either in the standard or not, so I can just look it up in the docs. What would be your workflow where this is essential?

I guess it seems easier to me to adapt AAC (I can post a very small PR) than to quite radically change AAS.

Sure, let's take a look at either of those PRs!
If it's a small patch, we'll see if it's something we want to carry it in -compat itself or keep as a patch or something. Looking forward to seeing it!


One note about making -compat stricter: we can only as much because we do not modify the array object itself. So its methods, dunders and all, are what they are in the library even if it deviates from the standard. in -strict though, we have an array object and we explicitly prohibit things which are not in the spec.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants