Skip to content

Commit 1f30288

Browse files
committed
BUG: torch: fix result_type with python scalars
1. Allow inputs to be arrays or dtypes or python scalars 2. Keep the pytorch-specific additions, e.g. `result_type(int, float) -> float`, `result_type(scalar, scalar) -> dtype` which are unspecified in the standard 3. Since pytorch only defines a binary `result_type` function, add a version with multiple inputs. The latter is a bit tricky because we want to - keep allowing "unspecified" behaviors - keep standard-allowed promotions compliant - (preferably) make result_type independent on the argument order The latter is important because of `int,float->float` promotions which break associativity. So what we do, we always promote all scalars after all array/dtype arguments.
1 parent e14754b commit 1f30288

File tree

2 files changed

+106
-13
lines changed

2 files changed

+106
-13
lines changed

array_api_compat/torch/_aliases.py

+32-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from functools import wraps as _wraps
3+
from functools import reduce as _reduce, wraps as _wraps
44
from builtins import all as _builtin_all, any as _builtin_any
55

66
from ..common import _aliases
@@ -124,25 +124,43 @@ def _fix_promotion(x1, x2, only_scalar=True):
124124

125125

126126
def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype:
127-
if len(arrays_and_dtypes) == 0:
128-
raise TypeError("At least one array or dtype must be provided")
129-
if len(arrays_and_dtypes) == 1:
127+
num = len(arrays_and_dtypes)
128+
129+
if num == 0:
130+
raise ValueError("At least one array or dtype must be provided")
131+
132+
elif num == 1:
130133
x = arrays_and_dtypes[0]
131134
if isinstance(x, torch.dtype):
132135
return x
133136
return x.dtype
134-
if len(arrays_and_dtypes) > 2:
135-
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
136137

137-
x, y = arrays_and_dtypes
138-
if isinstance(x, _py_scalars) or isinstance(y, _py_scalars):
139-
return torch.result_type(x, y)
138+
if num == 2:
139+
x, y = arrays_and_dtypes
140+
return _result_type(x, y)
141+
142+
else:
143+
# sort scalars so that they are treated last
144+
scalars, others = [], []
145+
for x in arrays_and_dtypes:
146+
if isinstance(x, _py_scalars):
147+
scalars.append(x)
148+
else:
149+
others.append(x)
150+
if not others:
151+
raise ValueError("At least one array or dtype must be provided")
152+
153+
# combine left-to-right
154+
return _reduce(_result_type, others + scalars)
140155

141-
xdt = x.dtype if not isinstance(x, torch.dtype) else x
142-
ydt = y.dtype if not isinstance(y, torch.dtype) else y
143156

144-
if (xdt, ydt) in _promotion_table:
145-
return _promotion_table[xdt, ydt]
157+
def _result_type(x, y):
158+
if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
159+
xdt = x.dtype if not isinstance(x, torch.dtype) else x
160+
ydt = y.dtype if not isinstance(y, torch.dtype) else y
161+
162+
if (xdt, ydt) in _promotion_table:
163+
return _promotion_table[xdt, ydt]
146164

147165
# This doesn't result_type(dtype, dtype) for non-array API dtypes
148166
# because torch.result_type only accepts tensors. This does however, allow
@@ -151,6 +169,7 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple
151169
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
152170
return torch.result_type(x, y)
153171

172+
154173
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
155174
if not isinstance(from_, torch.dtype):
156175
from_ = from_.dtype

tests/test_torch.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
2+
"""
3+
import itertools
4+
5+
import pytest
6+
import torch
7+
8+
from array_api_compat import torch as xp
9+
10+
11+
class TestResultType:
12+
def test_empty(self):
13+
with pytest.raises(ValueError):
14+
xp.result_type()
15+
16+
def test_one_arg(self):
17+
for x in [1, 1.0, 1j, '...', None]:
18+
with pytest.raises((ValueError, AttributeError)):
19+
xp.result_type(x)
20+
21+
for x in [xp.float32, xp.int64, torch.complex64]:
22+
assert xp.result_type(x) == x
23+
24+
for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]:
25+
assert xp.result_type(x) == x.dtype
26+
27+
def test_two_args(self):
28+
# Only include here things "unspecified" in the spec
29+
30+
# scalar, tensor or tensor,tensor
31+
for x, y in [
32+
(1., 1j),
33+
(1j, xp.arange(3)),
34+
(True, xp.asarray(3.)),
35+
(xp.ones(3) == 1, 1j*xp.ones(3)),
36+
]:
37+
assert xp.result_type(x, y) == torch.result_type(x, y)
38+
39+
# dtype, scalar
40+
for x, y in [
41+
(1j, xp.int64),
42+
(True, xp.float64),
43+
]:
44+
assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y))
45+
46+
# dtype, dtype
47+
for x, y in [
48+
(xp.bool, xp.complex64)
49+
]:
50+
xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y)
51+
assert xp.result_type(x, y) == torch.result_type(xt, yt)
52+
53+
def test_multi_arg(self):
54+
torch.set_default_dtype(torch.float32)
55+
56+
args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.]
57+
assert xp.result_type(*args) == torch.float16
58+
59+
args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6]
60+
assert xp.result_type(*args) == xp.complex64
61+
62+
args = [1, 2, 3j, xp.float64, 4, 5, 6]
63+
assert xp.result_type(*args) == xp.complex128
64+
65+
args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
66+
assert xp.result_type(*args) == xp.complex128
67+
68+
i64 = xp.ones(1, dtype=xp.int64)
69+
f16 = xp.ones(1, dtype=xp.float16)
70+
for i in itertools.permutations([i64, f16, 1.0, 1.0]):
71+
assert xp.result_type(*i) == xp.float16, f"{i}"
72+
73+
with pytest.raises(ValueError):
74+
xp.result_type(1, 2, 3, 4)

0 commit comments

Comments
 (0)