Skip to content

Commit 04d823b

Browse files
jamxia155alamb
andauthored
Substrait support for propagating TableScan.filters to Substrait ReadRel.filter (#14194)
* Propagate filter info from TableScan to ReadRel Propagate information in datafusion::logical_expr::TableScan.filters to substrait::proto::ReadRel.best_effort_filter. * Add test * cargo fmt * Fix clippy error * Use conjunction * cargo fmt * Use ReadRel.filter instead of best_effort_filter * Check filter types in TableScan.filters Use TableScan.source.supports_filters_pushdown() to determine if each filter in TableScan.filters should be included in ReadRel.filter or ReadRel.best_effort_filter * Propagate Substrait ReadRel filter to consumer * Address PR comments * Propagate TableScan filters to ReadRel filter --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent ee77d58 commit 04d823b

File tree

3 files changed

+100
-5
lines changed

3 files changed

+100
-5
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

+14-3
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ use datafusion::logical_expr::{
6868
};
6969
use datafusion::prelude::{lit, JoinType};
7070
use datafusion::{
71-
arrow, error::Result, logical_expr::utils::split_conjunction, prelude::Column,
72-
scalar::ScalarValue,
71+
arrow, error::Result, logical_expr::utils::split_conjunction,
72+
logical_expr::utils::split_conjunction_owned, prelude::Column, scalar::ScalarValue,
7373
};
7474
use std::collections::HashSet;
7575
use std::sync::Arc;
@@ -1327,19 +1327,28 @@ pub async fn from_read_rel(
13271327
table_ref: TableReference,
13281328
schema: DFSchema,
13291329
projection: &Option<MaskExpression>,
1330+
filter: &Option<Box<Expression>>,
13301331
) -> Result<LogicalPlan> {
13311332
let schema = schema.replace_qualifier(table_ref.clone());
13321333

1334+
let filters = if let Some(f) = filter {
1335+
let filter_expr = consumer.consume_expression(f, &schema).await?;
1336+
split_conjunction_owned(filter_expr)
1337+
} else {
1338+
vec![]
1339+
};
1340+
13331341
let plan = {
13341342
let provider = match consumer.resolve_table_ref(&table_ref).await? {
13351343
Some(ref provider) => Arc::clone(provider),
13361344
_ => return plan_err!("No table named '{table_ref}'"),
13371345
};
13381346

1339-
LogicalPlanBuilder::scan(
1347+
LogicalPlanBuilder::scan_with_filters(
13401348
table_ref,
13411349
provider_as_source(Arc::clone(&provider)),
13421350
None,
1351+
filters,
13431352
)?
13441353
.build()?
13451354
};
@@ -1382,6 +1391,7 @@ pub async fn from_read_rel(
13821391
table_reference,
13831392
substrait_schema,
13841393
&read.projection,
1394+
&read.filter,
13851395
)
13861396
.await
13871397
}
@@ -1464,6 +1474,7 @@ pub async fn from_read_rel(
14641474
table_reference,
14651475
substrait_schema,
14661476
&read.projection,
1477+
&read.filter,
14671478
)
14681479
.await
14691480
}

datafusion/substrait/src/logical_plan/producer.rs

+20-2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ use datafusion::logical_expr::expr::{
5555
AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList,
5656
InSubquery, WindowFunction, WindowFunctionParams,
5757
};
58+
use datafusion::logical_expr::utils::conjunction;
5859
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
5960
use datafusion::prelude::Expr;
6061
use pbjson_types::Any as ProtoAny;
@@ -540,7 +541,7 @@ pub fn to_substrait_rel(
540541
}
541542

542543
pub fn from_table_scan(
543-
_producer: &mut impl SubstraitProducer,
544+
producer: &mut impl SubstraitProducer,
544545
scan: &TableScan,
545546
) -> Result<Box<Rel>> {
546547
let projection = scan.projection.as_ref().map(|p| {
@@ -560,11 +561,28 @@ pub fn from_table_scan(
560561
let table_schema = scan.source.schema().to_dfschema_ref()?;
561562
let base_schema = to_substrait_named_struct(&table_schema)?;
562563

564+
let filter_option = if scan.filters.is_empty() {
565+
None
566+
} else {
567+
let table_schema_qualified = Arc::new(
568+
DFSchema::try_from_qualified_schema(
569+
scan.table_name.clone(),
570+
&(scan.source.schema()),
571+
)
572+
.unwrap(),
573+
);
574+
575+
let combined_expr = conjunction(scan.filters.clone()).unwrap();
576+
let filter_expr =
577+
producer.handle_expr(&combined_expr, &table_schema_qualified)?;
578+
Some(Box::new(filter_expr))
579+
};
580+
563581
Ok(Box::new(Rel {
564582
rel_type: Some(RelType::Read(Box::new(ReadRel {
565583
common: None,
566584
base_schema: Some(base_schema),
567-
filter: None,
585+
filter: filter_option,
568586
best_effort_filter: None,
569587
projection,
570588
advanced_extension: None,

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

+66
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,11 @@ async fn roundtrip_repartition_hash() -> Result<()> {
12341234
Ok(())
12351235
}
12361236

1237+
#[tokio::test]
1238+
async fn roundtrip_read_filter() -> Result<()> {
1239+
roundtrip_verify_read_filter_count("SELECT a FROM data where a < 5", 1).await
1240+
}
1241+
12371242
fn check_post_join_filters(rel: &Rel) -> Result<()> {
12381243
// search for target_rel and field value in proto
12391244
match &rel.rel_type {
@@ -1319,6 +1324,56 @@ async fn verify_post_join_filter_value(proto: Box<Plan>) -> Result<()> {
13191324
Ok(())
13201325
}
13211326

1327+
fn count_read_filters(rel: &Rel, filter_count: &mut u32) -> Result<()> {
1328+
// search for target_rel and field value in proto
1329+
match &rel.rel_type {
1330+
Some(RelType::Read(read)) => {
1331+
// increment counter for read filter if not None
1332+
if read.filter.is_some() {
1333+
*filter_count += 1;
1334+
}
1335+
Ok(())
1336+
}
1337+
Some(RelType::Filter(filter)) => {
1338+
count_read_filters(filter.input.as_ref().unwrap().as_ref(), filter_count)
1339+
}
1340+
_ => Ok(()),
1341+
}
1342+
}
1343+
1344+
async fn assert_read_filter_count(
1345+
proto: Box<Plan>,
1346+
expected_filter_count: u32,
1347+
) -> Result<()> {
1348+
let mut filter_count: u32 = 0;
1349+
for relation in &proto.relations {
1350+
match relation.rel_type.as_ref() {
1351+
Some(rt) => match rt {
1352+
plan_rel::RelType::Rel(rel) => {
1353+
match count_read_filters(rel, &mut filter_count) {
1354+
Err(e) => return Err(e),
1355+
Ok(_) => continue,
1356+
}
1357+
}
1358+
plan_rel::RelType::Root(root) => {
1359+
match count_read_filters(
1360+
root.input.as_ref().unwrap(),
1361+
&mut filter_count,
1362+
) {
1363+
Err(e) => return Err(e),
1364+
Ok(_) => continue,
1365+
}
1366+
}
1367+
},
1368+
None => return plan_err!("Cannot parse plan relation: None"),
1369+
}
1370+
}
1371+
1372+
assert_eq!(expected_filter_count, filter_count);
1373+
1374+
Ok(())
1375+
}
1376+
13221377
async fn assert_expected_plan_unoptimized(
13231378
sql: &str,
13241379
expected_plan_str: &str,
@@ -1489,6 +1544,17 @@ async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> {
14891544
verify_post_join_filter_value(proto).await
14901545
}
14911546

1547+
async fn roundtrip_verify_read_filter_count(
1548+
sql: &str,
1549+
expected_filter_count: u32,
1550+
) -> Result<()> {
1551+
let ctx = create_context().await?;
1552+
let proto = roundtrip_with_ctx(sql, ctx).await?;
1553+
1554+
// verify that filter counts in read relations are as expected
1555+
assert_read_filter_count(proto, expected_filter_count).await
1556+
}
1557+
14921558
async fn roundtrip_all_types(sql: &str) -> Result<()> {
14931559
roundtrip_with_ctx(sql, create_all_type_context().await?).await?;
14941560
Ok(())

0 commit comments

Comments
 (0)