Skip to content

Commit 07a3693

Browse files
mwiewiorpkhamutou
andauthored
feat:Nearest interval function (#59)
* Nearest interval function * Disabling JoinSelection rule for nearest function * CoitresNearest fix * Fixing macos, windows and linux x86 builds * add nearest test * add more intervals to nearest test * remove join_selection optimization rule --------- Co-authored-by: Pavel Khamutou <[email protected]>
1 parent 72f0cd2 commit 07a3693

File tree

5 files changed

+181
-8
lines changed

5 files changed

+181
-8
lines changed

sandbox/closest.md

Whitespace-only changes.

sandbox/complement.md

Whitespace-only changes.

sequila/sequila-core/src/physical_planner/joins/interval_join.rs

+123-7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::physical_planner::joins::utils::{
66
use crate::session_context::Algorithm;
77
use ahash::RandomState;
88
use bio::data_structures::interval_tree as rust_bio;
9+
use coitrees::{COITree, Interval};
910
use datafusion::arrow::array::{Array, AsArray, PrimitiveArray, PrimitiveBuilder, RecordBatch};
1011
use datafusion::arrow::compute;
1112
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef, UInt32Type};
@@ -702,6 +703,7 @@ enum IntervalJoinAlgorithm {
702703
ArrayIntervalTree(FnvHashMap<u64, rust_bio::ArrayBackedIntervalTree<i32, Position>>),
703704
AIList(FnvHashMap<u64, scailist::ScAIList<Position>>),
704705
Lapper(FnvHashMap<u64, rust_lapper::Lapper<u32, Position>>),
706+
CoitresNearest(FnvHashMap<u64, (COITree<Position, u32>, Vec<Interval<Position>>)>),
705707
}
706708

707709
impl Debug for IntervalJoinAlgorithm {
@@ -715,6 +717,8 @@ impl Debug for IntervalJoinAlgorithm {
715717
.collect::<HashMap<_, _>>();
716718
f.debug_struct("Coitrees").field("0", &q).finish()
717719
}
720+
&IntervalJoinAlgorithm::CoitresNearest(_) => todo!(),
721+
718722
IntervalJoinAlgorithm::IntervalTree(m) => {
719723
f.debug_struct("IntervalTree").field("0", m).finish()
720724
}
@@ -749,6 +753,28 @@ impl IntervalJoinAlgorithm {
749753

750754
IntervalJoinAlgorithm::Coitrees(hashmap)
751755
}
756+
Algorithm::CoitreesNearest => {
757+
use coitrees::{COITree, Interval, IntervalTree};
758+
759+
let hashmap = hash_map
760+
.into_iter()
761+
.map(|(k, v)| {
762+
let mut intervals = v
763+
.into_iter()
764+
.map(SequilaInterval::into_coitrees)
765+
.collect::<Vec<Interval<Position>>>();
766+
767+
// can hold up to u32::MAX intervals
768+
let tree: COITree<Position, u32> = COITree::new(intervals.iter());
769+
intervals.sort_by(|a, b| {
770+
a.first.cmp(&b.first).then_with(|| a.last.cmp(&b.last))
771+
});
772+
(k, (tree, intervals))
773+
})
774+
.collect::<FnvHashMap<u64, (COITree<Position, u32>, Vec<Interval<Position>>)>>(
775+
);
776+
IntervalJoinAlgorithm::CoitresNearest(hashmap)
777+
}
752778
Algorithm::IntervalTree => {
753779
let hashmap = hash_map
754780
.into_iter()
@@ -810,10 +836,22 @@ impl IntervalJoinAlgorithm {
810836
}
811837

812838
/// unoptimized on Linux x64 (without target-cpu=native)
813-
#[cfg(all(
814-
target_os = "linux",
815-
target_arch = "x86_64",
816-
not(target_feature = "avx")
839+
#[cfg(any(
840+
all(
841+
target_os = "linux",
842+
target_arch = "x86_64",
843+
not(target_feature = "avx")
844+
),
845+
all(
846+
target_os = "macos",
847+
target_arch = "x86_64",
848+
not(target_feature = "avx")
849+
),
850+
all(
851+
target_os = "windows",
852+
target_arch = "x86_64",
853+
not(target_feature = "avx")
854+
),
817855
))]
818856
fn extract_position(&self, node: &coitrees::IntervalNode<Position, u32>) -> Position {
819857
node.metadata
@@ -823,12 +861,56 @@ impl IntervalJoinAlgorithm {
823861
#[cfg(any(
824862
all(target_os = "macos", target_arch = "aarch64"),
825863
all(target_os = "macos", target_arch = "x86_64", target_feature = "avx"),
826-
all(target_os = "linux", target_arch = "x86_64", target_feature = "avx")
864+
all(target_os = "linux", target_arch = "x86_64", target_feature = "avx"),
865+
all(target_os = "windows", target_arch = "x86_64", target_feature = "avx")
827866
))]
828867
fn extract_position(&self, node: &coitrees::Interval<&Position>) -> Position {
829868
*node.metadata
830869
}
831870

871+
fn nearest(&self, start: i32, end: i32, ranges2: &[Interval<Position>]) -> Option<Position> {
872+
if ranges2.is_empty() {
873+
return None;
874+
}
875+
876+
let sorted_ranges2 = ranges2;
877+
878+
let mut closest_idx = None;
879+
let mut min_distance = i32::MAX;
880+
881+
let mut left = 0;
882+
let mut right = sorted_ranges2.len();
883+
884+
// Binary search to narrow down candidates in ranges2
885+
while left < right {
886+
let mid = (left + right) / 2;
887+
if sorted_ranges2[mid].first < end {
888+
left = mid + 1;
889+
} else {
890+
right = mid;
891+
}
892+
}
893+
894+
// Check ranges around the binary search result for nearest distance
895+
for &i in [left.saturating_sub(1), left].iter() {
896+
if let Some(r2) = sorted_ranges2.get(i) {
897+
let distance = if end < r2.first {
898+
r2.first - end
899+
} else if r2.last < start {
900+
start - r2.last
901+
} else {
902+
0
903+
};
904+
905+
if distance < min_distance {
906+
min_distance = distance;
907+
closest_idx = Some(r2.metadata);
908+
}
909+
}
910+
}
911+
912+
closest_idx
913+
}
832914
fn get<F>(&self, k: u64, start: i32, end: i32, mut f: F)
833915
where
834916
F: FnMut(Position),
@@ -843,6 +925,25 @@ impl IntervalJoinAlgorithm {
843925
});
844926
}
845927
}
928+
IntervalJoinAlgorithm::CoitresNearest(hashmap) => {
929+
use coitrees::IntervalTree;
930+
if let Some(tree) = hashmap.get(&k) {
931+
let mut i = 0;
932+
// first look for overlaps and return an arbitrary one (see: https://web.mit.edu/~r/current/arch/i386_linux26/lib/R/library/IRanges/html/nearest-methods.html)
933+
tree.0.query(start, end, |node| {
934+
let position: Position = self.extract_position(node);
935+
if i == 0 {
936+
f(position);
937+
i += 1;
938+
}
939+
});
940+
// found no overlaps in the tree - try to look for nearest intervals
941+
if i == 0 {
942+
let position = self.nearest(start, end, &tree.1);
943+
f(position.unwrap());
944+
}
945+
}
946+
}
846947
IntervalJoinAlgorithm::IntervalTree(hashmap) => {
847948
if let Some(tree) = hashmap.get(&k) {
848949
for entry in tree.find(start..end + 1) {
@@ -1086,8 +1187,23 @@ impl IntervalJoinStream {
10861187
.get(*hash_val, start.value(i), end.value(i), |pos| {
10871188
pos_vect.push(pos as u32);
10881189
});
1089-
rle_right.push(pos_vect.len() as u32);
1090-
builder_left.append_slice(&pos_vect);
1190+
match &build_side.hash_map {
1191+
IntervalJoinAlgorithm::CoitresNearest(_t) => {
1192+
// even if there is no hit we need to preserve the right side
1193+
rle_right.push(1);
1194+
if pos_vect.len() == 0 {
1195+
builder_left.append_null();
1196+
} else {
1197+
builder_left.append_slice(&pos_vect);
1198+
}
1199+
}
1200+
_ => {
1201+
rle_right.push(pos_vect.len() as u32);
1202+
builder_left.append_slice(&pos_vect);
1203+
}
1204+
}
1205+
1206+
// builder_left.append_slice(&pos_vect);
10911207
pos_vect.clear();
10921208
}
10931209
let left_indexes = builder_left.finish();

sequila/sequila-core/src/session_context.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use datafusion::common::extensions_options;
55
use datafusion::config::ConfigExtension;
66
use datafusion::execution::runtime_env::RuntimeEnv;
77
use datafusion::execution::SessionStateBuilder;
8+
use datafusion::physical_optimizer::optimizer::PhysicalOptimizer;
89
use datafusion::prelude::{SessionConfig, SessionContext};
910
use log::info;
1011
use std::str::FromStr;
@@ -26,12 +27,16 @@ impl SeQuiLaSessionExt for SessionContext {
2627
}
2728

2829
fn with_config_rt_sequila(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> SessionContext {
30+
let mut rules = PhysicalOptimizer::new().rules;
31+
rules.retain(|rule| rule.name() != "join_selection");
32+
rules.push(Arc::new(IntervalJoinPhysicalOptimizationRule));
33+
2934
let ctx: SessionContext = SessionStateBuilder::new()
3035
.with_config(config)
3136
.with_runtime_env(runtime)
3237
.with_default_features()
3338
.with_query_planner(Arc::new(SeQuiLaQueryPlanner))
34-
.with_physical_optimizer_rule(Arc::new(IntervalJoinPhysicalOptimizationRule))
39+
.with_physical_optimizer_rules(rules)
3540
.build()
3641
.into();
3742

@@ -61,6 +66,7 @@ pub enum Algorithm {
6166
ArrayIntervalTree,
6267
AIList,
6368
Lapper,
69+
CoitreesNearest,
6470
}
6571

6672
#[derive(Debug)]
@@ -85,6 +91,7 @@ impl FromStr for Algorithm {
8591
"arrayintervaltree" => Ok(Algorithm::ArrayIntervalTree),
8692
"ailist" => Ok(Algorithm::AIList),
8793
"lapper" => Ok(Algorithm::Lapper),
94+
"coitreesnearest" => Ok(Algorithm::CoitreesNearest),
8895
_ => Err(ParseAlgorithmError(format!(
8996
"Can't parse '{}' as Algorithm",
9097
s
@@ -101,6 +108,7 @@ impl std::fmt::Display for Algorithm {
101108
Algorithm::ArrayIntervalTree => "ArrayIntervalTree",
102109
Algorithm::AIList => "AIList",
103110
Algorithm::Lapper => "Lapper",
111+
Algorithm::CoitreesNearest => "CoitreesNearest",
104112
};
105113
write!(f, "{}", val)
106114
}

sequila/sequila-core/tests/integration_test.rs

+49
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,52 @@ async fn test_all_gt_lt_conditions(ctx: SessionContext) -> Result<()> {
348348

349349
Ok(())
350350
}
351+
352+
#[tokio::test(flavor = "multi_thread")]
353+
#[rstest::rstest]
354+
async fn test_nearest(ctx: SessionContext) -> Result<()> {
355+
let a = r#"
356+
CREATE TABLE a (contig TEXT, strand TEXT, start INTEGER, end INTEGER) AS VALUES
357+
('a', 's', 5, 10)
358+
"#;
359+
360+
let b = r#"
361+
CREATE TABLE b (contig TEXT, strand TEXT, start INTEGER, end INTEGER) AS VALUES
362+
('a', 's', 11, 13),
363+
('a', 's', 20, 21),
364+
('a', 'x', 0, 1),
365+
('b', 's', 1, 2)
366+
"#;
367+
368+
ctx.sql("SET sequila.interval_join_algorithm TO CoitreesNearest")
369+
.await?;
370+
371+
ctx.sql(a).await?;
372+
ctx.sql(b).await?;
373+
374+
let q = r#"
375+
SELECT * FROM a JOIN b
376+
ON a.contig = b.contig AND a.strand = b.strand
377+
AND a.start < b.end AND a.end > b.start
378+
"#;
379+
380+
let result = ctx.sql(q).await?;
381+
result.clone().show().await?;
382+
383+
let results = result.collect().await?;
384+
385+
let expected = [
386+
"+--------+--------+-------+-----+--------+--------+-------+-----+",
387+
"| contig | strand | start | end | contig | strand | start | end |",
388+
"+--------+--------+-------+-----+--------+--------+-------+-----+",
389+
"| | | | | a | x | 0 | 1 |",
390+
"| | | | | b | s | 1 | 2 |",
391+
"| a | s | 5 | 10 | a | s | 11 | 13 |",
392+
"| a | s | 5 | 10 | a | s | 20 | 21 |",
393+
"+--------+--------+-------+-----+--------+--------+-------+-----+",
394+
];
395+
396+
assert_batches_sorted_eq!(expected, &results);
397+
398+
Ok(())
399+
}

0 commit comments

Comments
 (0)