|
17 | 17 |
|
18 | 18 | import operator
|
19 | 19 | from enum import IntEnum
|
20 |
| -import warnings |
21 | 20 |
|
22 | 21 | from ._creation_functions import asarray
|
23 | 22 | from ._dtypes import (
|
|
32 | 31 | _result_type,
|
33 | 32 | _dtype_categories,
|
34 | 33 | )
|
| 34 | +from ._flags import get_array_api_strict_flags, set_array_api_strict_flags |
35 | 35 |
|
36 | 36 | from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex
|
37 | 37 | import types
|
@@ -427,13 +427,17 @@ def _validate_index(self, key):
|
427 | 427 | "the Array API)"
|
428 | 428 | )
|
429 | 429 | elif isinstance(i, Array):
|
430 |
| - if i.dtype in _boolean_dtypes and len(_key) != 1: |
431 |
| - assert isinstance(key, tuple) # sanity check |
432 |
| - raise IndexError( |
433 |
| - f"Single-axes index {i} is a boolean array and " |
434 |
| - f"{len(key)=}, but masking is only specified in the " |
435 |
| - "Array API when the array is the sole index." |
436 |
| - ) |
| 430 | + if i.dtype in _boolean_dtypes: |
| 431 | + if len(_key) != 1: |
| 432 | + assert isinstance(key, tuple) # sanity check |
| 433 | + raise IndexError( |
| 434 | + f"Single-axes index {i} is a boolean array and " |
| 435 | + f"{len(key)=}, but masking is only specified in the " |
| 436 | + "Array API when the array is the sole index." |
| 437 | + ) |
| 438 | + if not get_array_api_strict_flags()['data_dependent_shapes']: |
| 439 | + raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") |
| 440 | + |
437 | 441 | elif i.dtype in _integer_dtypes and i.ndim != 0:
|
438 | 442 | raise IndexError(
|
439 | 443 | f"Single-axes index {i} is a non-zero-dimensional "
|
@@ -482,10 +486,21 @@ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array:
|
482 | 486 | def __array_namespace__(
|
483 | 487 | self: Array, /, *, api_version: Optional[str] = None
|
484 | 488 | ) -> types.ModuleType:
|
485 |
| - if api_version is not None and api_version not in ["2021.12", "2022.12"]: |
486 |
| - raise ValueError(f"Unrecognized array API version: {api_version!r}") |
487 |
| - if api_version == "2021.12": |
488 |
| - warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") |
| 489 | + """ |
| 490 | + Return the array_api_strict namespace corresponding to api_version. |
| 491 | +
|
| 492 | + The default API version is '2022.12'. Note that '2021.12' is supported, |
| 493 | + but currently identical to '2022.12'. |
| 494 | +
|
| 495 | + For array_api_strict, calling this function with api_version will set |
| 496 | + the API version for the array_api_strict module globally. This can |
| 497 | + also be achieved with the |
| 498 | + {func}`array_api_strict.set_array_api_strict_flags` function. If you |
| 499 | + want to only set the version locally, use the |
| 500 | + {class}`array_api_strict.ArrayApiStrictFlags` context manager. |
| 501 | +
|
| 502 | + """ |
| 503 | + set_array_api_strict_flags(api_version=api_version) |
489 | 504 | import array_api_strict
|
490 | 505 | return array_api_strict
|
491 | 506 |
|
|
0 commit comments