Skip to content

Commit 7aaeda2

Browse files
authored
Merge pull request #30 from asmeurer/flags
array-api-strict flags
2 parents d9a4fe5 + f92b497 commit 7aaeda2

15 files changed

+744
-70
lines changed

Diff for: array_api_strict/__init__.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
1717
"""
1818

19-
__array_api_version__ = "2022.12"
19+
# Warning: __array_api_version__ could change globally with
20+
# set_array_api_strict_flags(). This should always be accessed as an
21+
# attribute, like xp.__array_api_version__, or using
22+
# array_api_strict.get_array_api_strict_flags()['api_version'].
23+
from ._flags import API_VERSION as __array_api_version__
2024

2125
__all__ = ["__array_api_version__"]
2226

@@ -244,7 +248,7 @@
244248

245249
__all__ += ["linalg"]
246250

247-
from .linalg import matmul, tensordot, matrix_transpose, vecdot
251+
from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
248252

249253
__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
250254

@@ -284,6 +288,17 @@
284288

285289
__all__ += ["all", "any"]
286290

291+
# Helper functions that are not part of the standard
292+
293+
from ._flags import (
294+
set_array_api_strict_flags,
295+
get_array_api_strict_flags,
296+
reset_array_api_strict_flags,
297+
ArrayAPIStrictFlags,
298+
)
299+
300+
__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayAPIStrictFlags']
301+
287302
from . import _version
288303
__version__ = _version.get_versions()['version']
289304
del _version

Diff for: array_api_strict/_array_object.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import operator
1919
from enum import IntEnum
20-
import warnings
2120

2221
from ._creation_functions import asarray
2322
from ._dtypes import (
@@ -32,6 +31,7 @@
3231
_result_type,
3332
_dtype_categories,
3433
)
34+
from ._flags import get_array_api_strict_flags, set_array_api_strict_flags
3535

3636
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex
3737
import types
@@ -427,13 +427,17 @@ def _validate_index(self, key):
427427
"the Array API)"
428428
)
429429
elif isinstance(i, Array):
430-
if i.dtype in _boolean_dtypes and len(_key) != 1:
431-
assert isinstance(key, tuple) # sanity check
432-
raise IndexError(
433-
f"Single-axes index {i} is a boolean array and "
434-
f"{len(key)=}, but masking is only specified in the "
435-
"Array API when the array is the sole index."
436-
)
430+
if i.dtype in _boolean_dtypes:
431+
if len(_key) != 1:
432+
assert isinstance(key, tuple) # sanity check
433+
raise IndexError(
434+
f"Single-axes index {i} is a boolean array and "
435+
f"{len(key)=}, but masking is only specified in the "
436+
"Array API when the array is the sole index."
437+
)
438+
if not get_array_api_strict_flags()['data_dependent_shapes']:
439+
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
440+
437441
elif i.dtype in _integer_dtypes and i.ndim != 0:
438442
raise IndexError(
439443
f"Single-axes index {i} is a non-zero-dimensional "
@@ -482,10 +486,21 @@ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array:
482486
def __array_namespace__(
483487
self: Array, /, *, api_version: Optional[str] = None
484488
) -> types.ModuleType:
485-
if api_version is not None and api_version not in ["2021.12", "2022.12"]:
486-
raise ValueError(f"Unrecognized array API version: {api_version!r}")
487-
if api_version == "2021.12":
488-
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
489+
"""
490+
Return the array_api_strict namespace corresponding to api_version.
491+
492+
The default API version is '2022.12'. Note that '2021.12' is supported,
493+
but currently identical to '2022.12'.
494+
495+
For array_api_strict, calling this function with api_version will set
496+
the API version for the array_api_strict module globally. This can
497+
also be achieved with the
498+
{func}`array_api_strict.set_array_api_strict_flags` function. If you
499+
want to only set the version locally, use the
500+
{class}`array_api_strict.ArrayApiStrictFlags` context manager.
501+
502+
"""
503+
set_array_api_strict_flags(api_version=api_version)
489504
import array_api_strict
490505
return array_api_strict
491506

0 commit comments

Comments
 (0)