Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reenable test_all, fix _aliases.__all__ #286

Merged
merged 6 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
import inspect
from typing import NamedTuple, Optional, Sequence, Tuple, Union

from ._helpers import array_namespace, _check_device, device, is_cupy_namespace
from ._typing import Array, Device, DType, Namespace
from ._helpers import (
array_namespace,
_check_device,
device as _get_device,
is_cupy_namespace as _is_cupy_namespace
)


# These functions are modified from the NumPy versions.

Expand Down Expand Up @@ -298,7 +304,7 @@ def cumulative_sum(
initial_shape = list(x.shape)
initial_shape[axis] = 1
res = xp.concatenate(
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
axis=axis,
)
return res
Expand Down Expand Up @@ -328,7 +334,7 @@ def cumulative_prod(
initial_shape = list(x.shape)
initial_shape[axis] = 1
res = xp.concatenate(
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
axis=axis,
)
return res
Expand Down Expand Up @@ -381,7 +387,7 @@ def _isscalar(a):
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
max = None

dev = device(x)
dev = _get_device(x)
if out is None:
out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
out[()] = x
Expand Down Expand Up @@ -599,7 +605,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
out = xp.sign(x, **kwargs)
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
out[xp.isnan(x)] = xp.nan
return out[()]

Expand All @@ -611,3 +617,5 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
'unstack', 'sign']

_all_ignore = ['inspect', 'array_namespace', 'NamedTuple']
1 change: 1 addition & 0 deletions numpy-1-21-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
array_api_tests/test_signatures.py::test_func_signature[bitwise_or]
array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars

# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
Expand Down
1 change: 1 addition & 0 deletions numpy-1-26-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
array_api_tests/test_signatures.py::test_func_signature[bitwise_or]
array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars

# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
"SupportsBufferProtocol",
))

@pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277")
@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
def test_all(library):
if library == "common":
import array_api_compat.common # noqa: F401
else:
import_(library, wrapper=True)

for mod_name in sys.modules:
# NB: iterate over a copy to avoid a "dictionary size changed" error
for mod_name in sys.modules.copy():
if not mod_name.startswith('array_api_compat.' + library):
continue

Expand Down
6 changes: 4 additions & 2 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import sys
import warnings

import jax
import numpy as np
import pytest
import torch

import array_api_compat
from array_api_compat import array_namespace
Expand Down Expand Up @@ -76,6 +74,7 @@ def test_array_namespace(library, api_version, use_compat):
subprocess.run([sys.executable, "-c", code], check=True)

def test_jax_zero_gradient():
jax = import_("jax")
jx = jax.numpy.arange(4)
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
assert array_namespace(jax_zero) is array_namespace(jx)
Expand All @@ -89,11 +88,13 @@ def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))

def test_array_namespace_errors_torch():
torch = import_("torch")
y = torch.asarray([1, 2])
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))

def test_api_version_torch():
torch = import_("torch")
x = torch.asarray([1, 2])
torch_ = import_("torch", wrapper=True)
assert array_namespace(x, api_version="2023.12") == torch_
Expand All @@ -118,6 +119,7 @@ def test_get_namespace():
assert array_api_compat.get_namespace is array_namespace

def test_python_scalars():
torch = import_("torch")
a = torch.asarray([1, 2])
xp = import_("torch", wrapper=True)

Expand Down
8 changes: 6 additions & 2 deletions tests/test_dask.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from contextlib import contextmanager

import array_api_strict
import dask
import numpy as np
import pytest
import dask.array as da

try:
import dask
import dask.array as da
except ImportError:
pytestmark = pytest.skip(allow_module_level=True, reason="dask not found")

from array_api_compat import array_namespace

Expand Down
8 changes: 6 additions & 2 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import jax
import jax.numpy as jnp
from numpy.testing import assert_equal
import pytest

from array_api_compat import device, to_device

try:
import jax
import jax.numpy as jnp
except ImportError:
pytestmark = pytest.skip(allow_module_level=True, reason="jax not found")

HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31"


Expand Down
6 changes: 5 additions & 1 deletion tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import itertools

import pytest
import torch

try:
import torch
except ImportError:
pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found")

from array_api_compat import torch as xp

Expand Down
2 changes: 2 additions & 0 deletions tests/test_vendoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ def test_vendoring_cupy():


def test_vendoring_torch():
pytest.importorskip("torch")
from vendor_test import uses_torch

uses_torch._test_torch()


def test_vendoring_dask():
pytest.importorskip("dask")
from vendor_test import uses_dask
uses_dask._test_dask()
6 changes: 4 additions & 2 deletions torch-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum]
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum]

# https://github.com/pytorch/pytorch/issues/149815
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal]
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[neq]
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[not_equal]
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less]
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[les_equal]
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less_equal]
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater]
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal]

Expand Down
Loading