12
12
from typing import NamedTuple
13
13
import inspect
14
14
15
- from ._helpers import array_namespace , _check_device
15
+ from ._helpers import array_namespace , _check_device , device , is_torch_array
16
16
17
17
# These functions are modified from the NumPy versions.
18
18
@@ -264,6 +264,38 @@ def var(
264
264
) -> ndarray :
265
265
return xp .var (x , axis = axis , ddof = correction , keepdims = keepdims , ** kwargs )
266
266
267
+ # cumulative_sum is renamed from cumsum, and adds the include_initial keyword
268
+ # argument
269
+
270
+ def cumulative_sum (
271
+ x : ndarray ,
272
+ / ,
273
+ xp ,
274
+ * ,
275
+ axis : Optional [int ] = None ,
276
+ dtype : Optional [Dtype ] = None ,
277
+ include_initial : bool = False ,
278
+ ** kwargs
279
+ ) -> ndarray :
280
+ wrapped_xp = array_namespace (x )
281
+
282
+ # TODO: The standard is not clear about what should happen when x.ndim == 0.
283
+ if axis is None :
284
+ if x .ndim > 1 :
285
+ raise ValueError ("axis must be specified in cumulative_sum for more than one dimension" )
286
+ axis = 0
287
+
288
+ res = xp .cumsum (x , axis = axis , dtype = dtype , ** kwargs )
289
+
290
+ # np.cumsum does not support include_initial
291
+ if include_initial :
292
+ initial_shape = list (x .shape )
293
+ initial_shape [axis ] = 1
294
+ res = xp .concatenate (
295
+ [wrapped_xp .zeros (shape = initial_shape , dtype = res .dtype , device = device (res )), res ],
296
+ axis = axis ,
297
+ )
298
+ return res
267
299
268
300
# The min and max argument names in clip are different and not optional in numpy, and type
269
301
# promotion behavior is different.
@@ -281,10 +313,11 @@ def _isscalar(a):
281
313
return isinstance (a , (int , float , type (None )))
282
314
min_shape = () if _isscalar (min ) else min .shape
283
315
max_shape = () if _isscalar (max ) else max .shape
284
- result_shape = xp .broadcast_shapes (x .shape , min_shape , max_shape )
285
316
286
317
wrapped_xp = array_namespace (x )
287
318
319
+ result_shape = xp .broadcast_shapes (x .shape , min_shape , max_shape )
320
+
288
321
# np.clip does type promotion but the array API clip requires that the
289
322
# output have the same dtype as x. We do this instead of just downcasting
290
323
# the result of xp.clip() to handle some corner cases better (e.g.,
@@ -305,20 +338,26 @@ def _isscalar(a):
305
338
306
339
# At least handle the case of Python integers correctly (see
307
340
# https://github.com/numpy/numpy/pull/26892).
308
- if type (min ) is int and min <= xp .iinfo (x .dtype ).min :
341
+ if type (min ) is int and min <= wrapped_xp .iinfo (x .dtype ).min :
309
342
min = None
310
- if type (max ) is int and max >= xp .iinfo (x .dtype ).max :
343
+ if type (max ) is int and max >= wrapped_xp .iinfo (x .dtype ).max :
311
344
max = None
312
345
313
346
if out is None :
314
- out = wrapped_xp .asarray (xp .broadcast_to (x , result_shape ), copy = True )
347
+ out = wrapped_xp .asarray (xp .broadcast_to (x , result_shape ),
348
+ copy = True , device = device (x ))
315
349
if min is not None :
316
- a = xp .broadcast_to (xp .asarray (min ), result_shape )
350
+ if is_torch_array (x ) and x .dtype == xp .float64 and _isscalar (min ):
351
+ # Avoid loss of precision due to torch defaulting to float32
352
+ min = wrapped_xp .asarray (min , dtype = xp .float64 )
353
+ a = xp .broadcast_to (wrapped_xp .asarray (min , device = device (x )), result_shape )
317
354
ia = (out < a ) | xp .isnan (a )
318
355
# torch requires an explicit cast here
319
356
out [ia ] = wrapped_xp .astype (a [ia ], out .dtype )
320
357
if max is not None :
321
- b = xp .broadcast_to (xp .asarray (max ), result_shape )
358
+ if is_torch_array (x ) and x .dtype == xp .float64 and _isscalar (max ):
359
+ max = wrapped_xp .asarray (max , dtype = xp .float64 )
360
+ b = xp .broadcast_to (wrapped_xp .asarray (max , device = device (x )), result_shape )
322
361
ib = (out > b ) | xp .isnan (b )
323
362
out [ib ] = wrapped_xp .astype (b [ib ], out .dtype )
324
363
# Return a scalar for 0-D
@@ -389,42 +428,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
389
428
raise ValueError ("nonzero() does not support zero-dimensional arrays" )
390
429
return xp .nonzero (x , ** kwargs )
391
430
392
- # sum() and prod() should always upcast when dtype=None
393
- def sum (
394
- x : ndarray ,
395
- / ,
396
- xp ,
397
- * ,
398
- axis : Optional [Union [int , Tuple [int , ...]]] = None ,
399
- dtype : Optional [Dtype ] = None ,
400
- keepdims : bool = False ,
401
- ** kwargs ,
402
- ) -> ndarray :
403
- # `xp.sum` already upcasts integers, but not floats or complexes
404
- if dtype is None :
405
- if x .dtype == xp .float32 :
406
- dtype = xp .float64
407
- elif x .dtype == xp .complex64 :
408
- dtype = xp .complex128
409
- return xp .sum (x , axis = axis , dtype = dtype , keepdims = keepdims , ** kwargs )
410
-
411
- def prod (
412
- x : ndarray ,
413
- / ,
414
- xp ,
415
- * ,
416
- axis : Optional [Union [int , Tuple [int , ...]]] = None ,
417
- dtype : Optional [Dtype ] = None ,
418
- keepdims : bool = False ,
419
- ** kwargs ,
420
- ) -> ndarray :
421
- if dtype is None :
422
- if x .dtype == xp .float32 :
423
- dtype = xp .float64
424
- elif x .dtype == xp .complex64 :
425
- dtype = xp .complex128
426
- return xp .prod (x , dtype = dtype , axis = axis , keepdims = keepdims , ** kwargs )
427
-
428
431
# ceil, floor, and trunc return integers for integer inputs
429
432
430
433
def ceil (x : ndarray , / , xp , ** kwargs ) -> ndarray :
@@ -521,10 +524,17 @@ def isdtype(
521
524
# array_api_strict implementation will be very strict.
522
525
return dtype == kind
523
526
527
+ # unstack is a new function in the 2023.12 array API standard
528
+ def unstack (x : ndarray , / , xp , * , axis : int = 0 ) -> Tuple [ndarray , ...]:
529
+ if x .ndim == 0 :
530
+ raise ValueError ("Input array must be at least 1-d." )
531
+ return tuple (xp .moveaxis (x , axis , 0 ))
532
+
524
533
__all__ = ['arange' , 'empty' , 'empty_like' , 'eye' , 'full' , 'full_like' ,
525
534
'linspace' , 'ones' , 'ones_like' , 'zeros' , 'zeros_like' ,
526
535
'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
527
536
'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
528
- 'astype' , 'std' , 'var' , 'clip' , 'permute_dims' , 'reshape' , 'argsort' ,
529
- 'sort' , 'nonzero' , 'sum' , 'prod' , 'ceil' , 'floor' , 'trunc' ,
530
- 'matmul' , 'matrix_transpose' , 'tensordot' , 'vecdot' , 'isdtype' ]
537
+ 'astype' , 'std' , 'var' , 'cumulative_sum' , 'clip' , 'permute_dims' ,
538
+ 'reshape' , 'argsort' , 'sort' , 'nonzero' , 'ceil' , 'floor' , 'trunc' ,
539
+ 'matmul' , 'matrix_transpose' , 'tensordot' , 'vecdot' , 'isdtype' ,
540
+ 'unstack' ]
0 commit comments