Skip to content

Commit 0954fcf

Browse files
committed
ENH: generate numpy scalars or 0D arrays
1 parent a3f3f37 commit 0954fcf

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

array_api_tests/hypothesis_helpers.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
6565

6666

6767
@wraps(xps.arrays)
68-
def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
68+
def arrays_no_scalars(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
6969
"""xps.arrays() without the crazy large numbers."""
7070
if isinstance(dtype, SearchStrategy):
7171
return dtype.flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs))
@@ -78,6 +78,19 @@ def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
7878
return xps.arrays(dtype, *args, elements=elements, **kwargs)
7979

8080

81+
def _f(a, flag):
82+
return a[()] if a.ndim==0 and flag else a
83+
84+
85+
@wraps(xps.arrays)
86+
def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
87+
"""xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars.
88+
89+
Is only relevant for numpy: on all other libraries, array[()] is no-op.
90+
"""
91+
return builds(_f, arrays_no_scalars(dtype, *args, elements=elements, **kwargs), booleans())
92+
93+
8194
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
8295
_sorted_dtypes = [d for category in _dtype_categories for d in category]
8396

0 commit comments

Comments
 (0)