Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP ENH: setdiff1d for Dask and jax.jit #124

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 42 additions & 44 deletions pixi.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = ["array-api-compat>=1.10.0,<2"]
# dependencies = ["array-api-compat>=1.11.0,<2"] # DNM

[project.urls]
Homepage = "https://github.com/data-apis/array-api-extra"
Expand All @@ -48,10 +48,11 @@ platforms = ["linux-64", "osx-arm64", "win-64"]

[tool.pixi.dependencies]
python = ">=3.10,<3.14"
array-api-compat = ">=1.10.0,<2"
# array-api-compat = ">=1.11.0,<2" # DNM

[tool.pixi.pypi-dependencies]
array-api-extra = { path = ".", editable = true }
array-api-compat = { git = "https://github.com/data-apis/array-api-compat" } # DNM

[tool.pixi.feature.lint.dependencies]
typing-extensions = "*"
Expand Down
110 changes: 102 additions & 8 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_jax_array
from ._utils._compat import (
array_namespace,
is_dask_namespace,
is_jax_array,
is_jax_namespace,
)
from ._utils._helpers import asarrays
from ._utils._typing import Array

Expand Down Expand Up @@ -547,6 +552,8 @@ def setdiff1d(
/,
*,
assume_unique: bool = False,
size: int | None = None,
fill_value: object | None = None,
xp: ModuleType | None = None,
) -> Array:
"""
Expand All @@ -563,6 +570,16 @@ def setdiff1d(
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
size : int, optional
The size of the output array. This is exclusively used inside the JAX JIT, and
only for as long as JAX does not support arrays of unknown size inside it. In
all other cases, it is disregarded.
Returned elements will be clipped if they are more than size, and padded with
`fill_value` if they are less. Default: raise if inside ``jax.jit``.

fill_value : object, optional
Pad the output array with this value. This is exclusively used for JAX arrays
when running inside ``jax.jit``. Default: 0.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer.

Expand All @@ -587,13 +604,90 @@ def setdiff1d(
xp = array_namespace(x1, x2)
x1, x2 = asarrays(x1, x2, xp=xp)

if assume_unique:
x1 = xp.reshape(x1, (-1,))
x2 = xp.reshape(x2, (-1,))
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
x1 = xp.reshape(x1, (-1,))
x2 = xp.reshape(x2, (-1,))
if x1.shape == (0,) or x2.shape == (0,):
return x1

def _x1_not_in_x2(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
"""For each element of x1, return True if it is not also in x2."""
# Even when assume_unique=True, there is no provision for x to be sorted
x2 = xp.sort(x2)
idx = xp.searchsorted(x2, x1)

# FIXME at() is faster but needs JAX jit support for bool mask
# idx = at(idx, idx == x2.shape[0]).set(0)
idx = xp.where(idx == x2.shape[0], xp.zeros_like(idx), idx)

return xp.take(x2, idx, axis=0) != x1

def _generic_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
"""Generic implementation (including eager JAX)."""
# Note: there is no provision in the Array API for xp.unique_values to sort
if not assume_unique:
# Call unique_values early to speed up the algorithm
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)
mask = _x1_not_in_x2(x1, x2)
x1 = x1[mask]
return x1 if assume_unique else xp.sort(x1)

def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
"""
Dask implementation.

Works around unique_values returning unknown shapes.
"""
# Do not call unique_values yet, as it would make array shapes unknown
mask = _x1_not_in_x2(x1, x2)
x1 = x1[mask]
# Note: da.unique_values sorts
return x1 if assume_unique else xp.unique_values(x1)

def _jax_jit_impl(
x1: Array, x2: Array, size: int | None, fill_value: object | None
) -> Array: # numpydoc ignore=PR01,RT01
"""
JAX implementation inside jax.jit.

Works around unique_values requiring a size= parameter
and not being able to filter by a boolean mask.
Returns array the same size as x1, padded with fill_value.
"""
if size is None:
msg = "`size` is mandatory when running inside `jax.jit`."
raise ValueError(msg)
if fill_value is None:
fill_value = xp.zeros((), dtype=x1.dtype)
else:
fill_value = xp.asarray(fill_value, dtype=x1.dtype)
if cast(Array, fill_value).ndim != 0:
msg = "`fill_value` must be a scalar."
raise ValueError(msg)

# unique_values inside jax.jit is not supported unless it's got a fixed size
mask = _x1_not_in_x2(x1, x2)
x1 = xp.where(mask, x1, fill_value)
# Move fill_value to the right
x1 = xp.take(x1, xp.argsort(~mask, stable=True))
x1 = x1[:size]
x1 = xp.unique_values(x1, size=size, fill_value=fill_value)

if is_dask_namespace(xp):
return _dask_impl(x1, x2)

if is_jax_namespace(xp):
import jax

try:
return _generic_impl(x1, x2) # eager mode
except (
jax.errors.ConcretizationTypeError,
jax.errors.NonConcreteBooleanIndexError,
):
return _jax_jit_impl(x1, x2, size, fill_value) # inside jax.jit

return _generic_impl(x1, x2)


def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
Expand Down
61 changes: 1 addition & 60 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,66 +10,7 @@
from ._compat import is_array_api_obj, is_numpy_array
from ._typing import Array

__all__ = ["in1d", "mean"]


def in1d(
x1: Array,
x2: Array,
/,
*,
assume_unique: bool = False,
invert: bool = False,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""
Check whether each element of an array is also present in a second array.

Returns a boolean array the same length as `x1` that is True
where an element of `x1` is in `x2` and False otherwise.

This function has been adapted using the original implementation
present in numpy:
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
"""
if xp is None:
xp = _compat.array_namespace(x1, x2)

# This code is run to make the code significantly faster
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
if invert:
mask = xp.ones(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
for a in x2:
mask &= x1 != a
else:
mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
for a in x2:
mask |= x1 == a
return mask

rev_idx = xp.empty(0) # placeholder
if not assume_unique:
x1, rev_idx = xp.unique_inverse(x1)
x2 = xp.unique_values(x2)

ar = xp.concat((x1, x2))
device_ = _compat.device(ar)
# We need this to be a stable sort.
order = xp.argsort(ar, stable=True)
reverse_order = xp.argsort(order, stable=True)
sar = xp.take(ar, order, axis=0)
ar_size = _compat.size(sar)
assert ar_size is not None, "xp.unique*() on lazy backends raises"
if ar_size >= 1:
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
else:
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
ret = xp.take(flag, reverse_order, axis=0)

if assume_unique:
return ret[: x1.shape[0]]
return xp.take(ret, rev_idx, axis=0)
__all__ = ["mean"]


def mean(
Expand Down
6 changes: 2 additions & 4 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
lazy_xp_function(kron, static_argnames="xp")
lazy_xp_function(nunique, static_argnames="xp")
lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp"))
# FIXME calls in1d which calls xp.unique_values without size
lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp"))
lazy_xp_function(setdiff1d, static_argnames=("assume_unique", "xp"))
# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238
lazy_xp_function(sinc, jax_jit=False, static_argnames="xp")

Expand Down Expand Up @@ -576,8 +575,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
assert padded.shape == (4, 4)


@pytest.mark.skip_xp_backend(Backend.DASK, reason="no argsort")
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray")
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no sort")
class TestSetDiff1D:
@pytest.mark.skip_xp_backend(
Backend.TORCH, reason="index_select not implemented for uint32"
Expand Down
41 changes: 1 addition & 40 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,10 @@
import pytest

from array_api_extra._lib import Backend
from array_api_extra._lib._testing import xp_assert_equal
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._helpers import asarrays, in1d
from array_api_extra._lib._utils._typing import Device
from array_api_extra.testing import lazy_xp_function
from array_api_extra._lib._utils._helpers import asarrays

# mypy: disable-error-code=no-untyped-usage

# FIXME calls xp.unique_values without size
lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp"))


class TestIn1D:
@pytest.mark.skip_xp_backend(Backend.DASK, reason="no argsort")
@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="no unique_inverse, no device kwarg in asarray"
)
# cover both code paths
@pytest.mark.parametrize("n", [9, 15])
def test_no_invert_assume_unique(self, xp: ModuleType, n: int):
x1 = xp.asarray([3, 8, 20])
x2 = xp.arange(n)
expected = xp.asarray([True, True, False])
actual = in1d(x1, x2)
xp_assert_equal(actual, expected)

@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray")
def test_device(self, xp: ModuleType, device: Device):
x1 = xp.asarray([3, 8, 20], device=device)
x2 = xp.asarray([2, 3, 4], device=device)
assert get_device(in1d(x1, x2)) == device

@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="explicit xp")
@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="no arange, no device kwarg in asarray"
)
def test_xp(self, xp: ModuleType):
x1 = xp.asarray([1, 6])
x2 = xp.arange(5)
expected = xp.asarray([True, False])
actual = in1d(x1, x2, xp=xp)
xp_assert_equal(actual, expected)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.parametrize(
Expand Down
Loading