Skip to content

Commit a6de255

Browse files
committed
ENH: allow python scalars in binary elementwise functions
1 parent d086c61 commit a6de255

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

array_api_strict/_elementwise_functions.py

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ._flags import requires_api_version
1616
from ._creation_functions import asarray
1717
from ._data_type_functions import broadcast_to, iinfo
18+
from ._helpers import _maybe_normalize_py_scalars
1819

1920
from typing import Optional, Union
2021

@@ -62,6 +63,8 @@ def add(x1: Array, x2: Array, /) -> Array:
6263
6364
See its docstring for more information.
6465
"""
66+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
67+
6568
if x1.device != x2.device:
6669
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
6770

array_api_strict/_helpers.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Private helper routines.
2+
"""
3+
import numpy as np
4+
5+
_py_scalars = (bool, int, float, complex)
6+
7+
def _maybe_normalize_py_scalars(x1, x2):
8+
from ._array_object import Array
9+
10+
if isinstance(x1, _py_scalars):
11+
if isinstance(x2, _py_scalars):
12+
raise TypeError(f"Two scalars not allowed, {type(x1) = } and {type(x2) =}")
13+
x1 = Array._new(np.asarray(x1, dtype=x2.dtype._np_dtype), device=x2.device)
14+
elif isinstance(x2, _py_scalars):
15+
x2 = Array._new(np.asarray(x2, dtype=x1.dtype._np_dtype), device=x1.device)
16+
else:
17+
# nothing to do
18+
pass
19+
return x1, x2

array_api_strict/tests/test_elementwise_functions.py

+31
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,34 @@ def test_bitwise_shift_error():
202202
assert_raises(
203203
ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))
204204
)
205+
206+
207+
def test_scalars():
208+
# Test that binary functions accept (array, scalar) and (scalar, array) arguments
209+
# and reject (scalar, scalar) arguments
210+
211+
def _sample_scalar(category):
212+
if 'boolean' in category:
213+
return True
214+
elif 'floating-point' in category:
215+
return 1.0
216+
elif 'numeric' in category or 'integer' in category or 'all' in category:
217+
return 1
218+
else:
219+
raise ValueError(f'Unknown {category = }')
220+
221+
for func_name, types in elementwise_function_input_types.items():
222+
dtypes = _dtype_categories[types]
223+
func = getattr(_elementwise_functions, func_name)
224+
if nargs(func) == 2:
225+
print(func_name, types, _sample_scalar(types))
226+
scalar = _sample_scalar(types)
227+
for dt in dtypes:
228+
array = asarray(scalar, dtype=dt)
229+
conv_scalar = asarray(scalar, dtype=array.dtype)
230+
assert func(scalar, array) == func(conv_scalar, array)
231+
assert func(array, scalar) == func(array, conv_scalar)
232+
233+
with assert_raises(TypeError):
234+
func(scalar, scalar)
235+

0 commit comments

Comments
 (0)