diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a2ed1449..5a69d27e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -22,9 +22,9 @@ try: # torch >=2.3 _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64} + _HAS_LARGE_UINT = True except AttributeError: - pass - + _HAS_LARGE_UINT = False _array_api_dtypes = { torch.bool, @@ -35,47 +35,23 @@ torch.complex128, } -_promotion_table = { - # bool - (torch.bool, torch.bool): torch.bool, +_promotion_table = { # ints - (torch.int8, torch.int8): torch.int8, (torch.int8, torch.int16): torch.int16, (torch.int8, torch.int32): torch.int32, (torch.int8, torch.int64): torch.int64, - (torch.int16, torch.int8): torch.int16, - (torch.int16, torch.int16): torch.int16, (torch.int16, torch.int32): torch.int32, (torch.int16, torch.int64): torch.int64, - (torch.int32, torch.int8): torch.int32, - (torch.int32, torch.int16): torch.int32, - (torch.int32, torch.int32): torch.int32, (torch.int32, torch.int64): torch.int64, - (torch.int64, torch.int8): torch.int64, - (torch.int64, torch.int16): torch.int64, - (torch.int64, torch.int32): torch.int64, - (torch.int64, torch.int64): torch.int64, - # uints - (torch.uint8, torch.uint8): torch.uint8, # ints and uints (mixed sign) - (torch.int8, torch.uint8): torch.int16, - (torch.int16, torch.uint8): torch.int16, - (torch.int32, torch.uint8): torch.int32, - (torch.int64, torch.uint8): torch.int64, (torch.uint8, torch.int8): torch.int16, (torch.uint8, torch.int16): torch.int16, (torch.uint8, torch.int32): torch.int32, (torch.uint8, torch.int64): torch.int64, # floats - (torch.float32, torch.float32): torch.float32, (torch.float32, torch.float64): torch.float64, - (torch.float64, torch.float32): torch.float64, - (torch.float64, torch.float64): torch.float64, # complexes - (torch.complex64, torch.complex64): torch.complex64, (torch.complex64, torch.complex128): torch.complex128, - (torch.complex128, torch.complex64): torch.complex128, - (torch.complex128, torch.complex128): torch.complex128, # Mixed float and complex (torch.float32, torch.complex64): torch.complex64, (torch.float32, torch.complex128): torch.complex128, @@ -83,6 +59,31 @@ (torch.float64, torch.complex128): torch.complex128, } +if _HAS_LARGE_UINT: # torch >=2.3 + _promotion_table.update( + { + # uints + (torch.uint8, torch.uint16): torch.uint16, + (torch.uint8, torch.uint32): torch.uint32, + (torch.uint8, torch.uint64): torch.uint64, + (torch.uint16, torch.uint32): torch.uint32, + (torch.uint16, torch.uint64): torch.uint64, + (torch.uint32, torch.uint64): torch.uint64, + # ints and uints (mixed sign) + (torch.uint16, torch.int8): torch.int32, + (torch.uint16, torch.int16): torch.int32, + (torch.uint16, torch.int32): torch.int32, + (torch.uint16, torch.int64): torch.int64, + (torch.uint32, torch.int8): torch.int64, + (torch.uint32, torch.int16): torch.int64, + (torch.uint32, torch.int32): torch.int64, + (torch.uint32, torch.int64): torch.int64, + } + ) + +_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()}) +_promotion_table.update({(a, a): a for a in _array_api_dtypes}) + def _two_arg(f): @_wraps(f) @@ -301,6 +302,31 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out + +def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: + """ + Implements `sum(..., axis=())` and `prod(..., axis=())`. + + Works around https://github.com/pytorch/pytorch/issues/29137 + """ + if dtype is not None: + return x.clone() if dtype == x.dtype else x.to(dtype) + + if x.dtype in (torch.int8, torch.int16, torch.int32): + return x.to(torch.int64) + + if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32): + return x.to(torch.uint64) + + if x.dtype == torch.uint8: + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what prod does + # when axis=None. + return x.to(torch.int64) + + return x.clone() + + def prod(x: Array, /, *, @@ -308,20 +334,9 @@ def prod(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic - # below because it still needs to upcast. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) # torch.prod doesn't support multiple axes # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): @@ -330,7 +345,7 @@ def prod(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.prod(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) @@ -343,25 +358,14 @@ def sum(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. - # Make sure it upcasts. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.sum(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) @@ -372,7 +376,7 @@ def any(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim + if axis == (): return x.to(torch.bool) # torch.any doesn't support multiple axes @@ -384,7 +388,7 @@ def any(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.any(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.any doesn't return bool for uint8 @@ -396,7 +400,7 @@ def all(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim + if axis == (): return x.to(torch.bool) # torch.all doesn't support multiple axes @@ -408,7 +412,7 @@ def all(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.all(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.all doesn't return bool for uint8