From 26845bd904ee66bb830463f46bb39f1cc5392275 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 11:12:31 +0100 Subject: [PATCH 1/6] Revert "TST: skip test_all" This reverts commit 5473d84d5c36b23e091b880279c863c32f41b828. --- tests/test_all.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_all.py b/tests/test_all.py index 598fab62..eeb67e4b 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -26,7 +26,6 @@ "SupportsBufferProtocol", )) -@pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) def test_all(library): if library == "common": From 07a3cd41e1c5804b7c11d358400431e8a53a984a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 11:40:02 +0100 Subject: [PATCH 2/6] MAINT: run self-tests even if a library is missing --- tests/test_array_namespace.py | 6 ++++-- tests/test_dask.py | 8 ++++++-- tests/test_jax.py | 8 ++++++-- tests/test_torch.py | 6 +++++- tests/test_vendoring.py | 2 ++ 5 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 605c69a1..cdb80007 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -2,10 +2,8 @@ import sys import warnings -import jax import numpy as np import pytest -import torch import array_api_compat from array_api_compat import array_namespace @@ -76,6 +74,7 @@ def test_array_namespace(library, api_version, use_compat): subprocess.run([sys.executable, "-c", code], check=True) def test_jax_zero_gradient(): + jax = import_("jax") jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) assert array_namespace(jax_zero) is array_namespace(jx) @@ -89,11 +88,13 @@ def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) def test_array_namespace_errors_torch(): + torch = import_("torch") y = torch.asarray([1, 2]) x = np.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) def test_api_version_torch(): + torch = import_("torch") x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) assert array_namespace(x, api_version="2023.12") == torch_ @@ -118,6 +119,7 @@ def test_get_namespace(): assert array_api_compat.get_namespace is array_namespace def test_python_scalars(): + torch = import_("torch") a = torch.asarray([1, 2]) xp = import_("torch", wrapper=True) diff --git a/tests/test_dask.py b/tests/test_dask.py index be2b1e39..69c738f6 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,10 +1,14 @@ from contextlib import contextmanager import array_api_strict -import dask import numpy as np import pytest -import dask.array as da + +try: + import dask + import dask.array as da +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="dask not found") from array_api_compat import array_namespace diff --git a/tests/test_jax.py b/tests/test_jax.py index e33cec02..285958d4 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,10 +1,14 @@ -import jax -import jax.numpy as jnp from numpy.testing import assert_equal import pytest from array_api_compat import device, to_device +try: + import jax + import jax.numpy as jnp +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="jax not found") + HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31" diff --git a/tests/test_torch.py b/tests/test_torch.py index 75b3a136..e8340f31 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -3,7 +3,11 @@ import itertools import pytest -import torch + +try: + import torch +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found") from array_api_compat import torch as xp diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 70083b49..8b561551 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -16,11 +16,13 @@ def test_vendoring_cupy(): def test_vendoring_torch(): + pytest.importorskip("torch") from vendor_test import uses_torch uses_torch._test_torch() def test_vendoring_dask(): + pytest.importorskip("dask") from vendor_test import uses_dask uses_dask._test_dask() From 89466a6b43672b9a4a2dbdaea2896c24e4dcdd76 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 14:01:44 +0100 Subject: [PATCH 3/6] MAINT: common._aliases.__all__ --- array_api_compat/common/_aliases.py | 18 +++++++++++++----- tests/test_all.py | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 0d123b99..0d1ecfbc 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -7,8 +7,14 @@ import inspect from typing import NamedTuple, Optional, Sequence, Tuple, Union -from ._helpers import array_namespace, _check_device, device, is_cupy_namespace from ._typing import Array, Device, DType, Namespace +from ._helpers import ( + array_namespace, + _check_device, + device as _get_device, + is_cupy_namespace as _is_cupy_namespace +) + # These functions are modified from the NumPy versions. @@ -298,7 +304,7 @@ def cumulative_sum( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], axis=axis, ) return res @@ -328,7 +334,7 @@ def cumulative_prod( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], axis=axis, ) return res @@ -381,7 +387,7 @@ def _isscalar(a): if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: max = None - dev = device(x) + dev = _get_device(x) if out is None: out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) out[()] = x @@ -599,7 +605,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): + if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): out[xp.isnan(x)] = xp.nan return out[()] @@ -611,3 +617,5 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', 'unstack', 'sign'] + +_all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] diff --git a/tests/test_all.py b/tests/test_all.py index eeb67e4b..4df4a361 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -33,7 +33,7 @@ def test_all(library): else: import_(library, wrapper=True) - for mod_name in sys.modules: + for mod_name in sys.modules.copy(): if not mod_name.startswith('array_api_compat.' + library): continue From 23841dfdb319fbb66a4065e0c138235c56e611f0 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Mar 2025 09:28:03 +0100 Subject: [PATCH 4/6] TST: update the torch skiplist --- torch-xfails.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch-xfails.txt b/torch-xfails.txt index 6e8f7dc6..f8333d90 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -144,10 +144,12 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] + +# https://github.com/pytorch/pytorch/issues/149815 array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[neq] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[not_equal] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less] -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[les_equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less_equal] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal] From 3b4ea593d43c3d522aa1e601a93781774606bbc3 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Mar 2025 09:33:26 +0100 Subject: [PATCH 5/6] TST: update numpy<2 skiplists --- numpy-1-21-xfails.txt | 1 + numpy-1-26-xfails.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 7c7a0757..30cde668 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -192,6 +192,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 57259b6f..1ce28ef4 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -46,6 +46,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] From 71d90ead399c03f5fcbc15d205d7cedb6bc9825c Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sun, 30 Mar 2025 09:19:56 +0100 Subject: [PATCH 6/6] Update test_all.py Co-authored-by: Evgeni Burovski --- tests/test_all.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_all.py b/tests/test_all.py index 4df4a361..271cd189 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -33,6 +33,7 @@ def test_all(library): else: import_(library, wrapper=True) + # NB: iterate over a copy to avoid a "dictionary size changed" error for mod_name in sys.modules.copy(): if not mod_name.startswith('array_api_compat.' + library): continue