diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/stats/DataSkippingReader.scala b/spark/src/main/scala/org/apache/spark/sql/delta/stats/DataSkippingReader.scala index d359d8faac7..6b41f18e38b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/stats/DataSkippingReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/stats/DataSkippingReader.scala @@ -1222,14 +1222,16 @@ trait DataSkippingReaderBase import DeltaTableUtils._ val partitionColumns = metadata.partitionColumns - // For data skipping, avoid using the filters that involve subqueries. - - val (subqueryFilters, flatFilters) = filters.partition { - case f => containsSubquery(f) + // For data skipping, avoid using the filters that either: + // 1. involve subqueries. + // 2. are non-deterministic. + var (ineligibleFilters, eligibleFilters) = filters.partition { + case f => containsSubquery(f) || !f.deterministic } - val (partitionFilters, dataFilters) = flatFilters - .partition(isPredicatePartitionColumnsOnly(_, partitionColumns, spark)) + + val (partitionFilters, dataFilters) = eligibleFilters + .partition(isPredicatePartitionColumnsOnly(_, partitionColumns, spark)) if (dataFilters.isEmpty) recordDeltaOperation(deltaLog, "delta.skipping.partition") { // When there are only partition filters we can scan allFiles @@ -1246,7 +1248,7 @@ trait DataSkippingReaderBase dataFilters = ExpressionSet(Nil), partitionLikeDataFilters = ExpressionSet(Nil), rewrittenPartitionLikeDataFilters = Set.empty, - unusedFilters = ExpressionSet(subqueryFilters), + unusedFilters = ExpressionSet(ineligibleFilters), scanDurationMs = System.currentTimeMillis() - startTime, dataSkippingType = getCorrectDataSkippingType(DeltaDataSkippingType.partitionFilteringOnlyV1) @@ -1323,7 +1325,7 @@ trait DataSkippingReaderBase dataFilters = ExpressionSet(skippingFilters.map(_._1)), partitionLikeDataFilters = ExpressionSet(partitionLikeFilters.map(_._1)), rewrittenPartitionLikeDataFilters = partitionLikeFilters.map(_._2.expr.expr).toSet, - unusedFilters = ExpressionSet(unusedFilters.map(_._1) ++ subqueryFilters), + unusedFilters = ExpressionSet(unusedFilters.map(_._1) ++ ineligibleFilters), scanDurationMs = System.currentTimeMillis() - startTime, dataSkippingType = getCorrectDataSkippingType(dataSkippingType) ) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/stats/PrepareDeltaScan.scala b/spark/src/main/scala/org/apache/spark/sql/delta/stats/PrepareDeltaScan.scala index 76552115bfe..341cecc3ec9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/stats/PrepareDeltaScan.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/stats/PrepareDeltaScan.scala @@ -114,22 +114,20 @@ trait PrepareDeltaScanBase extends Rule[LogicalPlan] limitOpt: Option[Int], filters: Seq[Expression], delta: LogicalRelation): DeltaScan = { - // Remove non-deterministic filters (e.g., rand() < 0.25) to prevent incorrect file pruning. - val deterministicFilters = filters.filter(_.deterministic) withStatusCode("DELTA", "Filtering files for query") { if (limitOpt.nonEmpty) { // If we trigger limit push down, the filters must be partition filters. Since // there are no data filters, we don't need to apply Generated Columns // optimization. See `DeltaTableScan` for more details. - return scanGenerator.filesForScan(limitOpt.get, deterministicFilters) + return scanGenerator.filesForScan(limitOpt.get, filters) } val filtersForScan = if (!GeneratedColumn.partitionFilterOptimizationEnabled(spark)) { - deterministicFilters + filters } else { val generatedPartitionFilters = GeneratedColumn.generatePartitionFilters( - spark, scanGenerator.snapshotToScan, deterministicFilters, delta) - deterministicFilters ++ generatedPartitionFilters + spark, scanGenerator.snapshotToScan, filters, delta) + filters ++ generatedPartitionFilters } scanGenerator.filesForScan(filtersForScan) } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/stats/DataSkippingDeltaTests.scala b/spark/src/test/scala/org/apache/spark/sql/delta/stats/DataSkippingDeltaTests.scala index 6ec60f2b303..43517728d00 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/stats/DataSkippingDeltaTests.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/stats/DataSkippingDeltaTests.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, PredicateHelper} +import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -1929,6 +1930,71 @@ trait DataSkippingDeltaTestsBase extends DeltaExcludedBySparkVersionTestMixinShi } } + test("File skipping with non-deterministic filters") { + withTable("tbl") { + // Create the table. + val df = spark.range(100).toDF() + df.write.mode("overwrite").format("delta").saveAsTable("tbl") + + // Append 9 times to the table. + for (i <- 1 to 9) { + val df = spark.range(i * 100, (i + 1) * 100).toDF() + df.write.mode("append").format("delta").insertInto("tbl") + } + + val query = "SELECT count(*) FROM tbl WHERE rand(0) < 0.25" + val result = sql(query).collect().head.getLong(0) + assert(result > 150, s"Expected around 250 rows (~0.25 * 1000), got: $result") + + val predicates = sql(query).queryExecution.optimizedPlan.collect { + case Filter(condition, _) => condition + }.flatMap(splitConjunctivePredicates) + val scanResult = DeltaLog.forTable(spark, TableIdentifier("tbl")) + .update().filesForScan(predicates) + assert(scanResult.unusedFilters.nonEmpty) + } + } + + test("File skipping with non-deterministic filters on partitioned tables") { + withTable("tbl_partitioned") { + import org.apache.spark.sql.functions.col + + // Create initial DataFrame and add a partition column. + val df = spark.range(100).toDF().withColumn("p", col("id") % 10) + df.write + .mode("overwrite") + .format("delta") + .partitionBy("p") + .saveAsTable("tbl_partitioned") + + // Append 9 more times to the table. + for (i <- 1 to 9) { + val newDF = spark.range(i * 100, (i + 1) * 100).toDF().withColumn("p", col("id") % 10) + newDF.write.mode("append").format("delta").insertInto("tbl_partitioned") + } + + // Run query with a nondeterministic filter. + val query = "SELECT count(*) FROM tbl_partitioned WHERE rand(0) < 0.25" + val result = sql(query).collect().head.getLong(0) + // Assert that the row count is as expected (e.g., roughly 25% of rows). + assert(result > 150, s"Expected a reasonable number of rows, got: $result") + + val predicates = sql(query).queryExecution.optimizedPlan.collect { + case Filter(condition, _) => condition + }.flatMap(splitConjunctivePredicates) + val scanResult = DeltaLog.forTable(spark, TableIdentifier("tbl_partitioned")) + .update().filesForScan(predicates) + assert(scanResult.unusedFilters.nonEmpty) + + // Assert that entries are fetched from all 10 partitions + val distinctPartitions = + sql("SELECT DISTINCT p FROM tbl_partitioned WHERE rand(0) < 0.25") + .collect() + .length + assert(distinctPartitions == 10) + } + } + protected def parse(deltaLog: DeltaLog, predicate: String): Seq[Expression] = { // We produce a wrong filter in this case otherwise