Skip to content

Commit

Permalink
Parallel calls
Browse files Browse the repository at this point in the history
  • Loading branch information
sumeet-db committed Feb 11, 2025
1 parent cd67546 commit 6c2ea22
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1222,14 +1222,16 @@ trait DataSkippingReaderBase
import DeltaTableUtils._
val partitionColumns = metadata.partitionColumns

// For data skipping, avoid using the filters that involve subqueries.

val (subqueryFilters, flatFilters) = filters.partition {
case f => containsSubquery(f)
// For data skipping, avoid using the filters that either:
// 1. involve subqueries.
// 2. are non-deterministic.
var (ineligibleFilters, eligibleFilters) = filters.partition {
case f => containsSubquery(f) || !f.deterministic
}

val (partitionFilters, dataFilters) = flatFilters
.partition(isPredicatePartitionColumnsOnly(_, partitionColumns, spark))

val (partitionFilters, dataFilters) = eligibleFilters
.partition(isPredicatePartitionColumnsOnly(_, partitionColumns, spark))

if (dataFilters.isEmpty) recordDeltaOperation(deltaLog, "delta.skipping.partition") {
// When there are only partition filters we can scan allFiles
Expand All @@ -1246,7 +1248,7 @@ trait DataSkippingReaderBase
dataFilters = ExpressionSet(Nil),
partitionLikeDataFilters = ExpressionSet(Nil),
rewrittenPartitionLikeDataFilters = Set.empty,
unusedFilters = ExpressionSet(subqueryFilters),
unusedFilters = ExpressionSet(ineligibleFilters),
scanDurationMs = System.currentTimeMillis() - startTime,
dataSkippingType =
getCorrectDataSkippingType(DeltaDataSkippingType.partitionFilteringOnlyV1)
Expand Down Expand Up @@ -1323,7 +1325,7 @@ trait DataSkippingReaderBase
dataFilters = ExpressionSet(skippingFilters.map(_._1)),
partitionLikeDataFilters = ExpressionSet(partitionLikeFilters.map(_._1)),
rewrittenPartitionLikeDataFilters = partitionLikeFilters.map(_._2.expr.expr).toSet,
unusedFilters = ExpressionSet(unusedFilters.map(_._1) ++ subqueryFilters),
unusedFilters = ExpressionSet(unusedFilters.map(_._1) ++ ineligibleFilters),
scanDurationMs = System.currentTimeMillis() - startTime,
dataSkippingType = getCorrectDataSkippingType(dataSkippingType)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,20 @@ trait PrepareDeltaScanBase extends Rule[LogicalPlan]
limitOpt: Option[Int],
filters: Seq[Expression],
delta: LogicalRelation): DeltaScan = {
// Remove non-deterministic filters (e.g., rand() < 0.25) to prevent incorrect file pruning.
val deterministicFilters = filters.filter(_.deterministic)
withStatusCode("DELTA", "Filtering files for query") {
if (limitOpt.nonEmpty) {
// If we trigger limit push down, the filters must be partition filters. Since
// there are no data filters, we don't need to apply Generated Columns
// optimization. See `DeltaTableScan` for more details.
return scanGenerator.filesForScan(limitOpt.get, deterministicFilters)
return scanGenerator.filesForScan(limitOpt.get, filters)
}
val filtersForScan =
if (!GeneratedColumn.partitionFilterOptimizationEnabled(spark)) {
deterministicFilters
filters
} else {
val generatedPartitionFilters = GeneratedColumn.generatePartitionFilters(
spark, scanGenerator.snapshotToScan, deterministicFilters, delta)
deterministicFilters ++ generatedPartitionFilters
spark, scanGenerator.snapshotToScan, filters, delta)
filters ++ generatedPartitionFilters
}
scanGenerator.filesForScan(filtersForScan)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -1929,6 +1930,71 @@ trait DataSkippingDeltaTestsBase extends DeltaExcludedBySparkVersionTestMixinShi
}
}

test("File skipping with non-deterministic filters") {
withTable("tbl") {
// Create the table.
val df = spark.range(100).toDF()
df.write.mode("overwrite").format("delta").saveAsTable("tbl")

// Append 9 times to the table.
for (i <- 1 to 9) {
val df = spark.range(i * 100, (i + 1) * 100).toDF()
df.write.mode("append").format("delta").insertInto("tbl")
}

val query = "SELECT count(*) FROM tbl WHERE rand(0) < 0.25"
val result = sql(query).collect().head.getLong(0)
assert(result > 150, s"Expected around 250 rows (~0.25 * 1000), got: $result")

val predicates = sql(query).queryExecution.optimizedPlan.collect {
case Filter(condition, _) => condition
}.flatMap(splitConjunctivePredicates)
val scanResult = DeltaLog.forTable(spark, TableIdentifier("tbl"))
.update().filesForScan(predicates)
assert(scanResult.unusedFilters.nonEmpty)
}
}

test("File skipping with non-deterministic filters on partitioned tables") {
withTable("tbl_partitioned") {
import org.apache.spark.sql.functions.col

// Create initial DataFrame and add a partition column.
val df = spark.range(100).toDF().withColumn("p", col("id") % 10)
df.write
.mode("overwrite")
.format("delta")
.partitionBy("p")
.saveAsTable("tbl_partitioned")

// Append 9 more times to the table.
for (i <- 1 to 9) {
val newDF = spark.range(i * 100, (i + 1) * 100).toDF().withColumn("p", col("id") % 10)
newDF.write.mode("append").format("delta").insertInto("tbl_partitioned")
}

// Run query with a nondeterministic filter.
val query = "SELECT count(*) FROM tbl_partitioned WHERE rand(0) < 0.25"
val result = sql(query).collect().head.getLong(0)
// Assert that the row count is as expected (e.g., roughly 25% of rows).
assert(result > 150, s"Expected a reasonable number of rows, got: $result")

val predicates = sql(query).queryExecution.optimizedPlan.collect {
case Filter(condition, _) => condition
}.flatMap(splitConjunctivePredicates)
val scanResult = DeltaLog.forTable(spark, TableIdentifier("tbl_partitioned"))
.update().filesForScan(predicates)
assert(scanResult.unusedFilters.nonEmpty)

// Assert that entries are fetched from all 10 partitions
val distinctPartitions =
sql("SELECT DISTINCT p FROM tbl_partitioned WHERE rand(0) < 0.25")
.collect()
.length
assert(distinctPartitions == 10)
}
}

protected def parse(deltaLog: DeltaLog, predicate: String): Seq[Expression] = {

// We produce a wrong filter in this case otherwise
Expand Down

0 comments on commit 6c2ea22

Please sign in to comment.