Skip to content

Commit 1a0c9a8

Browse files
authored
[Spark] Skip non-deterministic filters in Data Skipping to prevent incorrect file pruning in Delta queries (#4141)
#### Which Delta project/connector is this regarding? - [x] Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description This is a follow-up to the previous attempt to handle double-filtering of non-deterministic conditions (e.g. rand() < 0.25) in #4095. It prevented non-deterministic filters from appearing in `unusedFilters` in the `ScanReport` unless we added special pipelining for them from `PrepareDeltaScan` to `filesForScan`. This is also inconsistent with how we skip `subqueryFilters`. We now treat `filesForScan` as the narrow waist to skip any filters. ## How was this patch tested? UTs ## Does this PR introduce _any_ user-facing changes? No
1 parent cd67546 commit 1a0c9a8

File tree

3 files changed

+80
-14
lines changed

3 files changed

+80
-14
lines changed

spark/src/main/scala/org/apache/spark/sql/delta/stats/DataSkippingReader.scala

+10-8
Original file line numberDiff line numberDiff line change
@@ -1222,14 +1222,16 @@ trait DataSkippingReaderBase
12221222
import DeltaTableUtils._
12231223
val partitionColumns = metadata.partitionColumns
12241224

1225-
// For data skipping, avoid using the filters that involve subqueries.
1226-
1227-
val (subqueryFilters, flatFilters) = filters.partition {
1228-
case f => containsSubquery(f)
1225+
// For data skipping, avoid using the filters that either:
1226+
// 1. involve subqueries.
1227+
// 2. are non-deterministic.
1228+
var (ineligibleFilters, eligibleFilters) = filters.partition {
1229+
case f => containsSubquery(f) || !f.deterministic
12291230
}
12301231

1231-
val (partitionFilters, dataFilters) = flatFilters
1232-
.partition(isPredicatePartitionColumnsOnly(_, partitionColumns, spark))
1232+
1233+
val (partitionFilters, dataFilters) = eligibleFilters
1234+
.partition(isPredicatePartitionColumnsOnly(_, partitionColumns, spark))
12331235

12341236
if (dataFilters.isEmpty) recordDeltaOperation(deltaLog, "delta.skipping.partition") {
12351237
// When there are only partition filters we can scan allFiles
@@ -1246,7 +1248,7 @@ trait DataSkippingReaderBase
12461248
dataFilters = ExpressionSet(Nil),
12471249
partitionLikeDataFilters = ExpressionSet(Nil),
12481250
rewrittenPartitionLikeDataFilters = Set.empty,
1249-
unusedFilters = ExpressionSet(subqueryFilters),
1251+
unusedFilters = ExpressionSet(ineligibleFilters),
12501252
scanDurationMs = System.currentTimeMillis() - startTime,
12511253
dataSkippingType =
12521254
getCorrectDataSkippingType(DeltaDataSkippingType.partitionFilteringOnlyV1)
@@ -1323,7 +1325,7 @@ trait DataSkippingReaderBase
13231325
dataFilters = ExpressionSet(skippingFilters.map(_._1)),
13241326
partitionLikeDataFilters = ExpressionSet(partitionLikeFilters.map(_._1)),
13251327
rewrittenPartitionLikeDataFilters = partitionLikeFilters.map(_._2.expr.expr).toSet,
1326-
unusedFilters = ExpressionSet(unusedFilters.map(_._1) ++ subqueryFilters),
1328+
unusedFilters = ExpressionSet(unusedFilters.map(_._1) ++ ineligibleFilters),
13271329
scanDurationMs = System.currentTimeMillis() - startTime,
13281330
dataSkippingType = getCorrectDataSkippingType(dataSkippingType)
13291331
)

spark/src/main/scala/org/apache/spark/sql/delta/stats/PrepareDeltaScan.scala

+4-6
Original file line numberDiff line numberDiff line change
@@ -114,22 +114,20 @@ trait PrepareDeltaScanBase extends Rule[LogicalPlan]
114114
limitOpt: Option[Int],
115115
filters: Seq[Expression],
116116
delta: LogicalRelation): DeltaScan = {
117-
// Remove non-deterministic filters (e.g., rand() < 0.25) to prevent incorrect file pruning.
118-
val deterministicFilters = filters.filter(_.deterministic)
119117
withStatusCode("DELTA", "Filtering files for query") {
120118
if (limitOpt.nonEmpty) {
121119
// If we trigger limit push down, the filters must be partition filters. Since
122120
// there are no data filters, we don't need to apply Generated Columns
123121
// optimization. See `DeltaTableScan` for more details.
124-
return scanGenerator.filesForScan(limitOpt.get, deterministicFilters)
122+
return scanGenerator.filesForScan(limitOpt.get, filters)
125123
}
126124
val filtersForScan =
127125
if (!GeneratedColumn.partitionFilterOptimizationEnabled(spark)) {
128-
deterministicFilters
126+
filters
129127
} else {
130128
val generatedPartitionFilters = GeneratedColumn.generatePartitionFilters(
131-
spark, scanGenerator.snapshotToScan, deterministicFilters, delta)
132-
deterministicFilters ++ generatedPartitionFilters
129+
spark, scanGenerator.snapshotToScan, filters, delta)
130+
filters ++ generatedPartitionFilters
133131
}
134132
scanGenerator.filesForScan(filtersForScan)
135133
}

spark/src/test/scala/org/apache/spark/sql/delta/stats/DataSkippingDeltaTests.scala

+66
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.sql._
3636
import org.apache.spark.sql.catalyst.QueryPlanningTracker
3737
import org.apache.spark.sql.catalyst.TableIdentifier
3838
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, PredicateHelper}
39+
import org.apache.spark.sql.catalyst.plans.logical.Filter
3940
import org.apache.spark.sql.functions.{col, lit}
4041
import org.apache.spark.sql.internal.SQLConf
4142
import org.apache.spark.sql.test.SharedSparkSession
@@ -1929,6 +1930,71 @@ trait DataSkippingDeltaTestsBase extends DeltaExcludedBySparkVersionTestMixinShi
19291930
}
19301931
}
19311932

1933+
test("File skipping with non-deterministic filters") {
1934+
withTable("tbl") {
1935+
// Create the table.
1936+
val df = spark.range(100).toDF()
1937+
df.write.mode("overwrite").format("delta").saveAsTable("tbl")
1938+
1939+
// Append 9 times to the table.
1940+
for (i <- 1 to 9) {
1941+
val df = spark.range(i * 100, (i + 1) * 100).toDF()
1942+
df.write.mode("append").format("delta").insertInto("tbl")
1943+
}
1944+
1945+
val query = "SELECT count(*) FROM tbl WHERE rand(0) < 0.25"
1946+
val result = sql(query).collect().head.getLong(0)
1947+
assert(result > 150, s"Expected around 250 rows (~0.25 * 1000), got: $result")
1948+
1949+
val predicates = sql(query).queryExecution.optimizedPlan.collect {
1950+
case Filter(condition, _) => condition
1951+
}.flatMap(splitConjunctivePredicates)
1952+
val scanResult = DeltaLog.forTable(spark, TableIdentifier("tbl"))
1953+
.update().filesForScan(predicates)
1954+
assert(scanResult.unusedFilters.nonEmpty)
1955+
}
1956+
}
1957+
1958+
test("File skipping with non-deterministic filters on partitioned tables") {
1959+
withTable("tbl_partitioned") {
1960+
import org.apache.spark.sql.functions.col
1961+
1962+
// Create initial DataFrame and add a partition column.
1963+
val df = spark.range(100).toDF().withColumn("p", col("id") % 10)
1964+
df.write
1965+
.mode("overwrite")
1966+
.format("delta")
1967+
.partitionBy("p")
1968+
.saveAsTable("tbl_partitioned")
1969+
1970+
// Append 9 more times to the table.
1971+
for (i <- 1 to 9) {
1972+
val newDF = spark.range(i * 100, (i + 1) * 100).toDF().withColumn("p", col("id") % 10)
1973+
newDF.write.mode("append").format("delta").insertInto("tbl_partitioned")
1974+
}
1975+
1976+
// Run query with a nondeterministic filter.
1977+
val query = "SELECT count(*) FROM tbl_partitioned WHERE rand(0) < 0.25"
1978+
val result = sql(query).collect().head.getLong(0)
1979+
// Assert that the row count is as expected (e.g., roughly 25% of rows).
1980+
assert(result > 150, s"Expected a reasonable number of rows, got: $result")
1981+
1982+
val predicates = sql(query).queryExecution.optimizedPlan.collect {
1983+
case Filter(condition, _) => condition
1984+
}.flatMap(splitConjunctivePredicates)
1985+
val scanResult = DeltaLog.forTable(spark, TableIdentifier("tbl_partitioned"))
1986+
.update().filesForScan(predicates)
1987+
assert(scanResult.unusedFilters.nonEmpty)
1988+
1989+
// Assert that entries are fetched from all 10 partitions
1990+
val distinctPartitions =
1991+
sql("SELECT DISTINCT p FROM tbl_partitioned WHERE rand(0) < 0.25")
1992+
.collect()
1993+
.length
1994+
assert(distinctPartitions == 10)
1995+
}
1996+
}
1997+
19321998
protected def parse(deltaLog: DeltaLog, predicate: String): Seq[Expression] = {
19331999

19342000
// We produce a wrong filter in this case otherwise

0 commit comments

Comments
 (0)