|
3 | 3 | import pytest
|
4 | 4 | from hypothesis import given, note
|
5 | 5 | from hypothesis import strategies as st
|
| 6 | +from hypothesis.control import assume |
6 | 7 |
|
7 | 8 | from . import _array_module as xp
|
8 | 9 | from . import dtype_helpers as dh
|
@@ -203,3 +204,102 @@ def test_searchsorted(data):
|
203 | 204 | expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
|
204 | 205 | )
|
205 | 206 | # TODO: shapes and values testing
|
| 207 | + |
| 208 | + |
| 209 | +@pytest.mark.unvectorized |
| 210 | +# TODO: Test with signed zeros and NaNs (and ignore them somehow) |
| 211 | +@given( |
| 212 | + x=hh.arrays( |
| 213 | + dtype=hh.real_dtypes, |
| 214 | + shape=hh.shapes(min_dims=1, min_side=1), |
| 215 | + elements={"allow_nan": False}, |
| 216 | + ), |
| 217 | + data=st.data() |
| 218 | +) |
| 219 | +def test_top_k(x, data): |
| 220 | + |
| 221 | + if dh.is_float_dtype(x.dtype): |
| 222 | + assume(not xp.any(x == -0.0) and not xp.any(x == +0.0)) |
| 223 | + |
| 224 | + axis = data.draw( |
| 225 | + st.integers(-x.ndim, x.ndim - 1), label='axis') |
| 226 | + largest = data.draw(st.booleans(), label='largest') |
| 227 | + if axis is None: |
| 228 | + k = data.draw(st.integers(1, math.prod(x.shape))) |
| 229 | + else: |
| 230 | + k = data.draw(st.integers(1, x.shape[axis])) |
| 231 | + |
| 232 | + kw = dict( |
| 233 | + x=x, |
| 234 | + k=k, |
| 235 | + axis=axis, |
| 236 | + largest=largest, |
| 237 | + ) |
| 238 | + |
| 239 | + (out_values, out_indices) = xp.top_k(x, k, axis, largest=largest) |
| 240 | + if axis is None: |
| 241 | + x = xp.reshape(x, (-1,)) |
| 242 | + axis = 0 |
| 243 | + |
| 244 | + ph.assert_dtype("top_k", in_dtype=x.dtype, out_dtype=out_values.dtype) |
| 245 | + ph.assert_dtype( |
| 246 | + "top_k", |
| 247 | + in_dtype=x.dtype, |
| 248 | + out_dtype=out_indices.dtype, |
| 249 | + expected=dh.default_int |
| 250 | + ) |
| 251 | + axes, = sh.normalise_axis(axis, x.ndim) |
| 252 | + for arr in [out_values, out_indices]: |
| 253 | + ph.assert_shape( |
| 254 | + "top_k", |
| 255 | + out_shape=arr.shape, |
| 256 | + expected=x.shape[:axes] + (k,) + x.shape[axes + 1:], |
| 257 | + kw=kw |
| 258 | + ) |
| 259 | + |
| 260 | + scalar_type = dh.get_scalar_type(x.dtype) |
| 261 | + |
| 262 | + for indices in sh.axes_ndindex(x.shape, (axes,)): |
| 263 | + |
| 264 | + # Test if the values indexed by out_indices corresponds to |
| 265 | + # the correct top_k values. |
| 266 | + elements = [scalar_type(x[idx]) for idx in indices] |
| 267 | + size = len(elements) |
| 268 | + correct_order = sorted( |
| 269 | + range(size), |
| 270 | + key=elements.__getitem__, |
| 271 | + reverse=largest |
| 272 | + ) |
| 273 | + correct_order = correct_order[:k] |
| 274 | + test_order = [out_indices[idx] for idx in indices[:k]] |
| 275 | + # Sort because top_k does not necessarily return the values in |
| 276 | + # sorted order. |
| 277 | + test_sorted_order = sorted( |
| 278 | + test_order, |
| 279 | + key=elements.__getitem__, |
| 280 | + reverse=largest |
| 281 | + ) |
| 282 | + |
| 283 | + for y_o, x_o in zip(correct_order, test_sorted_order): |
| 284 | + y_idx = indices[y_o] |
| 285 | + x_idx = indices[x_o] |
| 286 | + ph.assert_0d_equals( |
| 287 | + "top_k", |
| 288 | + x_repr=f"x[{x_idx}]", |
| 289 | + x_val=x[x_idx], |
| 290 | + out_repr=f"x[{y_idx}]", |
| 291 | + out_val=x[y_idx], |
| 292 | + kw=kw, |
| 293 | + ) |
| 294 | + |
| 295 | + # Test if the values indexed by out_indices corresponds to out_values. |
| 296 | + for y_o, x_idx in zip(test_order, indices[:k]): |
| 297 | + y_idx = indices[y_o] |
| 298 | + ph.assert_0d_equals( |
| 299 | + "top_k", |
| 300 | + x_repr=f"out_values[{x_idx}]", |
| 301 | + x_val=scalar_type(out_values[x_idx]), |
| 302 | + out_repr=f"x[{y_idx}]", |
| 303 | + out_val=x[y_idx], |
| 304 | + kw=kw |
| 305 | + ) |
0 commit comments