Skip to content

Commit 866cedb

Browse files
committed
ENH: allow python scalars in binary elementwise functions
Allow func(array, scalar) and func(scalar, array), raise on func(scalar, scalar) cross-ref data-apis/array-api#807
1 parent d086c61 commit 866cedb

File tree

3 files changed

+106
-0
lines changed

3 files changed

+106
-0
lines changed

array_api_strict/_elementwise_functions.py

+57
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ._flags import requires_api_version
1616
from ._creation_functions import asarray
1717
from ._data_type_functions import broadcast_to, iinfo
18+
from ._helpers import _maybe_normalize_py_scalars
1819

1920
from typing import Optional, Union
2021

@@ -62,6 +63,8 @@ def add(x1: Array, x2: Array, /) -> Array:
6263
6364
See its docstring for more information.
6465
"""
66+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
67+
6568
if x1.device != x2.device:
6669
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
6770

@@ -116,6 +119,8 @@ def atan2(x1: Array, x2: Array, /) -> Array:
116119
117120
See its docstring for more information.
118121
"""
122+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
123+
119124
if x1.device != x2.device:
120125
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
121126
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
@@ -144,6 +149,8 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array:
144149
145150
See its docstring for more information.
146151
"""
152+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
153+
147154
if x1.device != x2.device:
148155
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
149156

@@ -165,6 +172,8 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
165172
166173
See its docstring for more information.
167174
"""
175+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
176+
168177
if x1.device != x2.device:
169178
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
170179

@@ -197,6 +206,8 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array:
197206
198207
See its docstring for more information.
199208
"""
209+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
210+
200211
if x1.device != x2.device:
201212
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
202213

@@ -218,6 +229,8 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
218229
219230
See its docstring for more information.
220231
"""
232+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
233+
221234
if x1.device != x2.device:
222235
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
223236

@@ -238,6 +251,8 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array:
238251
239252
See its docstring for more information.
240253
"""
254+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
255+
241256
if x1.device != x2.device:
242257
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
243258

@@ -389,6 +404,8 @@ def copysign(x1: Array, x2: Array, /) -> Array:
389404
390405
See its docstring for more information.
391406
"""
407+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
408+
392409
if x1.device != x2.device:
393410
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
394411

@@ -427,6 +444,8 @@ def divide(x1: Array, x2: Array, /) -> Array:
427444
428445
See its docstring for more information.
429446
"""
447+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
448+
430449
if x1.device != x2.device:
431450
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
432451
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
@@ -443,6 +462,8 @@ def equal(x1: Array, x2: Array, /) -> Array:
443462
444463
See its docstring for more information.
445464
"""
465+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
466+
446467
if x1.device != x2.device:
447468
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
448469
# Call result type here just to raise on disallowed type combinations
@@ -493,6 +514,8 @@ def floor_divide(x1: Array, x2: Array, /) -> Array:
493514
494515
See its docstring for more information.
495516
"""
517+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
518+
496519
if x1.device != x2.device:
497520
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
498521
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
@@ -509,6 +532,8 @@ def greater(x1: Array, x2: Array, /) -> Array:
509532
510533
See its docstring for more information.
511534
"""
535+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
536+
512537
if x1.device != x2.device:
513538
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
514539
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
@@ -525,6 +550,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
525550
526551
See its docstring for more information.
527552
"""
553+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
554+
528555
if x1.device != x2.device:
529556
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
530557
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
@@ -541,6 +568,8 @@ def hypot(x1: Array, x2: Array, /) -> Array:
541568
542569
See its docstring for more information.
543570
"""
571+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
572+
544573
if x1.device != x2.device:
545574
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
546575
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
@@ -600,6 +629,8 @@ def less(x1: Array, x2: Array, /) -> Array:
600629
601630
See its docstring for more information.
602631
"""
632+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
633+
603634
if x1.device != x2.device:
604635
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
605636
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
@@ -616,6 +647,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array:
616647
617648
See its docstring for more information.
618649
"""
650+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
651+
619652
if x1.device != x2.device:
620653
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
621654
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
@@ -676,6 +709,8 @@ def logaddexp(x1: Array, x2: Array, /) -> Array:
676709
677710
See its docstring for more information.
678711
"""
712+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
713+
679714
if x1.device != x2.device:
680715
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
681716
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
@@ -692,6 +727,8 @@ def logical_and(x1: Array, x2: Array, /) -> Array:
692727
693728
See its docstring for more information.
694729
"""
730+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
731+
695732
if x1.device != x2.device:
696733
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
697734
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
@@ -719,6 +756,8 @@ def logical_or(x1: Array, x2: Array, /) -> Array:
719756
720757
See its docstring for more information.
721758
"""
759+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
760+
722761
if x1.device != x2.device:
723762
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
724763
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
@@ -735,6 +774,8 @@ def logical_xor(x1: Array, x2: Array, /) -> Array:
735774
736775
See its docstring for more information.
737776
"""
777+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
778+
738779
if x1.device != x2.device:
739780
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
740781
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
@@ -751,6 +792,8 @@ def maximum(x1: Array, x2: Array, /) -> Array:
751792
752793
See its docstring for more information.
753794
"""
795+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
796+
754797
if x1.device != x2.device:
755798
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
756799
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
@@ -769,6 +812,8 @@ def minimum(x1: Array, x2: Array, /) -> Array:
769812
770813
See its docstring for more information.
771814
"""
815+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
816+
772817
if x1.device != x2.device:
773818
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
774819
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
@@ -784,6 +829,8 @@ def multiply(x1: Array, x2: Array, /) -> Array:
784829
785830
See its docstring for more information.
786831
"""
832+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
833+
787834
if x1.device != x2.device:
788835
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
789836
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
@@ -812,6 +859,8 @@ def nextafter(x1: Array, x2: Array, /) -> Array:
812859
813860
See its docstring for more information.
814861
"""
862+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
863+
815864
if x1.device != x2.device:
816865
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
817866
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
@@ -825,6 +874,8 @@ def not_equal(x1: Array, x2: Array, /) -> Array:
825874
826875
See its docstring for more information.
827876
"""
877+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
878+
828879
if x1.device != x2.device:
829880
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
830881
# Call result type here just to raise on disallowed type combinations
@@ -851,6 +902,8 @@ def pow(x1: Array, x2: Array, /) -> Array:
851902
852903
See its docstring for more information.
853904
"""
905+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
906+
854907
if x1.device != x2.device:
855908
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
856909
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
@@ -889,6 +942,8 @@ def remainder(x1: Array, x2: Array, /) -> Array:
889942
890943
See its docstring for more information.
891944
"""
945+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
946+
892947
if x1.device != x2.device:
893948
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
894949
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
@@ -985,6 +1040,8 @@ def subtract(x1: Array, x2: Array, /) -> Array:
9851040
9861041
See its docstring for more information.
9871042
"""
1043+
x1, x2 = _maybe_normalize_py_scalars(x1, x2)
1044+
9881045
if x1.device != x2.device:
9891046
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
9901047
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:

array_api_strict/_helpers.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Private helper routines.
2+
"""
3+
import numpy as np
4+
5+
_py_scalars = (bool, int, float, complex)
6+
7+
def _maybe_normalize_py_scalars(x1, x2):
8+
from ._array_object import Array
9+
10+
if isinstance(x1, _py_scalars):
11+
if isinstance(x2, _py_scalars):
12+
raise TypeError(f"Two scalars not allowed, {type(x1) = } and {type(x2) =}")
13+
x1 = Array._new(np.asarray(x1, dtype=x2.dtype._np_dtype), device=x2.device)
14+
elif isinstance(x2, _py_scalars):
15+
x2 = Array._new(np.asarray(x2, dtype=x1.dtype._np_dtype), device=x1.device)
16+
else:
17+
# nothing to do
18+
pass
19+
return x1, x2

array_api_strict/tests/test_elementwise_functions.py

+30
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,33 @@ def test_bitwise_shift_error():
202202
assert_raises(
203203
ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))
204204
)
205+
206+
207+
def test_scalars():
208+
# Test that binary functions accept (array, scalar) and (scalar, array) arguments
209+
# and reject (scalar, scalar) arguments
210+
211+
def _sample_scalar(category):
212+
if 'boolean' in category:
213+
return True
214+
elif 'floating-point' in category:
215+
return 1.0
216+
elif 'numeric' in category or 'integer' in category or 'all' in category:
217+
return 1
218+
else:
219+
raise ValueError(f'Unknown {category = }')
220+
221+
for func_name, types in elementwise_function_input_types.items():
222+
dtypes = _dtype_categories[types]
223+
func = getattr(_elementwise_functions, func_name)
224+
if nargs(func) == 2:
225+
scalar = _sample_scalar(types)
226+
for dt in dtypes:
227+
array = asarray(scalar, dtype=dt)
228+
conv_scalar = asarray(scalar, dtype=array.dtype)
229+
assert func(scalar, array) == func(conv_scalar, array)
230+
assert func(array, scalar) == func(array, conv_scalar)
231+
232+
with assert_raises(TypeError):
233+
func(scalar, scalar)
234+

0 commit comments

Comments
 (0)