Skip to content

Commit 6e3c824

Browse files
committed
WIP ENH: setdiff1d for Dask and jax.jit
1 parent 8fa3fd2 commit 6e3c824

File tree

6 files changed

+140
-164
lines changed

6 files changed

+140
-164
lines changed

pixi.lock

+42-44
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Typing :: Typed",
2727
]
2828
dynamic = ["version"]
29-
dependencies = ["array-api-compat>=1.10.0,<2"]
29+
# dependencies = ["array-api-compat>=1.10.0,<2"] # DNM
3030

3131
[project.urls]
3232
Homepage = "https://github.com/data-apis/array-api-extra"
@@ -48,10 +48,11 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
4848

4949
[tool.pixi.dependencies]
5050
python = ">=3.10,<3.14"
51-
array-api-compat = ">=1.10.0,<2"
51+
# array-api-compat = ">=1.10.0,<2" # DNM
5252

5353
[tool.pixi.pypi-dependencies]
5454
array-api-extra = { path = ".", editable = true }
55+
array-api-compat = { git = "git+https://github.com/crusaderky/array-api-compat.git", branch = "dask_sort" }
5556

5657
[tool.pixi.feature.lint.dependencies]
5758
typing-extensions = "*"

src/array_api_extra/_lib/_funcs.py

+92-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
14-
from ._utils._compat import array_namespace, is_jax_array
14+
from ._utils._compat import (
15+
array_namespace,
16+
is_dask_namespace,
17+
is_jax_array,
18+
is_jax_namespace,
19+
)
1520
from ._utils._typing import Array
1621

1722
__all__ = [
@@ -539,6 +544,7 @@ def setdiff1d(
539544
/,
540545
*,
541546
assume_unique: bool = False,
547+
fill_value: object | None = None,
542548
xp: ModuleType | None = None,
543549
) -> Array:
544550
"""
@@ -555,6 +561,11 @@ def setdiff1d(
555561
assume_unique : bool
556562
If ``True``, the input arrays are both assumed to be unique, which
557563
can speed up the calculation. Default is ``False``.
564+
fill_value : object, optional
565+
Pad the output array with this value.
566+
567+
This is exclusively used for JAX arrays when running inside ``jax.jit``,
568+
where all array shapes need to be known in advance.
558569
xp : array_namespace, optional
559570
The standard-compatible namespace for `x1` and `x2`. Default: infer.
560571
@@ -578,12 +589,86 @@ def setdiff1d(
578589
if xp is None:
579590
xp = array_namespace(x1, x2)
580591

581-
if assume_unique:
582-
x1 = xp.reshape(x1, (-1,))
583-
else:
584-
x1 = xp.unique_values(x1)
585-
x2 = xp.unique_values(x2)
586-
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
592+
x1 = xp.reshape(x1, (-1,))
593+
x2 = xp.reshape(x2, (-1,))
594+
if x1.shape == (0,) or x2.shape == (0,):
595+
return x1
596+
597+
def _x1_not_in_x2(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
598+
"""For each element of x1, return True if it is not also in x2."""
599+
# Even when assume_unique=True, there is no provision for x to be sorted
600+
x2 = xp.sort(x2)
601+
idx = xp.searchsorted(x2, x1)
602+
603+
# FIXME at() is faster but needs JAX jit support for bool mask
604+
# idx = at(idx, idx == x2.shape[0]).set(0)
605+
idx = xp.where(idx == x2.shape[0], xp.zeros_like(idx), idx)
606+
607+
return xp.take(x2, idx, axis=0) != x1
608+
609+
def _generic_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
610+
"""Generic implementation (including eager JAX)."""
611+
# Note: there is no provision in the Array API for xp.unique_values to sort
612+
if not assume_unique:
613+
# Call unique_values early to speed up the algorithm
614+
x1 = xp.unique_values(x1)
615+
x2 = xp.unique_values(x2)
616+
mask = _x1_not_in_x2(x1, x2)
617+
x1 = x1[mask]
618+
return x1 if assume_unique else xp.sort(x1)
619+
620+
def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
621+
"""
622+
Dask implementation.
623+
624+
Works around unique_values returning unknown shapes.
625+
"""
626+
# Do not call unique_values yet, as it would make array shapes unknown
627+
mask = _x1_not_in_x2(x1, x2)
628+
x1 = x1[mask]
629+
# Note: da.unique_values sorts
630+
return x1 if assume_unique else xp.unique_values(x1)
631+
632+
def _jax_jit_impl(
633+
x1: Array, x2: Array, fill_value: object | None
634+
) -> Array: # numpydoc ignore=PR01,RT01
635+
"""
636+
JAX implementation inside jax.jit.
637+
638+
Works around unique_values requiring a size= parameter
639+
and not being able to filter by a boolean mask.
640+
Returns array the same size as x1, padded with fill_value.
641+
"""
642+
# unique_values inside jax.jit is not supported unless it's got a fixed size
643+
mask = _x1_not_in_x2(x1, x2)
644+
645+
if fill_value is None:
646+
fill_value = xp.zeros((), dtype=x1.dtype)
647+
else:
648+
fill_value = xp.asarray(fill_value, dtype=x1.dtype)
649+
if cast(Array, fill_value).ndim != 0:
650+
msg = "`fill_value` must be a scalar."
651+
raise ValueError(msg)
652+
653+
x1 = xp.where(mask, x1, fill_value)
654+
# Note: jnp.unique_values sorts
655+
return xp.unique_values(x1, size=x1.size, fill_value=fill_value)
656+
657+
if is_dask_namespace(xp):
658+
return _dask_impl(x1, x2)
659+
660+
if is_jax_namespace(xp):
661+
import jax
662+
663+
try:
664+
return _generic_impl(x1, x2) # eager mode
665+
except (
666+
jax.errors.ConcretizationTypeError,
667+
jax.errors.NonConcreteBooleanIndexError,
668+
):
669+
return _jax_jit_impl(x1, x2, fill_value) # inside jax.jit
670+
671+
return _generic_impl(x1, x2)
587672

588673

589674
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:

src/array_api_extra/_lib/_utils/_helpers.py

+1-60
Original file line numberDiff line numberDiff line change
@@ -8,66 +8,7 @@
88
from . import _compat
99
from ._typing import Array
1010

11-
__all__ = ["in1d", "mean"]
12-
13-
14-
def in1d(
15-
x1: Array,
16-
x2: Array,
17-
/,
18-
*,
19-
assume_unique: bool = False,
20-
invert: bool = False,
21-
xp: ModuleType | None = None,
22-
) -> Array: # numpydoc ignore=PR01,RT01
23-
"""
24-
Check whether each element of an array is also present in a second array.
25-
26-
Returns a boolean array the same length as `x1` that is True
27-
where an element of `x1` is in `x2` and False otherwise.
28-
29-
This function has been adapted using the original implementation
30-
present in numpy:
31-
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
32-
"""
33-
if xp is None:
34-
xp = _compat.array_namespace(x1, x2)
35-
36-
# This code is run to make the code significantly faster
37-
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
38-
if invert:
39-
mask = xp.ones(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
40-
for a in x2:
41-
mask &= x1 != a
42-
else:
43-
mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
44-
for a in x2:
45-
mask |= x1 == a
46-
return mask
47-
48-
rev_idx = xp.empty(0) # placeholder
49-
if not assume_unique:
50-
x1, rev_idx = xp.unique_inverse(x1)
51-
x2 = xp.unique_values(x2)
52-
53-
ar = xp.concat((x1, x2))
54-
device_ = _compat.device(ar)
55-
# We need this to be a stable sort.
56-
order = xp.argsort(ar, stable=True)
57-
reverse_order = xp.argsort(order, stable=True)
58-
sar = xp.take(ar, order, axis=0)
59-
ar_size = _compat.size(sar)
60-
assert ar_size is not None, "xp.unique*() on lazy backends raises"
61-
if ar_size >= 1:
62-
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
63-
else:
64-
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])
65-
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
66-
ret = xp.take(flag, reverse_order, axis=0)
67-
68-
if assume_unique:
69-
return ret[: x1.shape[0]]
70-
return xp.take(ret, rev_idx, axis=0)
11+
__all__ = ["mean"]
7112

7213

7314
def mean(

tests/test_funcs.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@
3535
lazy_xp_function(kron, static_argnames="xp")
3636
lazy_xp_function(nunique, static_argnames="xp")
3737
lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp"))
38-
# FIXME calls in1d which calls xp.unique_values without size
39-
lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp"))
38+
lazy_xp_function(setdiff1d, static_argnames=("assume_unique", "xp"))
4039
# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238
4140
lazy_xp_function(sinc, jax_jit=False, static_argnames="xp")
4241

@@ -547,8 +546,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
547546
assert padded.shape == (4, 4)
548547

549548

550-
@pytest.mark.skip_xp_backend(Backend.DASK, reason="no argsort")
551-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray")
549+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no sort")
552550
class TestSetDiff1D:
553551
@pytest.mark.skip_xp_backend(
554552
Backend.TORCH, reason="index_select not implemented for uint32"

tests/test_utils.py

-47
This file was deleted.

0 commit comments

Comments
 (0)