diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index cb2dd11..483952e 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -722,7 +722,7 @@ def __getitem__( devices = {self.device} if isinstance(key, tuple): devices.update( - [subkey.device for subkey in key if hasattr(subkey, "device")] + [subkey.device for subkey in key if isinstance(subkey, Array)] ) if len(devices) > 1: raise ValueError( diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 51f4f31..c7330d8 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -100,6 +100,36 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[:]) assert_raises(IndexError, lambda: a[idx]) +class DummyIndex: + def __init__(self, x): + self.x = x + def __index__(self): + return self.x + + +@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) +@pytest.mark.parametrize( + "integer_index", + [ + 0, + np.int8(0), + np.uint8(0), + np.int16(0), + np.uint16(0), + np.int32(0), + np.uint32(0), + np.int64(0), + np.uint64(0), + DummyIndex(0), + ], +) +def test_indexing_ints(integer_index, device): + # Ensure indexing with different integer types works on all Devices. + device = None if device is None else Device(device) + + a = arange(5, device=device) + assert a[(integer_index,)] == a[integer_index] == a[0] + @pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) def test_indexing_arrays(device):