Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: torch: fix result_type with python scalars #277

Merged
merged 2 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from functools import wraps as _wraps
from functools import reduce as _reduce, wraps as _wraps
from builtins import all as _builtin_all, any as _builtin_any

from ..common import _aliases
Expand Down Expand Up @@ -124,25 +124,43 @@ def _fix_promotion(x1, x2, only_scalar=True):


def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype:
if len(arrays_and_dtypes) == 0:
raise TypeError("At least one array or dtype must be provided")
if len(arrays_and_dtypes) == 1:
num = len(arrays_and_dtypes)

if num == 0:
raise ValueError("At least one array or dtype must be provided")

elif num == 1:
x = arrays_and_dtypes[0]
if isinstance(x, torch.dtype):
return x
return x.dtype
if len(arrays_and_dtypes) > 2:
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))

x, y = arrays_and_dtypes
if isinstance(x, _py_scalars) or isinstance(y, _py_scalars):
return torch.result_type(x, y)
if num == 2:
x, y = arrays_and_dtypes
return _result_type(x, y)

else:
# sort scalars so that they are treated last
scalars, others = [], []
for x in arrays_and_dtypes:
if isinstance(x, _py_scalars):
scalars.append(x)
else:
others.append(x)
if not others:
raise ValueError("At least one array or dtype must be provided")

# combine left-to-right
return _reduce(_result_type, others + scalars)

xdt = x.dtype if not isinstance(x, torch.dtype) else x
ydt = y.dtype if not isinstance(y, torch.dtype) else y

if (xdt, ydt) in _promotion_table:
return _promotion_table[xdt, ydt]
def _result_type(x, y):
if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
xdt = x.dtype if not isinstance(x, torch.dtype) else x
ydt = y.dtype if not isinstance(y, torch.dtype) else y

if (xdt, ydt) in _promotion_table:
return _promotion_table[xdt, ydt]

# This doesn't result_type(dtype, dtype) for non-array API dtypes
# because torch.result_type only accepts tensors. This does however, allow
Expand All @@ -151,6 +169,7 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
return torch.result_type(x, y)


def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
if not isinstance(from_, torch.dtype):
from_ = from_.dtype
Expand Down
1 change: 1 addition & 0 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest

@pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277")
@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
def test_all(library):
if library == "common":
Expand Down
98 changes: 98 additions & 0 deletions tests/test_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
"""
import itertools

import pytest
import torch

from array_api_compat import torch as xp


class TestResultType:
def test_empty(self):
with pytest.raises(ValueError):
xp.result_type()

def test_one_arg(self):
for x in [1, 1.0, 1j, '...', None]:
with pytest.raises((ValueError, AttributeError)):
xp.result_type(x)

for x in [xp.float32, xp.int64, torch.complex64]:
assert xp.result_type(x) == x

for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]:
assert xp.result_type(x) == x.dtype

def test_two_args(self):
# Only include here things "unspecified" in the spec

# scalar, tensor or tensor,tensor
for x, y in [
(1., 1j),
(1j, xp.arange(3)),
(True, xp.asarray(3.)),
(xp.ones(3) == 1, 1j*xp.ones(3)),
]:
assert xp.result_type(x, y) == torch.result_type(x, y)

# dtype, scalar
for x, y in [
(1j, xp.int64),
(True, xp.float64),
]:
assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y))

# dtype, dtype
for x, y in [
(xp.bool, xp.complex64)
]:
xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y)
assert xp.result_type(x, y) == torch.result_type(xt, yt)

def test_multi_arg(self):
torch.set_default_dtype(torch.float32)

args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.]
assert xp.result_type(*args) == torch.float16

args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6]
assert xp.result_type(*args) == xp.complex64

args = [1, 2, 3j, xp.float64, 4, 5, 6]
assert xp.result_type(*args) == xp.complex128

args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
assert xp.result_type(*args) == xp.complex128

i64 = xp.ones(1, dtype=xp.int64)
f16 = xp.ones(1, dtype=xp.float16)
for i in itertools.permutations([i64, f16, 1.0, 1.0]):
assert xp.result_type(*i) == xp.float16, f"{i}"

with pytest.raises(ValueError):
xp.result_type(1, 2, 3, 4)


@pytest.mark.parametrize("default_dt", ['float32', 'float64'])
@pytest.mark.parametrize("dtype_a",
(xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
)
@pytest.mark.parametrize("dtype_b",
(xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
)
def test_gh_273(self, default_dt, dtype_a, dtype_b):
# Regression test for https://github.com/data-apis/array-api-compat/issues/273

try:
prev_default = torch.get_default_dtype()
default_dtype = getattr(torch, default_dt)
torch.set_default_dtype(default_dtype)

a = xp.asarray([2, 1], dtype=dtype_a)
b = xp.asarray([1, -1], dtype=dtype_b)
dtype_1 = xp.result_type(a, b, 1.0)
dtype_2 = xp.result_type(b, a, 1.0)
assert dtype_1 == dtype_2
finally:
torch.set_default_dtype(prev_default)
Loading