diff --git a/src/_array_api_conf.py b/src/_array_api_conf.py index e183082ec..539b55465 100644 --- a/src/_array_api_conf.py +++ b/src/_array_api_conf.py @@ -57,6 +57,7 @@ ('py:obj', "typing.Union[int, float, typing.Literal[inf, - inf]]"), ('py:class', 'enum.Enum'), ('py:class', 'ellipsis'), + ("py:class", "ArrayAPINamespace"), ] nitpick_ignore_regex = [ ('py:class', '.*array'), diff --git a/src/array_api_stubs/_draft/_namespace.py b/src/array_api_stubs/_draft/_namespace.py new file mode 100644 index 000000000..d68780b22 --- /dev/null +++ b/src/array_api_stubs/_draft/_namespace.py @@ -0,0 +1,11 @@ +__all__ = ["ArrayAPINamespace"] + +from typing import Protocol + +from .creation_functions import arange as ArangeCallable + + +class ArrayAPINamespace(Protocol): + """Protocol for the array API namespace itself.""" + + arange: ArangeCallable diff --git a/src/array_api_stubs/_draft/array_object.py b/src/array_api_stubs/_draft/array_object.py index cf6adcf3c..cba6d2f98 100644 --- a/src/array_api_stubs/_draft/array_object.py +++ b/src/array_api_stubs/_draft/array_object.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from ._types import ( array, dtype as Dtype, @@ -13,6 +15,9 @@ ellipsis, ) +if TYPE_CHECKING: + from ._namespace import ArrayAPINamespace + class _array: def __init__(self: array) -> None: @@ -195,7 +200,7 @@ def __and__(self: array, other: Union[int, bool, array], /) -> array: def __array_namespace__( self: array, /, *, api_version: Optional[str] = None - ) -> Any: + ) -> ArrayAPINamespace: """ Returns an object that has all the array API functions on it. diff --git a/src/array_api_stubs/_draft/creation_functions.py b/src/array_api_stubs/_draft/creation_functions.py index 42d6f9420..ee629142f 100644 --- a/src/array_api_stubs/_draft/creation_functions.py +++ b/src/array_api_stubs/_draft/creation_functions.py @@ -1,3 +1,4 @@ +from typing import Protocol from ._types import ( List, NestedSequence, @@ -11,15 +12,7 @@ ) -def arange( - start: Union[int, float], - /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, - *, - dtype: Optional[dtype] = None, - device: Optional[device] = None, -) -> array: +class arange(Protocol): """ Returns evenly spaced values within the half-open interval ``[start, stop)`` as a one-dimensional array. @@ -46,6 +39,18 @@ def arange( a one-dimensional array containing evenly spaced values. The length of the output array must be ``ceil((stop-start)/step)`` if ``stop - start`` and ``step`` have the same sign, and length ``0`` otherwise. """ + def __call__( + self, + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[dtype] = None, + device: Optional[device] = None, + ) -> array: + ... + def asarray( obj: Union[