11
11
12
12
from ._at import at
13
13
from ._utils import _compat , _helpers
14
- from ._utils ._compat import array_namespace , is_jax_array
14
+ from ._utils ._compat import (
15
+ array_namespace ,
16
+ is_dask_namespace ,
17
+ is_jax_array ,
18
+ is_jax_namespace ,
19
+ )
15
20
from ._utils ._typing import Array
16
21
17
22
__all__ = [
@@ -539,6 +544,7 @@ def setdiff1d(
539
544
/ ,
540
545
* ,
541
546
assume_unique : bool = False ,
547
+ fill_value : object | None = None ,
542
548
xp : ModuleType | None = None ,
543
549
) -> Array :
544
550
"""
@@ -555,6 +561,11 @@ def setdiff1d(
555
561
assume_unique : bool
556
562
If ``True``, the input arrays are both assumed to be unique, which
557
563
can speed up the calculation. Default is ``False``.
564
+ fill_value : object, optional
565
+ Pad the output array with this value.
566
+
567
+ This is exclusively used for JAX arrays when running inside ``jax.jit``,
568
+ where all array shapes need to be known in advance.
558
569
xp : array_namespace, optional
559
570
The standard-compatible namespace for `x1` and `x2`. Default: infer.
560
571
@@ -578,12 +589,86 @@ def setdiff1d(
578
589
if xp is None :
579
590
xp = array_namespace (x1 , x2 )
580
591
581
- if assume_unique :
582
- x1 = xp .reshape (x1 , (- 1 ,))
583
- else :
584
- x1 = xp .unique_values (x1 )
585
- x2 = xp .unique_values (x2 )
586
- return x1 [_helpers .in1d (x1 , x2 , assume_unique = True , invert = True , xp = xp )]
592
+ x1 = xp .reshape (x1 , (- 1 ,))
593
+ x2 = xp .reshape (x2 , (- 1 ,))
594
+ if x1 .shape == (0 ,) or x2 .shape == (0 ,):
595
+ return x1
596
+
597
+ def _x1_not_in_x2 (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
598
+ """For each element of x1, return True if it is not also in x2."""
599
+ # Even when assume_unique=True, there is no provision for x to be sorted
600
+ x2 = xp .sort (x2 )
601
+ idx = xp .searchsorted (x2 , x1 )
602
+
603
+ # FIXME at() is faster but needs JAX jit support for bool mask
604
+ # idx = at(idx, idx == x2.shape[0]).set(0)
605
+ idx = xp .where (idx == x2 .shape [0 ], xp .zeros_like (idx ), idx )
606
+
607
+ return xp .take (x2 , idx , axis = 0 ) != x1
608
+
609
+ def _generic_impl (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
610
+ """Generic implementation (including eager JAX)."""
611
+ # Note: there is no provision in the Array API for xp.unique_values to sort
612
+ if not assume_unique :
613
+ # Call unique_values early to speed up the algorithm
614
+ x1 = xp .unique_values (x1 )
615
+ x2 = xp .unique_values (x2 )
616
+ mask = _x1_not_in_x2 (x1 , x2 )
617
+ x1 = x1 [mask ]
618
+ return x1 if assume_unique else xp .sort (x1 )
619
+
620
+ def _dask_impl (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
621
+ """
622
+ Dask implementation.
623
+
624
+ Works around unique_values returning unknown shapes.
625
+ """
626
+ # Do not call unique_values yet, as it would make array shapes unknown
627
+ mask = _x1_not_in_x2 (x1 , x2 )
628
+ x1 = x1 [mask ]
629
+ # Note: da.unique_values sorts
630
+ return x1 if assume_unique else xp .unique_values (x1 )
631
+
632
+ def _jax_jit_impl (
633
+ x1 : Array , x2 : Array , fill_value : object | None
634
+ ) -> Array : # numpydoc ignore=PR01,RT01
635
+ """
636
+ JAX implementation inside jax.jit.
637
+
638
+ Works around unique_values requiring a size= parameter
639
+ and not being able to filter by a boolean mask.
640
+ Returns array the same size as x1, padded with fill_value.
641
+ """
642
+ # unique_values inside jax.jit is not supported unless it's got a fixed size
643
+ mask = _x1_not_in_x2 (x1 , x2 )
644
+
645
+ if fill_value is None :
646
+ fill_value = xp .zeros ((), dtype = x1 .dtype )
647
+ else :
648
+ fill_value = xp .asarray (fill_value , dtype = x1 .dtype )
649
+ if cast (Array , fill_value ).ndim != 0 :
650
+ msg = "`fill_value` must be a scalar."
651
+ raise ValueError (msg )
652
+
653
+ x1 = xp .where (mask , x1 , fill_value )
654
+ # Note: jnp.unique_values sorts
655
+ return xp .unique_values (x1 , size = x1 .size , fill_value = fill_value )
656
+
657
+ if is_dask_namespace (xp ):
658
+ return _dask_impl (x1 , x2 )
659
+
660
+ if is_jax_namespace (xp ):
661
+ import jax
662
+
663
+ try :
664
+ return _generic_impl (x1 , x2 ) # eager mode
665
+ except (
666
+ jax .errors .ConcretizationTypeError ,
667
+ jax .errors .NonConcreteBooleanIndexError ,
668
+ ):
669
+ return _jax_jit_impl (x1 , x2 , fill_value ) # inside jax.jit
670
+
671
+ return _generic_impl (x1 , x2 )
587
672
588
673
589
674
def sinc (x : Array , / , * , xp : ModuleType | None = None ) -> Array :
0 commit comments