Skip to content

Commit b30a59e

Browse files
authored
Merge pull request #166 from asmeurer/more-2023
More fixes for 2023.12 support
2 parents 6f9edc7 + 9d2e283 commit b30a59e

20 files changed

+1591
-94
lines changed

.github/workflows/array-api-tests.yml

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ jobs:
7474
if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
7575
env:
7676
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }}
77+
ARRAY_API_TESTS_VERSION: 2023.12
7778
# This enables the NEP 50 type promotion behavior (without it a lot of
7879
# tests fail on bad scalar type promotion behavior)
7980
NPY_PROMOTION_STATE: weak

array_api_compat/common/_aliases.py

+56-46
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import NamedTuple
1313
import inspect
1414

15-
from ._helpers import array_namespace, _check_device
15+
from ._helpers import array_namespace, _check_device, device, is_torch_array
1616

1717
# These functions are modified from the NumPy versions.
1818

@@ -264,6 +264,38 @@ def var(
264264
) -> ndarray:
265265
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
266266

267+
# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
268+
# argument
269+
270+
def cumulative_sum(
271+
x: ndarray,
272+
/,
273+
xp,
274+
*,
275+
axis: Optional[int] = None,
276+
dtype: Optional[Dtype] = None,
277+
include_initial: bool = False,
278+
**kwargs
279+
) -> ndarray:
280+
wrapped_xp = array_namespace(x)
281+
282+
# TODO: The standard is not clear about what should happen when x.ndim == 0.
283+
if axis is None:
284+
if x.ndim > 1:
285+
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
286+
axis = 0
287+
288+
res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
289+
290+
# np.cumsum does not support include_initial
291+
if include_initial:
292+
initial_shape = list(x.shape)
293+
initial_shape[axis] = 1
294+
res = xp.concatenate(
295+
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
296+
axis=axis,
297+
)
298+
return res
267299

268300
# The min and max argument names in clip are different and not optional in numpy, and type
269301
# promotion behavior is different.
@@ -281,10 +313,11 @@ def _isscalar(a):
281313
return isinstance(a, (int, float, type(None)))
282314
min_shape = () if _isscalar(min) else min.shape
283315
max_shape = () if _isscalar(max) else max.shape
284-
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
285316

286317
wrapped_xp = array_namespace(x)
287318

319+
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
320+
288321
# np.clip does type promotion but the array API clip requires that the
289322
# output have the same dtype as x. We do this instead of just downcasting
290323
# the result of xp.clip() to handle some corner cases better (e.g.,
@@ -305,20 +338,26 @@ def _isscalar(a):
305338

306339
# At least handle the case of Python integers correctly (see
307340
# https://github.com/numpy/numpy/pull/26892).
308-
if type(min) is int and min <= xp.iinfo(x.dtype).min:
341+
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
309342
min = None
310-
if type(max) is int and max >= xp.iinfo(x.dtype).max:
343+
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
311344
max = None
312345

313346
if out is None:
314-
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
347+
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
348+
copy=True, device=device(x))
315349
if min is not None:
316-
a = xp.broadcast_to(xp.asarray(min), result_shape)
350+
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
351+
# Avoid loss of precision due to torch defaulting to float32
352+
min = wrapped_xp.asarray(min, dtype=xp.float64)
353+
a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
317354
ia = (out < a) | xp.isnan(a)
318355
# torch requires an explicit cast here
319356
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
320357
if max is not None:
321-
b = xp.broadcast_to(xp.asarray(max), result_shape)
358+
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
359+
max = wrapped_xp.asarray(max, dtype=xp.float64)
360+
b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
322361
ib = (out > b) | xp.isnan(b)
323362
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
324363
# Return a scalar for 0-D
@@ -389,42 +428,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
389428
raise ValueError("nonzero() does not support zero-dimensional arrays")
390429
return xp.nonzero(x, **kwargs)
391430

392-
# sum() and prod() should always upcast when dtype=None
393-
def sum(
394-
x: ndarray,
395-
/,
396-
xp,
397-
*,
398-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
399-
dtype: Optional[Dtype] = None,
400-
keepdims: bool = False,
401-
**kwargs,
402-
) -> ndarray:
403-
# `xp.sum` already upcasts integers, but not floats or complexes
404-
if dtype is None:
405-
if x.dtype == xp.float32:
406-
dtype = xp.float64
407-
elif x.dtype == xp.complex64:
408-
dtype = xp.complex128
409-
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
410-
411-
def prod(
412-
x: ndarray,
413-
/,
414-
xp,
415-
*,
416-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
417-
dtype: Optional[Dtype] = None,
418-
keepdims: bool = False,
419-
**kwargs,
420-
) -> ndarray:
421-
if dtype is None:
422-
if x.dtype == xp.float32:
423-
dtype = xp.float64
424-
elif x.dtype == xp.complex64:
425-
dtype = xp.complex128
426-
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
427-
428431
# ceil, floor, and trunc return integers for integer inputs
429432

430433
def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
@@ -521,10 +524,17 @@ def isdtype(
521524
# array_api_strict implementation will be very strict.
522525
return dtype == kind
523526

527+
# unstack is a new function in the 2023.12 array API standard
528+
def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
529+
if x.ndim == 0:
530+
raise ValueError("Input array must be at least 1-d.")
531+
return tuple(xp.moveaxis(x, axis, 0))
532+
524533
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
525534
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
526535
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
527536
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
528-
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
529-
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
530-
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
537+
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
538+
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
539+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
540+
'unstack']

array_api_compat/common/_linalg.py

-5
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
147147
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
148148

149149
def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
150-
if dtype is None:
151-
if x.dtype == xp.float32:
152-
dtype = xp.float64
153-
elif x.dtype == xp.complex64:
154-
dtype = xp.complex128
155150
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
156151

157152
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',

array_api_compat/cupy/_aliases.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from ..common import _aliases
66
from .._internal import get_xp
77

8+
from ._info import __array_namespace_info__
9+
810
from typing import TYPE_CHECKING
911
if TYPE_CHECKING:
1012
from typing import Optional, Union
@@ -47,14 +49,13 @@
4749
astype = _aliases.astype
4850
std = get_xp(cp)(_aliases.std)
4951
var = get_xp(cp)(_aliases.var)
52+
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
5053
clip = get_xp(cp)(_aliases.clip)
5154
permute_dims = get_xp(cp)(_aliases.permute_dims)
5255
reshape = get_xp(cp)(_aliases.reshape)
5356
argsort = get_xp(cp)(_aliases.argsort)
5457
sort = get_xp(cp)(_aliases.sort)
5558
nonzero = get_xp(cp)(_aliases.nonzero)
56-
sum = get_xp(cp)(_aliases.sum)
57-
prod = get_xp(cp)(_aliases.prod)
5859
ceil = get_xp(cp)(_aliases.ceil)
5960
floor = get_xp(cp)(_aliases.floor)
6061
trunc = get_xp(cp)(_aliases.trunc)
@@ -121,14 +122,21 @@ def sign(x: ndarray, /) -> ndarray:
121122
vecdot = cp.vecdot
122123
else:
123124
vecdot = get_xp(cp)(_aliases.vecdot)
125+
124126
if hasattr(cp, 'isdtype'):
125127
isdtype = cp.isdtype
126128
else:
127129
isdtype = get_xp(cp)(_aliases.isdtype)
128130

129-
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
130-
'acosh', 'asin', 'asinh', 'atan', 'atan2',
131-
'atanh', 'bitwise_left_shift', 'bitwise_invert',
132-
'bitwise_right_shift', 'concat', 'pow', 'sign']
131+
if hasattr(cp, 'unstack'):
132+
unstack = cp.unstack
133+
else:
134+
unstack = get_xp(cp)(_aliases.unstack)
135+
136+
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
137+
'acos', 'acosh', 'asin', 'asinh', 'atan',
138+
'atan2', 'atanh', 'bitwise_left_shift',
139+
'bitwise_invert', 'bitwise_right_shift',
140+
'concat', 'pow', 'sign']
133141

134142
_all_ignore = ['cp', 'get_xp']

0 commit comments

Comments
 (0)