diff --git a/.github/workflows/array-api-tests-paddle.yml b/.github/workflows/array-api-tests-paddle.yml new file mode 100644 index 00000000..d4f88b00 --- /dev/null +++ b/.github/workflows/array-api-tests-paddle.yml @@ -0,0 +1,11 @@ +name: Array API Tests (Paddle Latest) + +on: [push, pull_request] + +jobs: + array-api-tests-paddle: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: paddle + extra-env-vars: | + ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 diff --git a/README.md b/README.md index 4b0b0c9c..5c30919d 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ This is a small wrapper around common array libraries that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). Currently, -NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want +NumPy, CuPy, PyTorch, Dask, JAX, ndonnx, `sparse` and Paddle are supported. If you want support for other array libraries, or if you encounter any issues, please [open an issue](https://github.com/data-apis/array-api-compat/issues). -See the documentation for more details https://data-apis.org/array-api-compat/ +See the documentation for more details diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index b011f08d..ec6b3e0d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -120,6 +120,32 @@ def is_torch_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, torch.Tensor) +def is_paddle_array(x): + """ + Return True if `x` is a Paddle tensor. + + This function does not import Paddle if it has not already been imported + and is therefore cheap to use. + + See Also + -------- + + array_namespace + is_array_api_obj + is_numpy_array + is_cupy_array + is_dask_array + is_jax_array + is_pydata_sparse_array + """ + # Avoid importing paddle if it isn't already + if 'paddle' not in sys.modules: + return False + + import paddle + + return paddle.is_tensor(x) + def is_ndonnx_array(x): """ Return True if `x` is a ndonnx Array. @@ -252,6 +278,7 @@ def is_array_api_obj(x): or is_dask_array(x) \ or is_jax_array(x) \ or is_pydata_sparse_array(x) \ + or is_paddle_array(x) \ or hasattr(x, '__array_namespace__') def _compat_module_name(): @@ -319,6 +346,27 @@ def is_torch_namespace(xp) -> bool: return xp.__name__ in {'torch', _compat_module_name() + '.torch'} +def is_paddle_namespace(xp) -> bool: + """ + Returns True if `xp` is a Paddle namespace. + + This includes both Paddle itself and the version wrapped by array-api-compat. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_ndonnx_namespace + is_dask_namespace + is_jax_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ in {'paddle', _compat_module_name() + '.paddle'} + + def is_ndonnx_namespace(xp): """ Returns True if `xp` is an NDONNX namespace. @@ -543,6 +591,14 @@ def your_function(x, y): else: import jax.experimental.array_api as jnp namespaces.add(jnp) + elif is_paddle_array(x): + if _use_compat: + _check_api_version(api_version) + from .. import paddle as paddle_namespace + namespaces.add(paddle_namespace) + else: + import paddle + namespaces.add(paddle) elif is_pydata_sparse_array(x): if use_compat is True: _check_api_version(api_version) @@ -660,6 +716,16 @@ def device(x: Array, /) -> Device: return "cpu" # Return the device of the constituent array return device(inner) + elif is_paddle_array(x): + raw_place_str = str(x.place) + if "gpu_pinned" in raw_place_str: + return "cpu" + elif "cpu" in raw_place_str: + return "cpu" + elif "gpu" in raw_place_str: + return "gpu" + raise ValueError(f"Unsupported Paddle device: {x.place}") + return x.device # Prevent shadowing, used below @@ -709,6 +775,14 @@ def _torch_to_device(x, device, /, stream=None): raise NotImplementedError return x.to(device) +def _paddle_to_device(x, device, /, stream=None): + if stream is not None: + raise NotImplementedError( + "paddle.Tensor.to() do not support stream argument yet" + ) + return x.to(device) + + def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -781,6 +855,8 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # In JAX v0.4.31 and older, this import adds to_device method to x. import jax.experimental.array_api # noqa: F401 return x.to_device(device, stream=stream) + elif is_paddle_array(x): + return _paddle_to_device(x, device, stream=stream) elif is_pydata_sparse_array(x) and device == _device(x): # Perform trivial check to return the same array if # device is same instead of err-ing. @@ -819,6 +895,8 @@ def size(x): "is_torch_namespace", "is_ndonnx_array", "is_ndonnx_namespace", + "is_paddle_array", + "is_paddle_namespace", "is_pydata_sparse_array", "is_pydata_sparse_namespace", "size", diff --git a/array_api_compat/paddle/__init__.py b/array_api_compat/paddle/__init__.py new file mode 100644 index 00000000..1016312d --- /dev/null +++ b/array_api_compat/paddle/__init__.py @@ -0,0 +1,22 @@ +from paddle import * # noqa: F403 + +# Several names are not included in the above import * +import paddle + +for n in dir(paddle): + if n.startswith("_") or n.endswith("_") or "gpu" in n or "cpu" in n or "backward" in n: + continue + exec(f"{n} = paddle.{n}") + + +# These imports may overwrite names from the import * above. +from ._aliases import * # noqa: F403 + +# See the comment in the numpy __init__.py +__import__(__package__ + ".linalg") + +__import__(__package__ + ".fft") + +from ..common._helpers import * # noqa: F403 + +__array_api_version__ = "2023.12" diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py new file mode 100644 index 00000000..88f71e7d --- /dev/null +++ b/array_api_compat/paddle/_aliases.py @@ -0,0 +1,1361 @@ +from __future__ import annotations + +from typing import Literal +import numpy as np + +from functools import wraps as _wraps +from builtins import any as _builtin_any + +from ..common._aliases import ( + unstack as _aliases_unstack, +) +from ..common._typing import ( + SupportsBufferProtocol, + NestedSequence, +) +from .._internal import get_xp + +from ._info import __array_namespace_info__ + +import paddle + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import List, Optional, Sequence, Tuple, Union + from ..common._typing import Device + from paddle import dtype as Dtype + + array = paddle.Tensor + +_int_dtypes = { + paddle.uint8, + paddle.int8, + paddle.int16, + paddle.int32, + paddle.int64, +} + +_array_api_dtypes = { + paddle.bool, + *_int_dtypes, + paddle.float32, + paddle.float64, + paddle.complex64, + paddle.complex128, +} + +# NOTE: Implicit promotion rules of Paddle is a bit strict than other frameworks, +# see details: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/advanced/auto_type_promotion_cn.html +_promotion_table = { + # bool + (paddle.bool, paddle.bool): paddle.bool, + # ints + (paddle.int8, paddle.int8): paddle.int8, + (paddle.int16, paddle.int16): paddle.int16, + (paddle.int32, paddle.int32): paddle.int32, + (paddle.int64, paddle.int64): paddle.int64, + # uints + (paddle.uint8, paddle.uint8): paddle.uint8, + # floats + (paddle.float32, paddle.float32): paddle.float32, + (paddle.float32, paddle.float64): paddle.float64, + (paddle.float64, paddle.float32): paddle.float64, + (paddle.float64, paddle.float64): paddle.float64, + # complexes + (paddle.complex64, paddle.complex64): paddle.complex64, + (paddle.complex64, paddle.complex128): paddle.complex128, + (paddle.complex128, paddle.complex64): paddle.complex128, + (paddle.complex128, paddle.complex128): paddle.complex128, + # Mixed float and complex + (paddle.float32, paddle.complex64): paddle.complex64, + (paddle.float32, paddle.complex128): paddle.complex128, + (paddle.float64, paddle.complex64): paddle.complex128, + (paddle.float64, paddle.complex128): paddle.complex128, +} + + +def _two_arg(f): + @_wraps(f) + def _f(x1, x2, /, **kwargs): + x1, x2 = _fix_promotion(x1, x2) + return f(x1, x2, **kwargs) + + if _f.__doc__ is None: + _f.__doc__ = f"""\ +Array API compatibility wrapper for paddle.{f.__name__}. + +See the corresponding Paddle documentation and/or the array API specification +for more details. + +""" + return _f + + +def _fix_promotion(x1, x2, only_scalar=True): + if not isinstance(x1, paddle.Tensor) or not isinstance(x2, paddle.Tensor): + return x1, x2 + if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes: + return x1, x2 + # If an argument is 0-D paddle downcasts the other argument + if not only_scalar or x1.shape == (): + dtype = result_type(x1, x2) + x2 = x2.to(dtype) + if not only_scalar or x2.shape == (): + dtype = result_type(x1, x2) + x1 = x1.to(dtype) + return x1, x2 + + +def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: + if len(arrays_and_dtypes) == 0: + raise TypeError("At least one array or dtype must be provided") + if len(arrays_and_dtypes) == 1: + x = arrays_and_dtypes[0] + return x if isinstance(x, paddle.dtype) else x.dtype + if len(arrays_and_dtypes) > 2: + return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:])) + + x, y = arrays_and_dtypes + xdt = x if isinstance(x, paddle.dtype) else x.dtype + ydt = y if isinstance(y, paddle.dtype) else y.dtype + + if (xdt, ydt) in _promotion_table: + return _promotion_table[(xdt, ydt)] + + type_order = { + paddle.bool: 0, + paddle.int8: 1, + paddle.uint8: 2, + paddle.int16: 3, + paddle.int32: 4, + paddle.int64: 5, + paddle.float16: 6, + paddle.float32: 7, + paddle.float64: 8, + paddle.complex64: 9, + paddle.complex128: 10 + } + + return xdt if type_order.get(xdt, 0) > type_order.get(ydt, 0) else ydt + + +def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: + if paddle.is_tensor(from_): + from_ = from_.dtype + + assert isinstance(from_, paddle.dtype), from_.dtype + assert isinstance(to, paddle.dtype), to.dtype + + can_cast_dict = { + paddle.bfloat16: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + paddle.float16: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + paddle.float32: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + paddle.float64: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + paddle.complex64: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + paddle.complex128: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + paddle.uint8: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + paddle.int8: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + paddle.int16: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + paddle.int32: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + paddle.int64: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + paddle.bool: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + }, + } + return can_cast_dict[from_][to] + + +# Basic renames +bitwise_invert = paddle.bitwise_not +newaxis = None +# paddle.conj sets the conjugation bit, which breaks conversion to other +# libraries. See https://github.com/data-apis/array-api-compat/issues/173 +conj = paddle.conj + +# Two-arg elementwise functions +# These require a wrapper to do the correct type promotion on 0-D tensors +add = _two_arg(paddle.add) +atan2 = _two_arg(paddle.atan2) +bitwise_and = _two_arg(paddle.bitwise_and) +bitwise_left_shift = _two_arg(paddle.bitwise_left_shift) +bitwise_or = _two_arg(paddle.bitwise_or) +bitwise_right_shift = _two_arg(paddle.bitwise_right_shift) +bitwise_xor = _two_arg(paddle.bitwise_xor) +copysign = _two_arg(paddle.copysign) +divide = _two_arg(paddle.divide) +# Also a rename. paddle.equal does not broadcast +equal = _two_arg(paddle.equal) +floor_divide = _two_arg(paddle.floor_divide) +greater = _two_arg(paddle.greater_than) +greater_equal = _two_arg(paddle.greater_equal) +hypot = _two_arg(paddle.hypot) +less = _two_arg(paddle.less) +less_equal = _two_arg(paddle.less_equal) +logaddexp = _two_arg(paddle.logaddexp) +# logical functions are not included here because they only accept bool in the +# spec, so type promotion is irrelevant. +maximum = _two_arg(paddle.maximum) +minimum = _two_arg(paddle.minimum) +multiply = _two_arg(paddle.multiply) +not_equal = _two_arg(paddle.not_equal) +pow = _two_arg(paddle.pow) +remainder = _two_arg(paddle.remainder) +subtract = _two_arg(paddle.subtract) + + +def max( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> array: + if axis == (): + return paddle.clone(x) + return paddle.amax(x, axis, keepdim=keepdims) + + +def argmax( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> array: + return paddle.argmax(x, axis, keepdim=keepdims) + + +def min( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> array: + if axis == (): + return paddle.clone(x) + return paddle.min(x, axis, keepdim=keepdims) + + +def argmin( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> array: + return paddle.argmin(x, axis, keepdim=keepdims) + + +unstack = get_xp(paddle)(_aliases_unstack) + + +# paddle.sort also returns a tuple +def sort( + x: array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs, +) -> array: + return paddle.sort(x, axis=axis, descending=descending, stable=stable, **kwargs) + + +def _normalize_axes(axis, ndim): + axes = [] + if ndim == 0 and axis: + # Better error message in this case + raise IndexError(f"Dimension out of range: {axis[0]}") + lower, upper = -ndim, ndim - 1 + for a in axis: + if a < lower or a > upper: + # Match paddle error message (e.g., from sum()) + raise IndexError( + f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}" + ) + if a < 0: + a = a + ndim + if a in axes: + # Use IndexError instead of RuntimeError, and "axis" instead of "dim" + raise IndexError(f"Axis {a} appears multiple times in the list of axes") + axes.append(a) + return sorted(axes) + + +def _axis_none_keepdims(x, ndim, keepdims): + # Apply keepdims when axis=None + # Note that this is only valid for the axis=None case. + if keepdims: + for i in range(ndim): + x = paddle.unsqueeze(x, 0) + return x + + +def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): + # Some reductions don't support multiple axes + axes = _normalize_axes(axis, x.ndim) + for a in reversed(axes): + x = paddle.moveaxis(x, a, -1) + x = paddle.flatten(x, -len(axes)) + + out = f(x, -1, **kwargs) + + if keepdims: + for a in axes: + out = paddle.unsqueeze(out, a) + return out + + +def prod( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, + **kwargs, +) -> array: + if not paddle.is_tensor(x): + x = paddle.to_tensor(x) + ndim = x.ndim + + # below because it still needs to upcast. + if axis == (): + if dtype is None: + # We can't upcast uint8 according to the spec because there is no + # paddle.uint64, so at least upcast to int64 which is what sum does + # when axis=None. + if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]: + return x.to(paddle.int64) + return x.clone() + return x.to(dtype) + + # paddle.prod doesn't support multiple axes + if isinstance(axis, tuple): + return _reduce_multiple_axes( + paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs + ) + if axis is None: + # paddle doesn't support keepdims with axis=None + res = paddle.prod(x, dtype=dtype, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res + + return paddle.prod(x, axis, dtype=dtype, keepdim=keepdims, **kwargs) + + +def sum( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, + **kwargs, +) -> array: + if not paddle.is_tensor(x): + x = paddle.to_tensor(x) + ndim = x.ndim + + # Make sure it upcasts. + if axis == (): + if dtype is None: + # We can't upcast uint8 according to the spec because there is no + # paddle.uint64, so at least upcast to int64 which is what sum does + # when axis=None. + if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]: + return x.to(paddle.int64) + return x.clone() + return x.to(dtype) + + if axis is None: + # paddle doesn't support keepdims with axis=None + res = paddle.sum(x, dtype=dtype, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res + + return paddle.sum(x, axis, dtype=dtype, keepdim=keepdims, **kwargs) + + +def any( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + **kwargs, +) -> array: + if not paddle.is_tensor(x): + x = paddle.to_tensor(x) + ndim = x.ndim + if axis == (): + return x.to(paddle.bool) + # paddle.any doesn't support multiple axes + if isinstance(axis, tuple): + res = _reduce_multiple_axes(paddle.any, x, axis, keepdim=keepdims, **kwargs) + return res.to(paddle.bool) + if axis is None: + # paddle doesn't support keepdims with axis=None + res = paddle.any(x, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res.to(paddle.bool) + + # paddle.any doesn't return bool for uint8 + return paddle.any(x, axis, keepdim=keepdims).to(paddle.bool) + + +def all( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + **kwargs, +) -> array: + if not paddle.is_tensor(x): + x = paddle.to_tensor(x) + ndim = x.ndim + if axis == (): + return x.to(paddle.bool) + # paddle.all doesn't support multiple axes + if isinstance(axis, tuple): + res = _reduce_multiple_axes(paddle.all, x, axis, keepdim=keepdims, **kwargs) + return res.to(paddle.bool) + if axis is None: + # paddle doesn't support keepdims with axis=None + res = paddle.all(x, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res.to(paddle.bool) + + # paddle.all doesn't return bool for uint8 + return paddle.all(x, axis, keepdim=keepdims).to(paddle.bool) + + +def mean( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + **kwargs, +) -> array: + if axis == (): + return paddle.clone(x) + if axis is None: + # paddle doesn't support keepdims with axis=None + res = paddle.mean(x, **kwargs) + res = _axis_none_keepdims(res, x.ndim, keepdims) + return res + return paddle.mean(x, axis, keepdim=keepdims, **kwargs) + + +def std( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, + **kwargs, +) -> array: + # Note, float correction is not supported + # implement it here for now. + + if isinstance(correction, float): + _correction = int(correction) + if correction != _correction: + raise NotImplementedError( + "float correction in paddle std() is not yet supported" + ) + elif isinstance(correction, int): + if correction not in [0, 1]: + raise NotImplementedError("correction only can be 0 or 1") + elif not isinstance(correction, bool): + raise NotImplementedError("Only support bool correction and 0, 1") + + _correction = bool(_correction) + + if axis == (): + return paddle.zeros_like(x) + if isinstance(axis, int): + axis = (axis,) + if axis is None: + # paddle doesn't support keepdims with axis=None + res = paddle.std(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs) + res = _axis_none_keepdims(res, x.ndim, keepdims) + return res + return paddle.std(x, axis, unbiased=_correction, keepdim=keepdims, **kwargs) + + +def var( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, + **kwargs, +) -> array: + # Note, float correction is not supported + # implement it here for now. + + # if isinstance(correction, float): + # correction = int(correction) + if isinstance(correction, float): + _correction = int(correction) + if correction != _correction: + raise NotImplementedError( + "float correction in paddle std() is not yet supported" + ) + elif isinstance(correction, int): + if correction not in [0, 1]: + raise NotImplementedError("correction only can be 0 or 1") + elif not isinstance(correction, bool): + raise NotImplementedError("Only support bool correction and 0, 1") + + _correction = bool(_correction) + + if axis == (): + return paddle.zeros_like(x) + if isinstance(axis, int): + axis = (axis,) + if axis is None: + # paddle doesn't support keepdims with axis=None + res = paddle.var(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs) + res = _axis_none_keepdims(res, x.ndim, keepdims) + return res + return paddle.var(x, axis, unbiased=_correction, keepdim=keepdims, **kwargs) + + +# paddle.concat doesn't support dim=None +def concat( + arrays: Union[Tuple[array, ...], List[array]], + /, + *, + axis: Optional[int] = 0, + **kwargs, +) -> array: + if axis is None: + arrays = tuple(ar.flatten() for ar in arrays) + axis = 0 + return paddle.concat(arrays, axis, **kwargs) + + +def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: + if isinstance(axis, int): + axis = (axis,) + for a in axis: + if x.shape[a] != 1: + raise ValueError("squeezed dimensions must be equal to 1") + axes = _normalize_axes(axis, x.ndim) + + sequence = [a - i for i, a in enumerate(axes)] + for a in sequence: + x = paddle.squeeze(x, a) + return x + + +# paddle.broadcast_to uses size instead of shape +def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array: + return paddle.broadcast_to(x, shape, **kwargs) + + +# paddle.permute uses dims instead of axes +def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: + return paddle.transpose(x, axes) + + +# The axis parameter doesn't work for flip() and roll() +# accept axis=None +def flip( + x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs +) -> array: + if axis is None: + axis = tuple(range(x.ndim)) + # paddle.flip doesn't accept dim as an int but the method does + return x.flip(axis, **kwargs) + + +def roll( + x: array, + /, + shift: Union[int, Tuple[int, ...]], + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + **kwargs, +) -> array: + return paddle.roll(x, shift, axis, **kwargs) + + +def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: + if x.ndim == 0: + raise ValueError("nonzero() does not support zero-dimensional arrays") + return paddle.nonzero(x, as_tuple=True, **kwargs) + + +def where(condition: array, x1: array, x2: array, /) -> array: + x1, x2 = _fix_promotion(x1, x2) + return paddle.where(condition, x1, x2) + + +def empty_like( + x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> array: + out = paddle.empty_like(x, dtype=dtype) + if device is not None: + out = out.to(device) + return out + + +def zeros_like( + x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> array: + out = paddle.zeros_like(x, dtype=dtype) + if device is not None: + out = out.to(device) + return out + + +def ones_like( + x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> array: + out = paddle.ones_like(x, dtype=dtype) + if device is not None: + out = out.to(device) + return out + + +def full_like( + x: array, + /, + fill_value: bool | int | float | complex, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> array: + out = paddle.full_like(x, fill_value, dtype=dtype) + if device is not None: + out = out.to(device) + return out + + +# paddle.reshape doesn't have the copy keyword +def reshape( + x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs +) -> array: + return paddle.reshape(x, shape, **kwargs) + + +# paddle.arange doesn't support returning empty arrays +# keyword argument combinations +def arange( + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs, +) -> array: + return paddle.arange(start, stop, step, dtype=dtype, **kwargs).to(device) + + +# paddle.eye does not accept None as a default for the second argument and +def eye( + n_rows: int, + n_cols: Optional[int] = None, + /, + *, + k: int = 0, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs, +) -> array: + if n_cols is None: + n_cols = n_rows + z = paddle.zeros([n_rows, n_cols], dtype=dtype, **kwargs).to(device) + if abs(k) <= n_rows + n_cols: + z.diagonal(k).fill_(1) + return z + + +# paddle.linspace doesn't have the endpoint parameter +def linspace( + start: Union[int, float], + stop: Union[int, float], + /, + num: int, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + endpoint: bool = True, + **kwargs, +) -> array: + if not endpoint: + return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[ + :-1 + ] + return paddle.linspace(start, stop, num, dtype=dtype, **kwargs).to(device) + + +# paddle.full does not accept an int size +def full( + shape: Union[int, Tuple[int, ...]], + fill_value: Union[bool, int, float, complex], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs, +) -> array: + if isinstance(shape, int): + shape = (shape,) + + return paddle.full(shape, fill_value, dtype=dtype, **kwargs).to(device) + + +# ones, zeros, and empty do not accept shape as a keyword argument +def ones( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs, +) -> array: + return paddle.ones(shape, dtype=dtype, **kwargs).to(device) + + +def zeros( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs, +) -> array: + return paddle.zeros(shape, dtype=dtype, **kwargs).to(device) + + +def empty( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs, +) -> array: + return paddle.empty(shape, dtype=dtype, **kwargs).to(device) + + +# tril and triu do not call the keyword argument k + + +def tril(x: array, /, *, k: int = 0) -> array: + return paddle.tril(x, k) + + +def triu(x: array, /, *, k: int = 0) -> array: + return paddle.triu(x, k) + + +def expand_dims(x: array, /, *, axis: int = 0) -> array: + return paddle.unsqueeze(x, axis) + + +def astype( + x: array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None +) -> array: + # if copy is not None: + # raise NotImplementedError("paddle.astype doesn't yet support the copy keyword") + t = x.to(dtype, device=device) + if copy: + t = t.detach().clone() + return t + + +def broadcast_arrays(*arrays: array) -> List[array]: + original_dtypes = [arr.dtype for arr in arrays] + if len(set(original_dtypes)) == 1: + return paddle.broadcast_tensors(arrays) + target_dtype = result_type(*arrays) + casted_arrays = [arr.astype(target_dtype) if arr.dtype != target_dtype else arr + for arr in arrays] + broadcasted = paddle.broadcast_tensors(casted_arrays) + result = [arr.astype(original_dtype) for arr, original_dtype in zip(broadcasted, original_dtypes)] + return result + + +# Note that these named tuples aren't actually part of the standard namespace, +# but I don't see any issue with exporting the names here regardless. +from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult + + +def unique_all(x: array) -> UniqueAllResult: + return paddle.unique( + x, + return_index=True, + return_inverse=True, + return_counts=True, + ) + + +def unique_counts(x: array) -> UniqueCountsResult: + values, counts = paddle.unique(x, return_counts=True) + + # paddle.unique incorrectly gives a 0 count for nan values. + counts[paddle.isnan(values)] = 1 + return UniqueCountsResult(values, counts) + + +def unique_inverse(x: array) -> UniqueInverseResult: + values, inverse = paddle.unique(x, return_inverse=True) + return UniqueInverseResult(values, inverse) + + +def unique_values(x: array) -> array: + return paddle.unique(x) + + +def matmul(x1: array, x2: array, /, **kwargs) -> array: + # paddle.matmul doesn't type promote (but differently from _fix_promotion) + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return paddle.matmul(x1, x2, **kwargs) + + +def meshgrid(*arrays: array, indexing: str = "xy") -> List[array]: + if indexing == "ij": + return paddle.meshgrid(*arrays) + else: + return [i.T for i in paddle.meshgrid(*arrays)] + + +matrix_transpose = paddle.linalg.matrix_transpose + + +def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return paddle.linalg.vecdot(x1, x2, axis=axis) + + +# paddle.tensordot uses dims instead of axes +def tensordot( + x1: array, + x2: array, + /, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> array: + # Note: paddle.tensordot fails with integer dtypes when there is only 1 + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return paddle.tensordot(x1, x2, axes=axes, **kwargs) + + +def isdtype( + dtype: Dtype, + kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], + *, + _tuple=True, # Disallow nested tuples +) -> bool: + """ + Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. + + Note that outside of this function, this compat library does not yet fully + support complex numbers. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + for more details + """ + + def is_signed(dtype): + return dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.int64] + + def is_complex(dtype): + return dtype in [paddle.complex64, paddle.complex128] + + if isinstance(kind, tuple) and _tuple: + return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind) + + elif isinstance(kind, str): + if kind == "bool": + return dtype == paddle.bool + elif kind == "signed integer": + return dtype in _int_dtypes and is_signed(dtype) + elif kind == "unsigned integer": + return dtype in _int_dtypes and not is_signed(dtype) + elif kind == "integral": + return dtype in _int_dtypes + elif kind == "real floating": + return dtype in [ + paddle.framework.core.VarDesc.VarType.FP32, + paddle.framework.core.VarDesc.VarType.FP64, + paddle.framework.core.VarDesc.VarType.FP16, + paddle.framework.core.VarDesc.VarType.BF16, + paddle.framework.core.DataType.FLOAT32, + paddle.framework.core.DataType.FLOAT64, + paddle.framework.core.DataType.FLOAT16, + paddle.framework.core.DataType.BFLOAT16, + ] + elif kind == "complex floating": + return is_complex(dtype) + elif kind == "numeric": + return isdtype(dtype, ("integral", "real floating", "complex floating")) + else: + raise ValueError(f"Unrecognized data type kind: {kind!r}") + else: + return dtype == kind + + +def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: + if axis is None: + if x.ndim != 1: + raise ValueError("axis must be specified when ndim > 1") + axis = 0 + return paddle.index_select(x, axis, indices, **kwargs) + + +def sign(x: array, /) -> array: + # paddle sign() does not support complex numbers and does not propagate + # nans. See https://github.com/data-apis/array-api-compat/issues/136 + if paddle.is_complex(x): + out = x / paddle.abs(x) + # sign(0) = 0 but the above formula would give nan + out[x == 0 + 0j] = 0 + 0j + return out + else: + out = paddle.sign(x) + if paddle.is_floating_point(x): + out = paddle.where(paddle.isnan(x), paddle.full(x.shape, paddle.nan), out) + return out + + +def broadcast_shapes(*shapes: List[int]) -> List[int]: + out_shape = shapes[0] + for i, shape in enumerate(shapes): + if i == 0: + continue + out_shape = paddle.broadcast_shape(out_shape, shape) + + return out_shape + + +# asarray also adds the copy keyword, which is not present in numpy 1.0. +def asarray( + obj: Union[ + array, + bool, + int, + float, + NestedSequence[bool | int | float], + SupportsBufferProtocol, + ], + /, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + copy: Optional[bool] = None, + **kwargs, +) -> array: + """ + Array API compatibility wrapper for asarray(). + + See the corresponding documentation in the array library and/or the array API + specification for more details. + """ + if copy is False: + if hasattr(obj, "__dlpack__"): + obj = paddle.from_dlpack(obj.__dlpack__()) + if device is not None: + obj = obj.to(device) + if dtype is not None: + obj = obj.to(dtype) + return obj + else: + raise NotImplementedError( + "asarray(obj, ..., copy=False) is not supported " + "for obj do not has '__dlpack__()' method" + ) + elif copy is True: + obj = np.array(obj, copy=True) + if np.issubdtype(obj.dtype, np.floating) and dtype is None: + obj = obj.astype(paddle.get_default_dtype()) + return paddle.to_tensor(obj, dtype=dtype, place=device) + else: + if not paddle.is_tensor(obj) or (dtype is not None and obj.dtype != dtype): + obj = np.array(obj, copy=False) + if np.issubdtype(obj.dtype, np.floating) and dtype is None: + obj = obj.astype(paddle.get_default_dtype()) + if dtype != paddle.bool and dtype != "bool": + obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype) + else: + obj = paddle.to_tensor(obj, dtype=dtype) + if device is not None: + obj = obj.to(device) + return obj + + return obj + + +def floor(x: array, /) -> array: + return paddle.floor(x).to(x.dtype) + + +def ceil(x: array, /) -> array: + return paddle.ceil(x).to(x.dtype) + + +def clip( + x: array, + /, + min: Optional[Union[int, float, array]] = None, + max: Optional[Union[int, float, array]] = None, +) -> array: + if min is None and max is None: + return x + + def _isscalar(a): + return isinstance(a, (int, float, type(None))) + + min_shape = [] if _isscalar(min) else min.shape + max_shape = [] if _isscalar(max) else max.shape + + result_shape = broadcast_shapes(x.shape, min_shape, max_shape) + + # np.clip does type promotion but the array API clip requires that the + # output have the same dtype as x. We do this instead of just downcasting + # the result of xp.clip() to handle some corner cases better (e.g., + # avoiding uint64 -> float64 promotion). + + # Note: cases where min or max overflow (integer) or round (float) in the + # wrong direction when downcasting to x.dtype are unspecified. This code + # just does whatever NumPy does when it downcasts in the assignment, but + # other behavior could be preferred, especially for integers. For example, + # this code produces: + + # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None) + # -128 + + # but an answer of 0 might be preferred. See + # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. + + # At least handle the case of Python integers correctly (see + # https://github.com/numpy/numpy/pull/26892). + if type(min) is int and min <= paddle.iinfo(x.dtype).min: + min = None + if type(max) is int and max >= paddle.iinfo(x.dtype).max: + max = None + + out = paddle.to_tensor(broadcast_to(x, result_shape), place=x.place) + if min is not None: + if paddle.is_tensor(x) and x.dtype == paddle.float64 and _isscalar(min): + # Avoid loss of precision due to paddle defaulting to float32 + min = paddle.to_tensor(min, dtype=paddle.float64) + a = broadcast_to(paddle.to_tensor(min, place=x.place), result_shape) + ia = (out < a) | paddle.isnan(a) + # paddle requires an explicit cast here + out[ia] = astype(a[ia], out.dtype) + if max is not None: + if paddle.is_tensor(x) and x.dtype == paddle.float64 and _isscalar(max): + max = paddle.to_tensor(max, dtype=paddle.float64) + b = broadcast_to(paddle.to_tensor(max, place=x.place), result_shape) + ib = (out > b) | paddle.isnan(b) + out[ib] = astype(b[ib], out.dtype) + # Return a scalar for 0-D + return out + + +def cumulative_sum( + x: array, + /, + *, + axis: Optional[int] = None, + dtype: Optional[Dtype] = None, + include_initial: bool = False, +) -> array: + if axis is None: + if x.ndim > 1: + raise ValueError( + "axis must be specified in cumulative_sum for more than one dimension" + ) + axis = 0 + + res = paddle.cumsum(x, axis=axis, dtype=dtype) + + # np.cumsum does not support include_initial + if include_initial: + initial_shape = list(x.shape) + initial_shape[axis] = 1 + res = paddle.concat( + [paddle.zeros(shape=initial_shape, dtype=res.dtype).to(res.place), res], + axis=axis, + ) + return res + + +def searchsorted( + x1: array, + x2: array, + /, + *, + side: Literal["left", "right"] = "left", + sorter: array | None = None, +) -> array: + if sorter is None: + return paddle.searchsorted(x1, x2, right=(side == "right")) + + return paddle.searchsorted( + x1.take_along_axis(axis=-1, indices=sorter), + x2, + right=(side == "right"), + ) + + +__all__ = [ + "__array_namespace_info__", + "result_type", + "can_cast", + "permute_dims", + "bitwise_invert", + "newaxis", + "conj", + "add", + "atan2", + "bitwise_and", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "copysign", + "divide", + "equal", + "floor_divide", + "greater", + "greater_equal", + "hypot", + "less", + "less_equal", + "logaddexp", + "maximum", + "minimum", + "multiply", + "not_equal", + "pow", + "remainder", + "subtract", + "max", + "min", + "clip", + "unstack", + "cumulative_sum", + "sort", + "prod", + "sum", + "any", + "all", + "mean", + "std", + "var", + "concat", + "squeeze", + "broadcast_to", + "flip", + "roll", + "nonzero", + "where", + "reshape", + "arange", + "eye", + "linspace", + "full", + "ones", + "zeros", + "empty", + "tril", + "triu", + "expand_dims", + "astype", + "broadcast_arrays", + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "matmul", + "matrix_transpose", + "vecdot", + "tensordot", + "isdtype", + "take", + "sign", + "broadcast_shapes", + "argmax", + "argmin", + "searchsorted", + "empty_like", + "zeros_like", + "ones_like", + "full_like", + "asarray", + "ceil", + "floor", +] + +_all_ignore = ["paddle", "get_xp"] diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py new file mode 100644 index 00000000..6f079020 --- /dev/null +++ b/array_api_compat/paddle/_info.py @@ -0,0 +1,380 @@ +""" +Array API Inspection namespace + +This is the namespace for inspection functions as defined by the array API +standard. See +https://data-apis.org/array-api/latest/API_specification/inspection.html for +more details. + +""" + +import paddle + +from functools import cache + + +class __array_namespace_info__: + """ + Get the array API inspection namespace for Paddle. + + The array API inspection namespace defines the following functions: + + - capabilities() + - default_device() + - default_dtypes() + - dtypes() + - devices() + + See + https://data-apis.org/array-api/latest/API_specification/inspection.html + for more details. + + Returns + ------- + info : ModuleType + The array API inspection namespace for Paddle. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': numpy.float64, + 'complex floating': numpy.complex128, + 'integral': numpy.int64, + 'indexing': numpy.int64} + + """ + + __module__ = "paddle" + + def capabilities(self): + """ + Return a dictionary of array API library capabilities. + + The resulting dictionary has the following keys: + + - **"boolean indexing"**: boolean indicating whether an array library + supports boolean indexing. Always ``True`` for Paddle. + + - **"data-dependent shapes"**: boolean indicating whether an array + library supports data-dependent output shapes. Always ``True`` for + Paddle. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html + for more details. + + See Also + -------- + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + capabilities : dict + A dictionary of array API library capabilities. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.capabilities() + {'boolean indexing': True, + 'data-dependent shapes': True} + + """ + return { + "boolean indexing": True, + "data-dependent shapes": True, + # 'max rank' will be part of the 2024.12 standard + # "max rank": 64, + } + + def default_device(self): + """ + The default device used for new Paddle arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + device : str + The default device used for new Paddle arrays. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_device() + 'cpu' + + """ + return paddle.device.get_device() + + def default_dtypes(self, *, device=None): + """ + The default data types used for new Paddle arrays. + + Parameters + ---------- + device : str, optional + The device to get the default data types for. For Paddle, only + ``'cpu'`` is allowed. + + Returns + ------- + dtypes : dict + A dictionary describing the default data types used for new Paddle + arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': paddle.float32, + 'complex floating': paddle.complex64, + 'integral': paddle.int64, + 'indexing': paddle.int64} + + """ + # Note: if the default is set to float64, the devices like MPS that + # don't support float64 will error. We still return the default_dtype + # value here because this error doesn't represent a different default + # per-device. + default_floating = paddle.get_default_dtype() + if default_floating in ["float16", "float32", "float64", "bfloat16"]: + default_floating = getattr(paddle, default_floating) + else: + raise ValueError(f"Unsupported default floating: {default_floating}") + default_complex = ( + paddle.complex64 + if default_floating == paddle.float32 + else paddle.complex128 + ) + default_integral = paddle.int64 + return { + "real floating": default_floating, + "complex floating": default_complex, + "integral": default_integral, + "indexing": default_integral, + } + + def _dtypes(self, kind): + bool = paddle.bool + int8 = paddle.int8 + int16 = paddle.int16 + int32 = paddle.int32 + int64 = paddle.int64 + uint8 = paddle.uint8 + # uint16, uint32, and uint64 are not fully supported in paddle, + # we omit them from this function. + float32 = paddle.float32 + float64 = paddle.float64 + complex64 = paddle.complex64 + complex128 = paddle.complex128 + + if kind is None: + return { + "bool": bool, + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + } + if kind == "unsigned integer": + return { + "uint8": uint8, + } + if kind == "integral": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + } + if kind == "real floating": + return { + "float32": float32, + "float64": float64, + } + if kind == "complex floating": + return { + "complex64": complex64, + "complex128": complex128, + } + if kind == "numeric": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(self.dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + + @cache + def dtypes(self, *, device=None, kind=None): + """ + The array API data types supported by Paddle. + + Note that this function only returns data types that are defined by + the array API. + + Parameters + ---------- + device : str, optional + The device to get the data types for. + kind : str or tuple of str, optional + The kind of data types to return. If ``None``, all data types are + returned. If a string, only data types of that kind are returned. + If a tuple, a dictionary containing the union of the given kinds + is returned. The following kinds are supported: + + - ``'bool'``: boolean data types (i.e., ``bool``). + - ``'signed integer'``: signed integer data types (i.e., ``int8``, + ``int16``, ``int32``, ``int64``). + - ``'unsigned integer'``: unsigned integer data types (i.e., + ``uint8``, ``uint16``, ``uint32``, ``uint64``). + - ``'integral'``: integer data types. Shorthand for ``('signed + integer', 'unsigned integer')``. + - ``'real floating'``: real-valued floating-point data types + (i.e., ``float32``, ``float64``). + - ``'complex floating'``: complex floating-point data types (i.e., + ``complex64``, ``complex128``). + - ``'numeric'``: numeric data types. Shorthand for ``('integral', + 'real floating', 'complex floating')``. + + Returns + ------- + dtypes : dict + A dictionary mapping the names of data types to the corresponding + Paddle data types. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.dtypes(kind='signed integer') + {'int8': numpy.int8, + 'int16': numpy.int16, + 'int32': numpy.int32, + 'int64': numpy.int64} + + """ + res = self._dtypes(kind) + for k, v in res.copy().items(): + try: + paddle.empty((0,), dtype=v, device=device) + except: + del res[k] + return res + + @cache + def devices(self): + """ + The devices supported by Paddle. + + Returns + ------- + devices : list of str + The devices supported by Paddle. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.devices() + [device(type='cpu'), device(type='mps', index=0), device(type='meta')] + + """ + # Paddle doesn't have a straightforward way to get the list of all + # currently supported devices. To do this, we first parse the error + # message of paddle.device to get the list of all possible types of + # device: + try: + paddle.set_device("notadevice") + except ValueError as e: + # The error message is something like: + # ValueError: The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x + devices_names = ( + e.args[0] + .split("The device must be a string which is like ")[1] + .split(", ") + ) + devices_names = [ + name.strip("'") for name in devices_names if ":" not in name + ] + + # Next we need to check for different indices for different devices. + # device(device_name, index=index) doesn't actually check if the + # device name or index is valid. We have to try to create a tensor + # with it (which is why this function is cached). + devices = [] + for device_name in devices_names: + i = 0 + while True: + try: + if device_name == "cpu": + a = paddle.empty((0,), place=paddle.CPUPlace()) + elif device_name == "gpu": + a = paddle.empty((0,), place=paddle.CUDAPlace(i)) + elif device_name == "xpu": + a = paddle.empty((0,), place=paddle.XPUPlace()) + else: + raise + if a.place in devices: + break + devices.append(a.device) + except: + break + i += 1 + + return devices diff --git a/array_api_compat/paddle/fft.py b/array_api_compat/paddle/fft.py new file mode 100644 index 00000000..1442aed8 --- /dev/null +++ b/array_api_compat/paddle/fft.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import paddle + from ..common._typing import Device + + array = paddle.Tensor + from typing import Optional, Union, Sequence, Literal + +from paddle.fft import * # noqa: F403 +import paddle.fft + + +def fftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return paddle.fft.fftn(x, s=s, axes=axes, norm=norm, **kwargs) + + +def ifftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return paddle.fft.ifftn(x, s=s, axes=axes, norm=norm, **kwargs) + + +def rfftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return paddle.fft.rfftn(x, s=s, axes=axes, norm=norm, **kwargs) + + +def irfftn( + x: array, + /, + *, + s: Sequence[int] = None, + axes: Sequence[int] = None, + norm: Literal["backward", "ortho", "forward"] = "backward", + **kwargs, +) -> array: + return paddle.fft.irfftn(x, s=s, axes=axes, norm=norm, **kwargs) + + +def fftshift( + x: array, + /, + *, + axes: Union[int, Sequence[int]] = None, + **kwargs, +) -> array: + return paddle.fft.fftshift(x, axes=axes, **kwargs) + + +def ifftshift( + x: array, + /, + *, + axes: Union[int, Sequence[int]] = None, + **kwargs, +) -> array: + return paddle.fft.ifftshift(x, axes=axes, **kwargs) + + +def fftfreq( + n: int, + /, + *, + d: float = 1.0, + device: Optional[Device] = None, +) -> array: + out = paddle.fft.fftfreq(n, d) + if device is not None: + out = out.to(device) + return out + + +def rfftfreq( + n: int, + /, + *, + d: float = 1.0, + device: Optional[Device] = None, +) -> array: + out = paddle.fft.rfftfreq(n, d) + if device is not None: + out = out.to(device) + return out + + +__all__ = paddle.fft.__all__ + [ + "fftn", + "ifftn", + "rfftn", + "irfftn", + "fftshift", + "ifftshift", + "fftfreq", + "rfftfreq", +] + +_all_ignore = ["paddle"] diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py new file mode 100644 index 00000000..7dd1a266 --- /dev/null +++ b/array_api_compat/paddle/linalg.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import paddle + + array = paddle.Tensor + from paddle import dtype as Dtype + from typing import Optional, Union, Tuple, Literal + + inf = float("inf") + +from ._aliases import _fix_promotion, sum +from collections import namedtuple + +import paddle +from paddle.linalg import * # noqa: F403 + +# paddle.linalg doesn't define __all__ +# from paddle.linalg import __all__ as linalg_all +from paddle import linalg as paddle_linalg + +linalg_all = [i for i in dir(paddle_linalg) if not i.startswith("_")] + +# outer is implemented in paddle but aren't in the linalg namespace +from paddle import outer +import paddle + +# These functions are in both the main and linalg namespaces +from ._aliases import matmul, matrix_transpose, tensordot + +# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the +# first axis with size 3) + + +# paddle.cross also does not support broadcasting when it would add new +# dimensions +def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): + raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") + + if not (x1.shape[axis] == x2.shape[axis] == 3): + raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") + + x1, x2 = paddle.broadcast_tensors([x1, x2]) + return paddle_linalg.cross(x1, x2, axis=axis) + + +def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: + from ._aliases import isdtype + + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + + # paddle.linalg.vecdot incorrectly allows broadcasting along the contracted dimension + if x1.shape[axis] != x2.shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + # paddle.linalg.vecdot doesn't support integer dtypes + if isdtype(x1.dtype, "integral") or isdtype(x2.dtype, "integral"): + if kwargs: + raise RuntimeError("vecdot kwargs not supported for integral dtypes") + + x1_ = paddle.moveaxis(x1, axis, -1) + x2_ = paddle.moveaxis(x2, axis, -1) + x1_, x2_ = paddle.broadcast_tensors([x1_, x2_]) + + res = x1_[..., None, :] @ x2_[..., None] + return res[..., 0, 0] + return paddle.linalg.vecdot(x1, x2, axis=axis, **kwargs) + + +def solve(x1: array, x2: array, /, **kwargs) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + + if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape: + x2 = x2[None] + return paddle.linalg.solve(x1, x2, **kwargs) + + +# paddle.trace doesn't support the offset argument and doesn't support stacking +def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: + # Use our wrapped sum to make sure it does upcasting correctly + return sum(paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1), axis=-1, dtype=dtype) + + +def vector_norm( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Union[int, float, Literal[inf, -inf]] = 2, + **kwargs, +) -> array: + # paddle.vector_norm incorrectly treats axis=() the same as axis=None + if axis == (): + out = kwargs.get("out") + if out is None: + dtype = None + if x.dtype == paddle.complex64: + dtype = paddle.float32 + elif x.dtype == paddle.complex128: + dtype = paddle.float64 + + out = paddle.zeros_like(x, dtype=dtype) + + # The norm of a single scalar works out to abs(x) in every case except + # for ord=0, which is x != 0. + if ord == 0: + out[:] = x != 0 + else: + out[:] = paddle.abs(x) + return out + return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs) + + +def matrix_norm( + x: array, + /, + *, + keepdims: bool = False, + ord: Optional[Union[int, float, Literal["fro", "nuc"]]] = "fro", +) -> array: + return paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims) + + +def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array: + if rtol is None: + return paddle.linalg.pinv(x) + + return paddle.linalg.pinv(x, rcond=rtol) + + +def slogdet(x: array): + det = paddle.linalg.det(x) + sign = paddle.sign(det) + log_det = paddle.log(det) + + slotdet = namedtuple("slotdet", ["sign", "logabsdet"]) + return slotdet(sign, log_det) + + +__all__ = linalg_all + [ + "outer", + "matmul", + "matrix_transpose", + "matrix_norm", + "tensordot", + "cross", + "vecdot", + "solve", + "trace", + "vector_norm", + "slogdet", +] + +_all_ignore = ["paddle_linalg", "sum"] + +del linalg_all diff --git a/docs/index.md b/docs/index.md index ef18265e..874c3866 100644 --- a/docs/index.md +++ b/docs/index.md @@ -60,6 +60,10 @@ import array_api_compat.torch as torch import array_api_compat.dask as da ``` +```py +import array_api_compat.paddle as paddle +``` + ```{note} There are no `array_api_compat` submodules for JAX, sparse, or ndonnx. These support for these libraries is contained in the libraries themselves (JAX diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md index a016a636..26a1c1c5 100644 --- a/docs/supported-array-libraries.md +++ b/docs/supported-array-libraries.md @@ -137,3 +137,25 @@ The minimum supported Dask version is 2023.12.0. ## [Sparse](https://sparse.pydata.org/en/stable/) Similar to JAX, `sparse` Array API support is contained directly in `sparse`. + +## [Paddle](https://www.paddlepaddle.org.cn/) + +- Like NumPy/CuPy, we do not wrap the `paddle.Tensor` object. It is missing the + `__array_namespace__` and `to_device` methods, so the corresponding helper + functions {func}`~.array_namespace()` and {func}`~.to_device()` in this + library should be used instead. + +- Paddle does not have unsigned integer types other than `uint8`, and no + attempt is made to implement them here. + +- [`std()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html#array_api.std) + and + [`var()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html#array_api.var) + do not support floating-point `correction` except for `0.0` and `1.0`. + +- The `stream` argument of the {func}`~.to_device()` helper is not supported. + +- As with NumPy, type annotations and positional-only arguments may not + exactly match the spec for functions that are not wrapped at all. + +The minimum supported PyTorch version is 3.0.0. diff --git a/requirements-dev.txt b/requirements-dev.txt index c9d10f71..7ad022d7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,5 +4,6 @@ jax[cpu] numpy pytest torch +paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/ sparse >=0.15.1 ndonnx diff --git a/tests/_helpers.py b/tests/_helpers.py index e2a7e1d1..801cd32d 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -3,7 +3,7 @@ import pytest -wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"] +wrapped_libraries = ["numpy", "cupy", "torch", "dask.array", "paddle"] all_libraries = wrapped_libraries + ["jax.numpy"] # `sparse` added array API support as of Python 3.10. @@ -25,4 +25,9 @@ def import_(library, wrapper=False): else: library = 'array_api_compat.' + library + if library == 'paddle': + xp = import_module(library) + xp.asarray = xp.to_tensor + return xp + return import_module(library) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 9c26371c..4076c74c 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -6,6 +6,7 @@ import numpy as np import pytest import torch +import paddle import array_api_compat from array_api_compat import array_namespace @@ -91,6 +92,12 @@ def test_array_namespace_errors_torch(): x = np.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) + +def test_array_namespace_errors_paddle(): + y = paddle.to_tensor([1, 2]) + x = np.asarray([1, 2]) + pytest.raises(TypeError, lambda: array_namespace(x, y)) + def test_api_version(): x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) @@ -115,7 +122,7 @@ def test_get_namespace(): # Backwards compatible wrapper assert array_api_compat.get_namespace is array_api_compat.array_namespace -def test_python_scalars(): +def test_python_scalars_torch(): a = torch.asarray([1, 2]) xp = import_("torch", wrapper=True) @@ -130,3 +137,19 @@ def test_python_scalars(): assert array_namespace(a, 1j) == xp assert array_namespace(a, True) == xp assert array_namespace(a, None) == xp + +def test_python_scalars_paddle(): + a = paddle.to_tensor([1, 2]) + xp = import_("paddle", wrapper=True) + + pytest.raises(TypeError, lambda: array_namespace(1)) + pytest.raises(TypeError, lambda: array_namespace(1.0)) + pytest.raises(TypeError, lambda: array_namespace(1j)) + pytest.raises(TypeError, lambda: array_namespace(True)) + pytest.raises(TypeError, lambda: array_namespace(None)) + + assert array_namespace(a, 1) == xp + assert array_namespace(a, 1.0) == xp + assert array_namespace(a, 1j) == xp + assert array_namespace(a, True) == xp + assert array_namespace(a, None) == xp diff --git a/tests/test_common.py b/tests/test_common.py index e1cfa9eb..23ac53d1 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,8 +1,8 @@ from array_api_compat import ( # noqa: F401 - is_numpy_array, is_cupy_array, is_torch_array, + is_numpy_array, is_cupy_array, is_torch_array, is_paddle_array, is_dask_array, is_jax_array, is_pydata_sparse_array, is_numpy_namespace, is_cupy_namespace, is_torch_namespace, - is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, + is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, is_paddle_namespace, ) from array_api_compat import is_array_api_obj, device, to_device @@ -21,6 +21,7 @@ 'dask.array': 'is_dask_array', 'jax.numpy': 'is_jax_array', 'sparse': 'is_pydata_sparse_array', + 'paddle': 'is_paddle_array', } is_namespace_functions = { @@ -30,6 +31,7 @@ 'dask.array': 'is_dask_namespace', 'jax.numpy': 'is_jax_namespace', 'sparse': 'is_pydata_sparse_namespace', + 'paddle': 'is_paddle_namespace', } @@ -101,6 +103,13 @@ def test_asarray_cross_library(source_library, target_library, request): if source_library == "cupy" and target_library != "cupy": # cupy explicitly disallows implicit conversions to CPU pytest.skip(reason="cupy does not support implicit conversion to CPU") + if source_library == "paddle" or target_library == "paddle": + pytest.skip( + reason=( + "paddle does not support implicit conversion from/to other framework " + "via 'asarray', dlpack is recommend now." + ) + ) elif source_library == "sparse" and target_library != "sparse": pytest.skip(reason="`sparse` does not allow implicit densification") src_lib = import_(source_library, wrapper=True) @@ -114,6 +123,8 @@ def test_asarray_cross_library(source_library, target_library, request): @pytest.mark.parametrize("library", wrapped_libraries) def test_asarray_copy(library): + if library == 'paddle': + pytest.skip("Paddle does not support explicit copies") # Note, we have this test here because the test suite currently doesn't # test the copy flag to asarray() very rigorously. Once # https://github.com/data-apis/array-api-tests/issues/241 is fixed we diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 6ad45d4c..e7b7d9c1 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -10,7 +10,7 @@ # Check the known dtypes by their string names def _spec_dtypes(library): - if library == 'torch': + if library in ['torch', 'paddle']: # torch does not have unsigned integer dtypes return { 'bool', diff --git a/tests/test_no_dependencies.py b/tests/test_no_dependencies.py index a1fdf731..11a516ac 100644 --- a/tests/test_no_dependencies.py +++ b/tests/test_no_dependencies.py @@ -49,8 +49,12 @@ def _test_dependency(mod): # TODO: Test that wrapper for library X doesn't depend on wrappers for library # Y (except most array libraries actually do themselves depend on numpy). -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", - "jax.numpy", "sparse", "array_api_strict"]) +@pytest.mark.parametrize("library", + [ + "numpy", "cupy", "numpy", "torch", "dask.array", + "jax.numpy", "sparse", "paddle", "array_api_strict" + ] +) def test_numpy_dependency(library): # This import is here because it imports numpy from ._helpers import import_ diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 70083b49..3c9b5d92 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -24,3 +24,9 @@ def test_vendoring_torch(): def test_vendoring_dask(): from vendor_test import uses_dask uses_dask._test_dask() + + +def test_vendoring_paddle(): + from vendor_test import uses_paddle + + uses_paddle._test_paddle() diff --git a/vendor_test/uses_paddle.py b/vendor_test/uses_paddle.py new file mode 100644 index 00000000..e92257a4 --- /dev/null +++ b/vendor_test/uses_paddle.py @@ -0,0 +1,30 @@ +# Basic test that vendoring works + +from .vendored._compat import ( + is_paddle_array, + is_paddle_namespace, + paddle as paddle_compat, +) + +import paddle + +def _test_paddle(): + a = paddle_compat.to_tensor([1., 2., 3.]) + b = paddle_compat.arange(3, dtype=paddle_compat.float64) + assert a.dtype == paddle_compat.float32 == paddle.float32 + assert b.dtype == paddle_compat.float64 == paddle.float64 + + # paddle.expand_dims does not exist. Update this to use something else if it is added + res = paddle_compat.expand_dims(a, axis=0) + assert res.dtype == paddle_compat.float32 == paddle.float32 + assert res.shape == [1, 3] + assert isinstance(res.shape, list) + assert isinstance(a, paddle.Tensor) + assert isinstance(b, paddle.Tensor) + assert isinstance(res, paddle.Tensor) + + assert paddle.allclose(res, paddle.to_tensor([[1., 2., 3.]])) + + assert is_paddle_array(res) + assert is_paddle_namespace(paddle) and is_paddle_namespace(paddle_compat) +