@@ -6,6 +6,7 @@ use crate::physical_planner::joins::utils::{
6
6
use crate :: session_context:: Algorithm ;
7
7
use ahash:: RandomState ;
8
8
use bio:: data_structures:: interval_tree as rust_bio;
9
+ use coitrees:: { COITree , Interval } ;
9
10
use datafusion:: arrow:: array:: { Array , AsArray , PrimitiveArray , PrimitiveBuilder , RecordBatch } ;
10
11
use datafusion:: arrow:: compute;
11
12
use datafusion:: arrow:: datatypes:: { DataType , Schema , SchemaRef , UInt32Type } ;
@@ -702,6 +703,7 @@ enum IntervalJoinAlgorithm {
702
703
ArrayIntervalTree ( FnvHashMap < u64 , rust_bio:: ArrayBackedIntervalTree < i32 , Position > > ) ,
703
704
AIList ( FnvHashMap < u64 , scailist:: ScAIList < Position > > ) ,
704
705
Lapper ( FnvHashMap < u64 , rust_lapper:: Lapper < u32 , Position > > ) ,
706
+ CoitresNearest ( FnvHashMap < u64 , ( COITree < Position , u32 > , Vec < Interval < Position > > ) > ) ,
705
707
}
706
708
707
709
impl Debug for IntervalJoinAlgorithm {
@@ -715,6 +717,8 @@ impl Debug for IntervalJoinAlgorithm {
715
717
. collect :: < HashMap < _ , _ > > ( ) ;
716
718
f. debug_struct ( "Coitrees" ) . field ( "0" , & q) . finish ( )
717
719
}
720
+ & IntervalJoinAlgorithm :: CoitresNearest ( _) => todo ! ( ) ,
721
+
718
722
IntervalJoinAlgorithm :: IntervalTree ( m) => {
719
723
f. debug_struct ( "IntervalTree" ) . field ( "0" , m) . finish ( )
720
724
}
@@ -749,6 +753,28 @@ impl IntervalJoinAlgorithm {
749
753
750
754
IntervalJoinAlgorithm :: Coitrees ( hashmap)
751
755
}
756
+ Algorithm :: CoitreesNearest => {
757
+ use coitrees:: { COITree , Interval , IntervalTree } ;
758
+
759
+ let hashmap = hash_map
760
+ . into_iter ( )
761
+ . map ( |( k, v) | {
762
+ let mut intervals = v
763
+ . into_iter ( )
764
+ . map ( SequilaInterval :: into_coitrees)
765
+ . collect :: < Vec < Interval < Position > > > ( ) ;
766
+
767
+ // can hold up to u32::MAX intervals
768
+ let tree: COITree < Position , u32 > = COITree :: new ( intervals. iter ( ) ) ;
769
+ intervals. sort_by ( |a, b| {
770
+ a. first . cmp ( & b. first ) . then_with ( || a. last . cmp ( & b. last ) )
771
+ } ) ;
772
+ ( k, ( tree, intervals) )
773
+ } )
774
+ . collect :: < FnvHashMap < u64 , ( COITree < Position , u32 > , Vec < Interval < Position > > ) > > (
775
+ ) ;
776
+ IntervalJoinAlgorithm :: CoitresNearest ( hashmap)
777
+ }
752
778
Algorithm :: IntervalTree => {
753
779
let hashmap = hash_map
754
780
. into_iter ( )
@@ -810,10 +836,22 @@ impl IntervalJoinAlgorithm {
810
836
}
811
837
812
838
/// unoptimized on Linux x64 (without target-cpu=native)
813
- #[ cfg( all(
814
- target_os = "linux" ,
815
- target_arch = "x86_64" ,
816
- not( target_feature = "avx" )
839
+ #[ cfg( any(
840
+ all(
841
+ target_os = "linux" ,
842
+ target_arch = "x86_64" ,
843
+ not( target_feature = "avx" )
844
+ ) ,
845
+ all(
846
+ target_os = "macos" ,
847
+ target_arch = "x86_64" ,
848
+ not( target_feature = "avx" )
849
+ ) ,
850
+ all(
851
+ target_os = "windows" ,
852
+ target_arch = "x86_64" ,
853
+ not( target_feature = "avx" )
854
+ ) ,
817
855
) ) ]
818
856
fn extract_position ( & self , node : & coitrees:: IntervalNode < Position , u32 > ) -> Position {
819
857
node. metadata
@@ -823,12 +861,56 @@ impl IntervalJoinAlgorithm {
823
861
#[ cfg( any(
824
862
all( target_os = "macos" , target_arch = "aarch64" ) ,
825
863
all( target_os = "macos" , target_arch = "x86_64" , target_feature = "avx" ) ,
826
- all( target_os = "linux" , target_arch = "x86_64" , target_feature = "avx" )
864
+ all( target_os = "linux" , target_arch = "x86_64" , target_feature = "avx" ) ,
865
+ all( target_os = "windows" , target_arch = "x86_64" , target_feature = "avx" )
827
866
) ) ]
828
867
fn extract_position ( & self , node : & coitrees:: Interval < & Position > ) -> Position {
829
868
* node. metadata
830
869
}
831
870
871
+ fn nearest ( & self , start : i32 , end : i32 , ranges2 : & [ Interval < Position > ] ) -> Option < Position > {
872
+ if ranges2. is_empty ( ) {
873
+ return None ;
874
+ }
875
+
876
+ let sorted_ranges2 = ranges2;
877
+
878
+ let mut closest_idx = None ;
879
+ let mut min_distance = i32:: MAX ;
880
+
881
+ let mut left = 0 ;
882
+ let mut right = sorted_ranges2. len ( ) ;
883
+
884
+ // Binary search to narrow down candidates in ranges2
885
+ while left < right {
886
+ let mid = ( left + right) / 2 ;
887
+ if sorted_ranges2[ mid] . first < end {
888
+ left = mid + 1 ;
889
+ } else {
890
+ right = mid;
891
+ }
892
+ }
893
+
894
+ // Check ranges around the binary search result for nearest distance
895
+ for & i in [ left. saturating_sub ( 1 ) , left] . iter ( ) {
896
+ if let Some ( r2) = sorted_ranges2. get ( i) {
897
+ let distance = if end < r2. first {
898
+ r2. first - end
899
+ } else if r2. last < start {
900
+ start - r2. last
901
+ } else {
902
+ 0
903
+ } ;
904
+
905
+ if distance < min_distance {
906
+ min_distance = distance;
907
+ closest_idx = Some ( r2. metadata ) ;
908
+ }
909
+ }
910
+ }
911
+
912
+ closest_idx
913
+ }
832
914
fn get < F > ( & self , k : u64 , start : i32 , end : i32 , mut f : F )
833
915
where
834
916
F : FnMut ( Position ) ,
@@ -843,6 +925,25 @@ impl IntervalJoinAlgorithm {
843
925
} ) ;
844
926
}
845
927
}
928
+ IntervalJoinAlgorithm :: CoitresNearest ( hashmap) => {
929
+ use coitrees:: IntervalTree ;
930
+ if let Some ( tree) = hashmap. get ( & k) {
931
+ let mut i = 0 ;
932
+ // first look for overlaps and return an arbitrary one (see: https://web.mit.edu/~r/current/arch/i386_linux26/lib/R/library/IRanges/html/nearest-methods.html)
933
+ tree. 0 . query ( start, end, |node| {
934
+ let position: Position = self . extract_position ( node) ;
935
+ if i == 0 {
936
+ f ( position) ;
937
+ i += 1 ;
938
+ }
939
+ } ) ;
940
+ // found no overlaps in the tree - try to look for nearest intervals
941
+ if i == 0 {
942
+ let position = self . nearest ( start, end, & tree. 1 ) ;
943
+ f ( position. unwrap ( ) ) ;
944
+ }
945
+ }
946
+ }
846
947
IntervalJoinAlgorithm :: IntervalTree ( hashmap) => {
847
948
if let Some ( tree) = hashmap. get ( & k) {
848
949
for entry in tree. find ( start..end + 1 ) {
@@ -1086,8 +1187,23 @@ impl IntervalJoinStream {
1086
1187
. get ( * hash_val, start. value ( i) , end. value ( i) , |pos| {
1087
1188
pos_vect. push ( pos as u32 ) ;
1088
1189
} ) ;
1089
- rle_right. push ( pos_vect. len ( ) as u32 ) ;
1090
- builder_left. append_slice ( & pos_vect) ;
1190
+ match & build_side. hash_map {
1191
+ IntervalJoinAlgorithm :: CoitresNearest ( _t) => {
1192
+ // even if there is no hit we need to preserve the right side
1193
+ rle_right. push ( 1 ) ;
1194
+ if pos_vect. len ( ) == 0 {
1195
+ builder_left. append_null ( ) ;
1196
+ } else {
1197
+ builder_left. append_slice ( & pos_vect) ;
1198
+ }
1199
+ }
1200
+ _ => {
1201
+ rle_right. push ( pos_vect. len ( ) as u32 ) ;
1202
+ builder_left. append_slice ( & pos_vect) ;
1203
+ }
1204
+ }
1205
+
1206
+ // builder_left.append_slice(&pos_vect);
1091
1207
pos_vect. clear ( ) ;
1092
1208
}
1093
1209
let left_indexes = builder_left. finish ( ) ;
0 commit comments