diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 46cbb359..351b5bd6 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -18,7 +18,7 @@ # These functions are modified from the NumPy versions. -# Creation functions add the device keyword (which does nothing for NumPy) +# Creation functions add the device keyword (which does nothing for NumPy and Dask) def arange( start: Union[int, float], diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py index 790621e4..78e48a33 100644 --- a/array_api_compat/cupy/_info.py +++ b/array_api_compat/cupy/_info.py @@ -26,6 +26,7 @@ complex128, ) + class __array_namespace_info__: """ Get the array API inspection namespace for CuPy. @@ -49,7 +50,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': cupy.float64, 'complex floating': cupy.complex128, @@ -94,13 +95,13 @@ def capabilities(self): >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -117,7 +118,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new CuPy arrays. Examples @@ -126,6 +127,15 @@ def default_device(self): >>> info.default_device() Device(0) + Notes + ----- + This method returns the static default device when CuPy is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed globally or with a context manager. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return cuda.Device(0) @@ -312,7 +322,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by CuPy. See Also diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index fc70b5a2..614f43d9 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -50,7 +50,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -103,10 +103,11 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { @@ -130,12 +131,12 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new Dask arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() 'cpu' @@ -173,7 +174,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -239,7 +240,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': dask.int8, 'int16': dask.int16, @@ -335,7 +336,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by Dask. See Also @@ -347,7 +348,7 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() ['cpu', DASK_DEVICE] diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index e706d118..365855b8 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -94,13 +94,13 @@ def capabilities(self): >>> info = np.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -119,7 +119,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new NumPy arrays. Examples @@ -326,7 +326,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by NumPy. See Also diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 34fbcb21..818e5d37 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -34,7 +34,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': numpy.float64, 'complex floating': numpy.complex128, @@ -76,16 +76,16 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -102,15 +102,24 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new PyTorch arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() - 'cpu' + device(type='cpu') + Notes + ----- + This method returns the static default device when PyTorch is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed at runtime. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return torch.device("cpu") @@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None): Parameters ---------- - device : str, optional - The device to get the default data types for. For PyTorch, only - ``'cpu'`` is allowed. + device : Device, optional + The device to get the default data types for. + Unused for PyTorch, as all devices use the same default dtypes. Returns ------- @@ -139,7 +148,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': torch.float32, 'complex floating': torch.complex64, @@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None): Parameters ---------- - device : str, optional + device : Device, optional The device to get the data types for. + Unused for PyTorch, as all devices use the same dtypes. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned. @@ -287,7 +297,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': numpy.int8, 'int16': numpy.int16, @@ -310,7 +320,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by PyTorch. See Also @@ -322,7 +332,7 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() [device(type='cpu'), device(type='mps', index=0), device(type='meta')] @@ -333,6 +343,7 @@ def devices(self): # device: try: torch.device('notadevice') + raise AssertionError("unreachable") # pragma: nocover except RuntimeError as e: # The error message is something like: # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"