Skip to content

Commit 0be1074

Browse files
committed
Rudimentary tests for SupportsIndex in indexing methods
1 parent e38ce34 commit 0be1074

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

array_api_tests/test_array_object.py

+35
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,41 @@ def test_setitem(shape, dtypes, data):
153153
)
154154

155155

156+
class AwkwardIndexable:
157+
def __init__(self, value: int):
158+
self._value = value
159+
160+
def __int__(self):
161+
raise TypeError("__int__() should not be called")
162+
163+
def __index__(self):
164+
return self._value
165+
166+
167+
@pytest.mark.parametrize(
168+
"x, key",
169+
[
170+
(xp.asarray([0, 1]), AwkwardIndexable(1)),
171+
(xp.asarray([[0, 1], [2, 3]]), (0, AwkwardIndexable(1))),
172+
]
173+
)
174+
def test_getitem_supports_index(x, key):
175+
out = x[key]
176+
assert out == xp.asarray(1)
177+
178+
179+
@pytest.mark.parametrize(
180+
"x, key, expected",
181+
[
182+
(xp.asarray([0, 1]), AwkwardIndexable(1), xp.asarray([0, 42])),
183+
(xp.asarray([[0, 1], [2, 3]]), (0, AwkwardIndexable(1)), xp.asarray([[0, 42], [2, 3]])),
184+
]
185+
)
186+
def test_setitem_supports_index(x, key, expected):
187+
x[key] = xp.asarray(42)
188+
ph.assert_array_elements("__setitem__", out=x, expected=expected, out_repr="x")
189+
190+
156191
@pytest.mark.unvectorized
157192
@pytest.mark.data_dependent_shapes
158193
@given(hh.shapes(), st.data())

0 commit comments

Comments
 (0)