@@ -65,7 +65,7 @@ def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
65
65
66
66
67
67
@wraps (xps .arrays )
68
- def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
68
+ def arrays_no_scalars (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
69
69
"""xps.arrays() without the crazy large numbers."""
70
70
if isinstance (dtype , SearchStrategy ):
71
71
return dtype .flatmap (lambda d : arrays (d , * args , elements = elements , ** kwargs ))
@@ -78,6 +78,19 @@ def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
78
78
return xps .arrays (dtype , * args , elements = elements , ** kwargs )
79
79
80
80
81
+ def _f (a , flag ):
82
+ return a [()] if a .ndim == 0 and flag else a
83
+
84
+
85
+ @wraps (xps .arrays )
86
+ def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
87
+ """xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars.
88
+
89
+ Is only relevant for numpy: on all other libraries, array[()] is no-op.
90
+ """
91
+ return builds (_f , arrays_no_scalars (dtype , * args , elements = elements , ** kwargs ), booleans ())
92
+
93
+
81
94
_dtype_categories = [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .real_float_dtypes , dh .complex_dtypes ]
82
95
_sorted_dtypes = [d for category in _dtype_categories for d in category ]
83
96
0 commit comments