Skip to content

Commit 490a3e1

Browse files
authored
Merge pull request #126 from asmeurer/use_compat
Don't wrap NumPy 2.0 at all
2 parents dbea61e + 3b90b87 commit 490a3e1

7 files changed

+70
-50
lines changed

array_api_compat/common/_helpers.py

+40-14
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _check_api_version(api_version):
178178
elif api_version is not None and api_version != '2022.12':
179179
raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
180180

181-
def array_namespace(*xs, api_version=None, _use_compat=True):
181+
def array_namespace(*xs, api_version=None, use_compat=None):
182182
"""
183183
Get the array API compatible namespace for the arrays `xs`.
184184
@@ -191,6 +191,12 @@ def array_namespace(*xs, api_version=None, _use_compat=True):
191191
The newest version of the spec that you need support for (currently
192192
the compat library wrapped APIs support v2022.12).
193193
194+
use_compat: bool or None
195+
If None (the default), the native namespace will be returned if it is
196+
already array API compatible, otherwise a compat wrapper is used. If
197+
True, the compat library wrapped library will be returned. If False,
198+
the native library namespace is returned.
199+
194200
Returns
195201
-------
196202
@@ -234,46 +240,66 @@ def your_function(x, y):
234240
is_jax_array
235241
236242
"""
243+
if use_compat not in [None, True, False]:
244+
raise ValueError("use_compat must be None, True, or False")
245+
246+
_use_compat = use_compat in [None, True]
247+
237248
namespaces = set()
238249
for x in xs:
239250
if is_numpy_array(x):
240-
_check_api_version(api_version)
241-
if _use_compat:
242-
from .. import numpy as numpy_namespace
251+
from .. import numpy as numpy_namespace
252+
import numpy as np
253+
if use_compat is True:
254+
_check_api_version(api_version)
243255
namespaces.add(numpy_namespace)
244-
else:
245-
import numpy as np
256+
elif use_compat is False:
246257
namespaces.add(np)
258+
else:
259+
# numpy 2.0 has __array_namespace__ and is fully array API
260+
# compatible.
261+
if hasattr(x, '__array_namespace__'):
262+
namespaces.add(x.__array_namespace__(api_version=api_version))
263+
else:
264+
namespaces.add(numpy_namespace)
247265
elif is_cupy_array(x):
248-
_check_api_version(api_version)
249266
if _use_compat:
267+
_check_api_version(api_version)
250268
from .. import cupy as cupy_namespace
251269
namespaces.add(cupy_namespace)
252270
else:
253271
import cupy as cp
254272
namespaces.add(cp)
255273
elif is_torch_array(x):
256-
_check_api_version(api_version)
257274
if _use_compat:
275+
_check_api_version(api_version)
258276
from .. import torch as torch_namespace
259277
namespaces.add(torch_namespace)
260278
else:
261279
import torch
262280
namespaces.add(torch)
263281
elif is_dask_array(x):
264-
_check_api_version(api_version)
265282
if _use_compat:
283+
_check_api_version(api_version)
266284
from ..dask import array as dask_namespace
267285
namespaces.add(dask_namespace)
268286
else:
269-
raise TypeError("_use_compat cannot be False if input array is a dask array!")
287+
import dask.array as da
288+
namespaces.add(da)
270289
elif is_jax_array(x):
271-
_check_api_version(api_version)
272-
# jax.experimental.array_api is already an array namespace. We do
273-
# not have a wrapper submodule for it.
274-
import jax.experimental.array_api as jnp
290+
if use_compat is True:
291+
_check_api_version(api_version)
292+
raise ValueError("JAX does not have an array-api-compat wrapper")
293+
elif use_compat is False:
294+
import jax.numpy as jnp
295+
else:
296+
# jax.experimental.array_api is already an array namespace. We do
297+
# not have a wrapper submodule for it.
298+
import jax.experimental.array_api as jnp
275299
namespaces.add(jnp)
276300
elif hasattr(x, '__array_namespace__'):
301+
if use_compat is True:
302+
raise ValueError("The given array does not have an array-api-compat wrapper")
277303
namespaces.add(x.__array_namespace__(api_version=api_version))
278304
else:
279305
# TODO: Support Python scalars?

array_api_compat/numpy/_aliases.py

-5
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,6 @@ def asarray(
9494
See the corresponding documentation in the array library and/or the array API
9595
specification for more details.
9696
"""
97-
if np.__version__[0] >= '2':
98-
# NumPy 2.0 asarray() is completely array API compatible. No need for
99-
# the complicated logic below
100-
return np.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)
101-
10297
if device not in ["cpu", None]:
10398
raise ValueError(f"Unsupported device for NumPy: {device!r}")
10499

numpy-dev-xfails.txt

-11
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,6 @@
1-
# asarray(copy=False) is not yet implemented
2-
array_api_tests/test_creation_functions.py::test_asarray_arrays
3-
41
# finfo(float32).eps returns float32 but should return float
52
array_api_tests/test_data_type_functions.py::test_finfo[float32]
63

7-
# Array methods and attributes not already on np.ndarray cannot be wrapped
8-
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
9-
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
10-
11-
# linalg tests require cleanups
12-
# https://github.com/data-apis/array-api-tests/pull/101
13-
array_api_tests/test_linalg.py::test_solve
14-
154
# NumPy deviates in some special cases for floordiv
165
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
176
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]

tests/_helpers.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import pytest
44

5-
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
6-
all_libraries = wrapped_libraries + ["jax.numpy"]
5+
wrapped_libraries = ["cupy", "torch", "dask.array"]
6+
all_libraries = wrapped_libraries + ["numpy", "jax.numpy"]
7+
import numpy as np
8+
if np.__version__[0] == '1':
9+
wrapped_libraries.append("numpy")
710

811
def import_(library, wrapper=False):
912
if library == 'cupy':

tests/test_array_namespace.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,29 @@
99
import array_api_compat
1010
from array_api_compat import array_namespace
1111

12-
from ._helpers import import_, all_libraries
12+
from ._helpers import import_, all_libraries, wrapped_libraries
1313

14-
@pytest.mark.parametrize("library", all_libraries)
15-
@pytest.mark.parametrize("api_version", [None, "2021.12"])
16-
def test_array_namespace(library, api_version):
14+
@pytest.mark.parametrize("use_compat", [True, False, None])
15+
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"])
16+
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
17+
def test_array_namespace(library, api_version, use_compat):
1718
xp = import_(library)
1819

1920
array = xp.asarray([1.0, 2.0, 3.0])
20-
namespace = array_api_compat.array_namespace(array, api_version=api_version)
21+
if use_compat is True and library in ['array_api_strict', 'jax.numpy']:
22+
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
23+
return
24+
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
2125

22-
if "array_api" in library:
23-
assert namespace == xp
26+
if use_compat is False or use_compat is None and library not in wrapped_libraries:
27+
if library == "jax.numpy" and use_compat is None:
28+
import jax.experimental.array_api
29+
assert namespace == jax.experimental.array_api
30+
else:
31+
assert namespace == xp
2432
else:
2533
if library == "dask.array":
2634
assert namespace == array_api_compat.dask.array
27-
elif library == "jax.numpy":
28-
import jax.experimental.array_api
29-
assert namespace == jax.experimental.array_api
3035
else:
3136
assert namespace == getattr(array_api_compat, library)
3237

@@ -64,14 +69,14 @@ def test_array_namespace_errors_torch():
6469
pytest.raises(TypeError, lambda: array_namespace(x, y))
6570

6671
def test_api_version():
67-
x = np.asarray([1, 2])
68-
np_ = import_("numpy", wrapper=True)
69-
assert array_namespace(x, api_version="2022.12") == np_
70-
assert array_namespace(x, api_version=None) == np_
71-
assert array_namespace(x) == np_
72+
x = torch.asarray([1, 2])
73+
torch_ = import_("torch", wrapper=True)
74+
assert array_namespace(x, api_version="2022.12") == torch_
75+
assert array_namespace(x, api_version=None) == torch_
76+
assert array_namespace(x) == torch_
7277
# Should issue a warning
7378
with warnings.catch_warnings(record=True) as w:
74-
assert array_namespace(x, api_version="2021.12") == np_
79+
assert array_namespace(x, api_version="2021.12") == torch_
7580
assert len(w) == 1
7681
assert "2021.12" in str(w[0].message)
7782

tests/test_no_dependencies.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
import sys
1212
import subprocess
1313

14-
from ._helpers import import_
15-
1614
import pytest
1715

1816
class Array:
@@ -54,6 +52,9 @@ def _test_dependency(mod):
5452
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
5553
"jax.numpy", "array_api_strict"])
5654
def test_numpy_dependency(library):
55+
# This import is here because it imports numpy
56+
from ._helpers import import_
57+
5758
# This unfortunately won't go through any of the pytest machinery. We
5859
# reraise the exception as an AssertionError so that pytest will show it
5960
# in a semi-reasonable way

torch-xfails.txt

+1
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ array_api_tests/test_statistical_functions.py::test_var
183183
# The test suite is incorrectly checking sums that have loss of significance
184184
# (https://github.com/data-apis/array-api-tests/issues/168)
185185
array_api_tests/test_statistical_functions.py::test_sum
186+
array_api_tests/test_statistical_functions.py::test_prod
186187

187188
# These functions do not yet support complex numbers
188189
array_api_tests/test_operators_and_elementwise_functions.py::test_sign

0 commit comments

Comments
 (0)