Skip to content

Commit 65fd5f4

Browse files
committed
BUG: Don't import helpers in namespaces
1 parent b6900df commit 65fd5f4

File tree

7 files changed

+19
-22
lines changed

7 files changed

+19
-22
lines changed

array_api_compat/common/_linalg.py

+2
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,5 @@ def trace(
174174
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
175175
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
176176
'trace']
177+
178+
_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']

array_api_compat/cupy/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88

99
# See the comment in the numpy __init__.py
1010
__import__(__package__ + '.linalg')
11-
1211
__import__(__package__ + '.fft')
1312

14-
from ..common._helpers import * # noqa: F401,F403
15-
1613
__array_api_version__ = '2024.12'

array_api_compat/dask/array/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55

66
__array_api_version__ = '2024.12'
77

8+
# See the comment in the numpy __init__.py
89
__import__(__package__ + '.linalg')
910
__import__(__package__ + '.fft')

array_api_compat/numpy/__init__.py

-9
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,8 @@
1414
# It doesn't overwrite np.linalg from above. The import is generated
1515
# dynamically so that the library can be vendored.
1616
__import__(__package__ + '.linalg')
17-
1817
__import__(__package__ + '.fft')
1918

2019
from .linalg import matrix_transpose, vecdot # noqa: F401
2120

22-
from ..common._helpers import * # noqa: F403
23-
24-
try:
25-
# Used in asarray(). Not present in older versions.
26-
from numpy import _CopyMode # noqa: F401
27-
except ImportError:
28-
pass
29-
3021
__array_api_version__ = '2024.12'

array_api_compat/numpy/_aliases.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010

1111
import numpy as np
1212

13+
try:
14+
# Note: On NumPy 1.26, without this line `hasattr(np, '_CopyMode') returns False
15+
from numpy import _CopyMode
16+
except ImportError:
17+
_CopyMode = None
18+
1319
bool = np.bool_
1420

1521
# Basic renames
@@ -86,7 +92,7 @@ def asarray(
8692
*,
8793
dtype: Optional[DType] = None,
8894
device: Optional[Device] = None,
89-
copy: "Optional[Union[bool, np._CopyMode]]" = None,
95+
copy: Optional[Union[bool, np._CopyMode]] = None,
9096
**kwargs,
9197
) -> Array:
9298
"""
@@ -98,13 +104,13 @@ def asarray(
98104
if device not in ["cpu", None]:
99105
raise ValueError(f"Unsupported device for NumPy: {device!r}")
100106

101-
if hasattr(np, '_CopyMode'):
107+
if _CopyMode is not None:
102108
if copy is None:
103-
copy = np._CopyMode.IF_NEEDED
109+
copy = _CopyMode.IF_NEEDED
104110
elif copy is False:
105-
copy = np._CopyMode.NEVER
111+
copy = _CopyMode.NEVER
106112
elif copy is True:
107-
copy = np._CopyMode.ALWAYS
113+
copy = _CopyMode.ALWAYS
108114
else:
109115
# Not present in older NumPys. In this case, we cannot really support
110116
# copy=False.

array_api_compat/torch/__init__.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,14 @@
99
or 'cpu' in n
1010
or 'backward' in n):
1111
continue
12-
exec(n + ' = torch.' + n)
12+
exec(f"{n} = torch.{n}")
13+
del n
1314

1415
# These imports may overwrite names from the import * above.
1516
from ._aliases import * # noqa: F403
1617

1718
# See the comment in the numpy __init__.py
1819
__import__(__package__ + '.linalg')
19-
2020
__import__(__package__ + '.fft')
2121

22-
from ..common._helpers import * # noqa: F403
23-
2422
__array_api_version__ = '2024.12'

tests/test_common.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
)
2020
from ._helpers import all_libraries, import_, wrapped_libraries, xfail
2121

22+
from array_api_compat.numpy._aliases import _CopyMode
23+
2224

2325
is_array_functions = {
2426
'numpy': 'is_numpy_array',
@@ -276,7 +278,7 @@ def test_asarray_copy(library):
276278
is_lib_func = globals()[is_array_functions[library]]
277279
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
278280

279-
if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
281+
if library == 'numpy' and xp.__version__[0] < '2' and _CopyMode is None:
280282
supports_copy_false_other_ns = False
281283
supports_copy_false_same_ns = False
282284
elif library == 'cupy':

0 commit comments

Comments
 (0)