@@ -153,6 +153,41 @@ def test_setitem(shape, dtypes, data):
153
153
)
154
154
155
155
156
+ class AwkwardIndexable :
157
+ def __init__ (self , value : int ):
158
+ self ._value = value
159
+
160
+ def __int__ (self ):
161
+ raise TypeError ("__int__() should not be called" )
162
+
163
+ def __index__ (self ):
164
+ return self ._value
165
+
166
+
167
+ @pytest .mark .parametrize (
168
+ "x, key" ,
169
+ [
170
+ (xp .asarray ([0 , 1 ]), AwkwardIndexable (1 )),
171
+ (xp .asarray ([[0 , 1 ], [2 , 3 ]]), (0 , AwkwardIndexable (1 ))),
172
+ ]
173
+ )
174
+ def test_getitem_supports_index (x , key ):
175
+ out = x [key ]
176
+ assert out == xp .asarray (1 )
177
+
178
+
179
+ @pytest .mark .parametrize (
180
+ "x, key, expected" ,
181
+ [
182
+ (xp .asarray ([0 , 1 ]), AwkwardIndexable (1 ), xp .asarray ([0 , 42 ])),
183
+ (xp .asarray ([[0 , 1 ], [2 , 3 ]]), (0 , AwkwardIndexable (1 )), xp .asarray ([[0 , 42 ], [2 , 3 ]])),
184
+ ]
185
+ )
186
+ def test_setitem_supports_index (x , key , expected ):
187
+ x [key ] = xp .asarray (42 )
188
+ ph .assert_array_elements ("__setitem__" , out = x , expected = expected , out_repr = "x" )
189
+
190
+
156
191
@pytest .mark .unvectorized
157
192
@pytest .mark .data_dependent_shapes
158
193
@given (hh .shapes (), st .data ())
0 commit comments