This is an automated email from the ASF dual-hosted git repository.
gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 349b261e72d3 [SPARK-56868][SQL] Extract V2 runtime-filter + partition
planning into a shared helper
349b261e72d3 is described below
commit 349b261e72d396a237dc56c57b567505055bc3c0
Author: Vitalii Li <[email protected]>
AuthorDate: Fri May 15 16:38:10 2026 -0700
[SPARK-56868][SQL] Extract V2 runtime-filter + partition planning into a
shared helper
### What changes were proposed in this pull request?
Lift the body of `BatchScanExec.filteredPartitions` (runtime filter
pushdown, re-planning, and `KeyedPartitioning` validation + `None`-padding)
into a new `PushDownUtils.filterAndPlanPartitions` helper so the logic can be
reused by alternative DataSourceV2 physical scan operators.
### Why are the changes needed?
The runtime-filter pushdown pipeline in `BatchScanExec.filteredPartitions`
contains non-trivial logic: V2 predicate translation (DPP + scalar subqueries
via SPARK-56467), iterative `PartitionPredicate` pushdown (SPARK-55596),
`KeyedPartitioning` validation, and `None`-padding to preserve SPJ key
alignment. This refactoring allows for reuse of this logic by alternative scan
operators.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing coverage, no new logic added
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude (Anthropic), via Claude Code
Closes #55887 from vitaliili-db/spark-unify-v2-runtime-filter-helper.
Authored-by: Vitalii Li <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
(cherry picked from commit 6a855fdce5d2da770646b937c498bd2523e021c8)
Signed-off-by: Gengliang Wang <[email protected]>
---
.../execution/datasources/v2/BatchScanExec.scala | 71 ++-----------
.../execution/datasources/v2/PushDownUtils.scala | 117 ++++++++++++++++++++-
2 files changed, 123 insertions(+), 65 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index e9a18833ed9a..ea25f3b1c85f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -19,13 +19,12 @@ package org.apache.spark.sql.execution.datasources.v2
import java.util.Objects
-import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning,
SinglePartition}
-import org.apache.spark.sql.catalyst.util.{truncatedString,
InternalRowComparableWrapper}
+import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.read._
import org.apache.spark.util.ArrayImplicits._
@@ -60,64 +59,14 @@ case class BatchScanExec(
batch.planInputPartitions().toImmutableArraySeq
// Visible for testing
- @transient private[sql] lazy val filteredPartitions:
Seq[Option[InputPartition]] = {
- val originalPartitioning = outputPartitioning
-
- val filtered = PushDownUtils.pushRuntimeFilters(scan, runtimeFilters,
table, output)
- if (filtered) {
- // call toBatch again to get filtered partitions
- val newPartitions = scan.toBatch.planInputPartitions()
-
- originalPartitioning match {
- case k: KeyedPartitioning =>
- if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
- throw new SparkException("Data source must have preserved the
original partitioning " +
- "during runtime filtering: not all partitions implement
HasPartitionKey after " +
- "filtering")
- }
-
- val inputMap =
k.partitionKeys.groupBy(identity).view.mapValues(_.size)
- val comparableKeyWrapperFactory = InternalRowComparableWrapper
- .getInternalRowComparableWrapperFactory(k.expressionDataTypes)
- val filteredMap = newPartitions.groupBy(
- p =>
comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey())
- )
-
- if (!filteredMap.keySet.subsetOf(inputMap.keySet)) {
- throw new SparkException("During runtime filtering, data source
must not report new " +
- "partition keys that are not present in the original
partitioning.")
- }
-
- inputMap.toSeq
- .sortBy(_._1)(k.keyOrdering)
- .flatMap { case (key, size) =>
- // We require the new number of partitions to be equal or less
than the old number of
- // partitions for a given key. In the case of less than, empty
partitions are added.
- val fps = filteredMap.getOrElse(key, Array.empty)
-
- if (fps.size > size) {
- throw new SparkException("During runtime filtering, data
source must not report " +
- s"new partitions for a given key. Before: $size partitions.
" +
- s"After: ${fps.size} partitions")
- }
-
- fps.map(Some).padTo(size, None)
- }
-
- case _ =>
- // no validation is needed as the data source did not report any
specific partitioning
- newPartitions.toSeq.map(Some)
- }
-
- } else {
- (originalPartitioning match {
- case k: KeyedPartitioning =>
-
inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering)
-
- case _ => inputPartitions
- }).map(Some)
- }
- }
+ @transient private[sql] lazy val filteredPartitions:
Seq[Option[InputPartition]] =
+ PushDownUtils.replanWithRuntimeFilters(
+ scan,
+ runtimeFilters,
+ table,
+ output,
+ outputPartitioning,
+ inputPartitions)
override lazy val readerFactory: PartitionReaderFactory =
batch.createReaderFactory()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index e31e81fc1fa9..dc6de6f29af9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -19,17 +19,19 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
+import org.apache.spark.SparkException
import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
AttributeSet, DynamicPruning, DynamicPruningExpression, Expression,
ExpressionSet, GetStructField, NamedExpression, PythonUDF, SchemaPruning,
SubqueryExpression, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.logical.SampleMethod
+import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning,
Partitioning}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
-import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.catalyst.util.{CharVarcharUtils,
InternalRowComparableWrapper}
import org.apache.spark.sql.connector.catalog.Table
-import org.apache.spark.sql.connector.expressions.{IdentityTransform,
SortOrder}
+import org.apache.spark.sql.connector.expressions.{IdentityTransform,
SortOrder, Transform}
import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.connector.read.{SampleMethod => SampleMethodV2,
Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit,
SupportsPushDownOffset, SupportsPushDownRequiredColumns,
SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters,
SupportsRuntimeV2Filtering}
+import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition,
SampleMethod => SampleMethodV2, Scan, ScanBuilder, SupportsPushDownFilters,
SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns,
SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters,
SupportsRuntimeV2Filtering}
import org.apache.spark.sql.execution.{ScalarSubquery => ExecScalarSubquery}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy,
DataSourceUtils}
import org.apache.spark.sql.internal.SQLConf
@@ -187,6 +189,101 @@ object PushDownUtils extends Logging {
}
}
+ /**
+ * Pushes runtime filters into `scan` and re-plans its input partitions. For
scans whose
+ * `outputPartitioning` is a [[KeyedPartitioning]] (SPJ-active), validates
that the data source
+ * preserved the original partitioning and pads with `None` to preserve key
alignment with the
+ * pre-filter partition set.
+ *
+ * Must be called at execute time: runtime filters carry
[[DynamicPruningExpression]] and
+ * scalar-subquery references whose values are only resolved after their
broadcast/subquery
+ * side completes. The mutating [[pushRuntimeFilters]] call must run at most
once per scan
+ * instance; callers are responsible for caching the result.
+ *
+ * Precondition: when `outputPartitioning` is a [[KeyedPartitioning]], every
element of
+ * `originalPartitions` (and every partition re-planned by the data source)
must implement
+ * [[HasPartitionKey]].
+ *
+ * @param scan the V2 scan to push filters into
+ * @param runtimeFilters runtime filters to translate and push
+ * @param table the table backing the scan, used to derive the
partition-predicate
+ * schema for iterative [[PartitionPredicate]]
pushdown
+ * @param output scan output attributes
+ * @param outputPartitioning Spark-side output partitioning (used for SPJ
validation)
+ * @param originalPartitions unfiltered partitions, consulted only when no
runtime filters fire
+ * @return one entry per original input partition: `Some(part)` for
surviving partitions and
+ * `None` for partition keys whose splits were entirely pruned (SPJ
alignment)
+ */
+ def replanWithRuntimeFilters(
+ scan: Scan,
+ runtimeFilters: Seq[Expression],
+ table: Table,
+ output: Seq[AttributeReference],
+ outputPartitioning: Partitioning,
+ originalPartitions: => Seq[InputPartition]): Seq[Option[InputPartition]]
= {
+ val filtered = pushRuntimeFilters(scan, runtimeFilters, table, output)
+ if (filtered) {
+ // call toBatch again to get filtered partitions
+ val newPartitions = scan.toBatch.planInputPartitions()
+
+ outputPartitioning match {
+ case k: KeyedPartitioning =>
+ if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
+ throw new SparkException("Data source must have preserved the
original partitioning " +
+ "during runtime filtering: not all partitions implement
HasPartitionKey after " +
+ "filtering")
+ }
+
+ val inputMap =
k.partitionKeys.groupBy(identity).view.mapValues(_.size)
+ val comparableKeyWrapperFactory = InternalRowComparableWrapper
+ .getInternalRowComparableWrapperFactory(k.expressionDataTypes)
+ val filteredMap = newPartitions.groupBy(
+ p =>
comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey())
+ )
+
+ if (!filteredMap.keySet.subsetOf(inputMap.keySet)) {
+ throw new SparkException("During runtime filtering, data source
must not report new " +
+ "partition keys that are not present in the original
partitioning.")
+ }
+
+ // Pad the post-filter partitions with `None` per original key so
SPJ key alignment with
+ // the other side of the join is preserved when splits are entirely
pruned.
+ inputMap.toSeq
+ .sortBy(_._1)(k.keyOrdering)
+ .flatMap { case (key, size) =>
+ // We require the new number of partitions to be equal or less
than the old number of
+ // partitions for a given key. In the case of less than, empty
partitions are added.
+ val fps = filteredMap.getOrElse(key, Array.empty)
+
+ if (fps.size > size) {
+ throw new SparkException("During runtime filtering, data
source must not report " +
+ s"new partitions for a given key. Before: $size partitions.
" +
+ s"After: ${fps.size} partitions")
+ }
+
+ fps.map(Some).padTo(size, None)
+ }
+
+ case _ =>
+ // no validation is needed as the data source did not report any
specific partitioning
+ newPartitions.toSeq.map(Some)
+ }
+
+ } else {
+ val parts = originalPartitions
+ (outputPartitioning match {
+ case k: KeyedPartitioning =>
+ if (parts.exists(!_.isInstanceOf[HasPartitionKey])) {
+ throw new SparkException("Original partitions must implement
HasPartitionKey when " +
+ "outputPartitioning is KeyedPartitioning.")
+ }
+
parts.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering)
+
+ case _ => parts
+ }).map(Some)
+ }
+ }
+
/**
* Returns a Seq of [[PartitionPredicateField]] representing partition
transform expression types,
* if schema is supported for [[PartitionPredicate]] push down. None if not
supported.
@@ -202,7 +299,19 @@ object PushDownUtils extends Logging {
*/
def getPartitionPredicateSchema(table: Table, output:
Seq[AttributeReference])
: Option[Seq[PartitionPredicateField]] = {
- val transforms = table.partitioning
+ getPartitionPredicateSchema(table.partitioning, output)
+ }
+
+ /**
+ * Returns a Seq of [[PartitionPredicateField]] representing partition
transform expression types,
+ * if schema is supported for [[PartitionPredicate]] push down. None if not
supported.
+ *
+ * Use this overload when the caller has access to the partition transforms
but not the
+ * full [[Table]].
+ */
+ def getPartitionPredicateSchema(
+ transforms: Array[Transform],
+ output: Seq[AttributeReference]): Option[Seq[PartitionPredicateField]] =
{
if (transforms.isEmpty) {
None
} else {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]