Skip to content

Commit 8994765

Browse files
committed
WIP: top_k tests
The purpose of this PR is to continue several threads of discussion regarding `top_k`. This follows roughly the specifications of `top_k` in data-apis/array-api#722, with slight modifications to the API: ```py def topk( x: array, k: int, /, axis: Optional[int] = None, *, largest: bool = True, ) -> Tuple[array, array]: ... ``` Modifications: - `mode: Literal["largest", "smallest"]` is replaced with `largest: bool` - `axis` is no longer a kw-only arg. This makes `torch.topk` slightly more compatible. The tests implemented here follows the proposed `top_k` implementation at numpy/numpy#26666.
1 parent dbdca7b commit 8994765

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

array_api_tests/test_searching_functions.py

+100
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from hypothesis import given, note
55
from hypothesis import strategies as st
6+
from hypothesis.control import assume
67

78
from . import _array_module as xp
89
from . import dtype_helpers as dh
@@ -203,3 +204,102 @@ def test_searchsorted(data):
203204
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
204205
)
205206
# TODO: shapes and values testing
207+
208+
209+
@pytest.mark.unvectorized
210+
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
211+
@given(
212+
x=hh.arrays(
213+
dtype=hh.real_dtypes,
214+
shape=hh.shapes(min_dims=1, min_side=1),
215+
elements={"allow_nan": False},
216+
),
217+
data=st.data()
218+
)
219+
def test_top_k(x, data):
220+
221+
if dh.is_float_dtype(x.dtype):
222+
assume(not xp.any(x == -0.0) and not xp.any(x == +0.0))
223+
224+
axis = data.draw(
225+
st.integers(-x.ndim, x.ndim - 1), label='axis')
226+
largest = data.draw(st.booleans(), label='largest')
227+
if axis is None:
228+
k = data.draw(st.integers(1, math.prod(x.shape)))
229+
else:
230+
k = data.draw(st.integers(1, x.shape[axis]))
231+
232+
kw = dict(
233+
x=x,
234+
k=k,
235+
axis=axis,
236+
largest=largest,
237+
)
238+
239+
(out_values, out_indices) = xp.top_k(x, k, axis, largest=largest)
240+
if axis is None:
241+
x = xp.reshape(x, (-1,))
242+
axis = 0
243+
244+
ph.assert_dtype("top_k", in_dtype=x.dtype, out_dtype=out_values.dtype)
245+
ph.assert_dtype(
246+
"top_k",
247+
in_dtype=x.dtype,
248+
out_dtype=out_indices.dtype,
249+
expected=dh.default_int
250+
)
251+
axes, = sh.normalise_axis(axis, x.ndim)
252+
for arr in [out_values, out_indices]:
253+
ph.assert_shape(
254+
"top_k",
255+
out_shape=arr.shape,
256+
expected=x.shape[:axes] + (k,) + x.shape[axes + 1:],
257+
kw=kw
258+
)
259+
260+
scalar_type = dh.get_scalar_type(x.dtype)
261+
262+
for indices in sh.axes_ndindex(x.shape, (axes,)):
263+
264+
# Test if the values indexed by out_indices corresponds to
265+
# the correct top_k values.
266+
elements = [scalar_type(x[idx]) for idx in indices]
267+
size = len(elements)
268+
correct_order = sorted(
269+
range(size),
270+
key=elements.__getitem__,
271+
reverse=largest
272+
)
273+
correct_order = correct_order[:k]
274+
test_order = [out_indices[idx] for idx in indices[:k]]
275+
# Sort because top_k does not necessarily return the values in
276+
# sorted order.
277+
test_sorted_order = sorted(
278+
test_order,
279+
key=elements.__getitem__,
280+
reverse=largest
281+
)
282+
283+
for y_o, x_o in zip(correct_order, test_sorted_order):
284+
y_idx = indices[y_o]
285+
x_idx = indices[x_o]
286+
ph.assert_0d_equals(
287+
"top_k",
288+
x_repr=f"x[{x_idx}]",
289+
x_val=x[x_idx],
290+
out_repr=f"x[{y_idx}]",
291+
out_val=x[y_idx],
292+
kw=kw,
293+
)
294+
295+
# Test if the values indexed by out_indices corresponds to out_values.
296+
for y_o, x_idx in zip(test_order, indices[:k]):
297+
y_idx = indices[y_o]
298+
ph.assert_0d_equals(
299+
"top_k",
300+
x_repr=f"out_values[{x_idx}]",
301+
x_val=scalar_type(out_values[x_idx]),
302+
out_repr=f"x[{y_idx}]",
303+
out_val=x[y_idx],
304+
kw=kw
305+
)

0 commit comments

Comments
 (0)