This is an automated email from the ASF dual-hosted git repository.

peter-toth pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new 907fd75441ea [SPARK-56877][SQL] Enforce `KeyedPartitioning` invariant 
in `PartitioningCollection`
907fd75441ea is described below

commit 907fd75441eaed1904e8a6a954fe47dc7935b2b5
Author: Peter Toth <[email protected]>
AuthorDate: Mon May 18 11:08:46 2026 +0200

    [SPARK-56877][SQL] Enforce `KeyedPartitioning` invariant in 
`PartitioningCollection`
    
    ### What changes were proposed in this pull request?
    
    - Add a `require` in `PartitioningCollection` that all `KeyedPartitioning`s 
reachable from the collection share the same `partitionKeys` reference (`eq`) 
and have matching expression arity. The check walks the partitioning tree via 
`foreach` so nested collections are covered.
    - Add a smart factory `PartitioningCollection.fromPartitionings` that 
interns `partitionKeys` references across `KeyedPartitioning`s. Use this at 
sites that combine independently-computed partitionings (joins) where keys are 
structurally equal but not reference-equal. The factory uses manual recursion 
rather than `transformWithPruning` because `KeyedPartitioning.equals` compares 
`partitionKeys` element-wise, which would make `transformWithPruning` discard 
the rule's replacement as str [...]
    - In `GroupPartitionsExec.outputPartitioning`, hoist `val partitionKeys = 
groupedPartitions.map(_._1)` above the `transform` so every rebuilt 
`KeyedPartitioning` shares the same `partitionKeys` reference. Drop the ad-hoc 
consistency assert (now enforced by `PartitioningCollection`).
    - Switch `ShuffledJoin` and `StreamingSymmetricHashJoinExec` to 
`PartitioningCollection.fromPartitionings` for their inner-join 
`outputPartitioning`.
    - Update affected tests to construct collections via `fromPartitionings`. 
Rewrite the `SPARK-46367` arity-mismatch test in 
`ProjectedOrderingAndPartitioningSuite` since the scenario is now rejected at 
`PartitioningCollection` construction rather than inside 
`AliasAwareOutputExpression`.
    
    ### Why are the changes needed?
    
    The "all `KeyedPartitioning`s in a collection must agree on 
`partitionKeys`" invariant already existed informally -- 
`GroupPartitionsExec.outputPartitioning` had a runtime assert checking `==`, 
`AliasAwareOutputExpression.projectKeyedPartitionings` asserted matching arity, 
and various consumers relied on the invariant being upheld. Consolidating the 
check into the `PartitioningCollection` constructor makes it load-bearing: any 
future construction site that violates it fails immediatel [...]
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing test suites (`EnsureRequirementsSuite`, 
`GroupPartitionsExecSuite`, `ProjectedOrderingAndPartitioningSuite`) updated to 
use `PartitioningCollection.fromPartitionings` where they previously 
constructed collections from independently-built `KeyedPartitioning`s. The 
`SPARK-46367` test was rewritten to assert that the invalid mixed-arity 
scenario is rejected at `PartitioningCollection` construction.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code 4.7
    
    Closes #55901 from 
peter-toth/SPARK-56877-enforce-keyedpartitioning-invariant-in-collection.
    
    Authored-by: Peter Toth <[email protected]>
    Signed-off-by: Peter Toth <[email protected]>
    (cherry picked from commit c8528a73d4b7a205e44b1530e6423243be963af4)
    Signed-off-by: Peter Toth <[email protected]>
---
 .../sql/catalyst/plans/physical/partitioning.scala | 56 ++++++++++++++++++++++
 .../sql/execution/AliasAwareOutputExpression.scala | 17 ++-----
 .../datasources/v2/GroupPartitionsExec.scala       | 20 ++------
 .../execution/joins/BroadcastHashJoinExec.scala    |  2 +-
 .../spark/sql/execution/joins/ShuffledJoin.scala   |  3 +-
 .../join/StreamingSymmetricHashJoinExec.scala      |  3 +-
 .../ProjectedOrderingAndPartitioningSuite.scala    | 25 ++++------
 .../datasources/v2/GroupPartitionsExecSuite.scala  |  2 +-
 .../exchange/EnsureRequirementsSuite.scala         | 12 ++---
 9 files changed, 87 insertions(+), 53 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index cc50da1f17fd..f331cd124759 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -691,6 +691,12 @@ case class RangePartitioning(ordering: Seq[SortOrder], 
numPartitions: Int)
  * `HashPartitioning(B.key2)`. It is also worth noting that `partitionings`
  * in this collection do not need to be equivalent, which is useful for
  * Outer Join operators.
+ *
+ * [[KeyedPartitioning]]s within a `PartitioningCollection` describe the same 
physical partitioning
+ * and so must share the same `partitionKeys` reference, differing only in 
their `expressions` (with
+ * matching arity). Use [[PartitioningCollection.fromPartitionings]] to build 
a collection from
+ * independently-computed partitionings (e.g. join `outputPartitioning`); it 
interns `partitionKeys`
+ * references (including across nested collections) so the invariant holds.
  */
 case class PartitioningCollection(partitionings: Seq[Partitioning])
   extends Expression with Partitioning with Unevaluable {
@@ -699,6 +705,26 @@ case class PartitioningCollection(partitionings: 
Seq[Partitioning])
     partitionings.map(_.numPartitions).distinct.length == 1,
     s"PartitioningCollection requires all of its partitionings have the same 
numPartitions.")
 
+  checkKeyedPartitioningInvariant()
+
+  private def checkKeyedPartitioningInvariant(): Unit = {
+    var first: KeyedPartitioning = null
+    foreach {
+      case k: KeyedPartitioning =>
+        if (first == null) {
+          first = k
+        } else {
+          require(k.expressions.length == first.expressions.length,
+            "All KeyedPartitionings in a PartitioningCollection must have 
matching expression " +
+              "arity")
+          require(k.partitionKeys eq first.partitionKeys,
+            "All KeyedPartitionings in a PartitioningCollection must share the 
same " +
+              "partitionKeys reference")
+        }
+      case _ =>
+    }
+  }
+
   override def children: Seq[Expression] = partitionings.collect {
     case expr: Expression => expr
   }
@@ -730,6 +756,36 @@ case class PartitioningCollection(partitionings: 
Seq[Partitioning])
     
super.legacyWithNewChildren(newChildren).asInstanceOf[PartitioningCollection]
 }
 
+object PartitioningCollection {
+  /**
+   * Builds a [[PartitioningCollection]], unifying the `partitionKeys` 
reference across all
+   * [[KeyedPartitioning]]s (including those in nested collections). Use this 
when combining
+   * independently-computed partitionings (e.g. join `outputPartitioning`) 
where
+   * `KeyedPartitioning.partitionKeys` are structurally equal but may not be 
reference-equal.
+   *
+   * Note: this can't be implemented with `TreeNode.transform`.
+   */
+  def fromPartitionings(partitionings: Seq[Partitioning]): 
PartitioningCollection = {
+    var canonicalKeys: Seq[InternalRowComparableWrapper] = null
+    def intern(p: Partitioning): Partitioning = p match {
+      case k: KeyedPartitioning =>
+        if (canonicalKeys == null) {
+          canonicalKeys = k.partitionKeys
+          k
+        } else if (k.partitionKeys ne canonicalKeys) {
+          require(k.partitionKeys == canonicalKeys,
+            "All KeyedPartitionings in a PartitioningCollection must have 
equal partitionKeys")
+          k.copy(partitionKeys = canonicalKeys)
+        } else {
+          k
+        }
+      case pc: PartitioningCollection => new 
PartitioningCollection(pc.partitionings.map(intern))
+      case other => other
+    }
+    new PartitioningCollection(partitionings.map(intern))
+  }
+}
+
 /**
  * Represents a partitioning where rows are collected, transformed and 
broadcasted to each
  * node in the cluster.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
index 1f2b1d0a585d..b37e1b258e9b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
@@ -40,7 +40,7 @@ trait PartitioningPreservingUnaryExecNode extends 
UnaryExecNode
     (projectedKPs ++ projectedOthers).take(aliasCandidateLimit) match {
       case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions)
       case Seq(p) => p
-      case ps => PartitioningCollection(ps)
+      case ps => PartitioningCollection.fromPartitionings(ps)
     }
   }
 
@@ -88,22 +88,15 @@ trait PartitioningPreservingUnaryExecNode extends 
UnaryExecNode
    *
    * The resulting [[KeyedPartitioning]]s are the cross-product of the 
per-position alternatives
    * restricted to the projectable positions. All share the same 
`partitionKeys` object (projected
-   * to the same subset of positions), preserving the invariant required by 
[[GroupPartitionsExec]].
+   * to the same subset of positions), preserving the invariant required by
+   * [[PartitioningCollection]].
    */
   private def projectKeyedPartitionings(
       kps: Seq[KeyedPartitioning]): LazyList[KeyedPartitioning] = {
     if (kps.isEmpty) return LazyList.empty
+    // All input KPs share the same `partitionKeys` reference and matching 
arity by the
+    // [[PartitioningCollection]] invariant (the only producer of multi-KP 
inputs here).
     val numPositions = kps.head.expressions.length
-    // The function assumes all input KPs share the same `partitionKeys`, 
which implies matching
-    // expression arity. This invariant is asserted by [[GroupPartitionsExec]] 
and is established
-    // by the constructors of [[PartitioningCollection]] feeding this method 
(a join's
-    // `PartitioningCollection(left.outputPartitioning, 
right.outputPartitioning)` combines KPs
-    // that have been aligned by [[EnsureRequirements]] to the same join 
keys). If the invariant
-    // is ever violated upstream, fail early with a clear message instead of 
throwing an opaque
-    // `IndexOutOfBoundsException` from `kp.expressions(i)` below.
-    assert(kps.tail.forall(_.expressions.length == numPositions),
-      s"All input KeyedPartitionings must share the same expression arity, " +
-        s"but got: ${kps.map(_.expressions.length).mkString(", ")}.")
 
     val alternativesPerPosition: IndexedSeq[LazyList[Expression]] =
       if (hasAlias) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala
index 264a0e954936..4d87be662293 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala
@@ -67,24 +67,14 @@ case class GroupPartitionsExec(
   override def outputPartitioning: Partitioning = {
     child.outputPartitioning match {
       case p: Partitioning with Expression =>
-        // There can be multiple `KeyedPartitioning` in an output partitioning 
of a join, but they
-        // can only differ in `expressions`. `partitionKeys` must match so we 
can calculate it only
-        // once via `groupedPartitions`.
-
-        val keyedPartitionings = p.collect { case k: KeyedPartitioning => k }
-        if (keyedPartitionings.size > 1) {
-          val first = keyedPartitionings.head
-          keyedPartitionings.tail.foreach { k =>
-            assert(k.partitionKeys == first.partitionKeys,
-              "All KeyedPartitioning nodes must have identical partition keys")
-          }
-        }
-
+        // There can be multiple `KeyedPartitioning`s in an output 
partitioning of a join, but they
+        // can only differ in `expressions`; their `partitionKeys` reference 
is shared (enforced by
+        // `PartitioningCollection`), so `groupedPartitions` is computed only 
once.
+        val partitionKeys = groupedPartitions.map(_._1)
         p.transform {
           case k: KeyedPartitioning =>
             val projectedExpressions = 
joinKeyPositions.fold(k.expressions)(_.map(k.expressions))
-            KeyedPartitioning(projectedExpressions, 
groupedPartitions.map(_._1),
-              isGrouped = isGrouped)
+            KeyedPartitioning(projectedExpressions, partitionKeys, isGrouped = 
isGrouped)
         }.asInstanceOf[Partitioning]
       case o => o
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index e4f18c9144dd..2881aeac55d8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -84,7 +84,7 @@ case class BroadcastHashJoinExec private(
           // constructor prevents that.
 
           case p :: Nil => p
-          case ps => PartitioningCollection(ps)
+          case ps => PartitioningCollection.fromPartitionings(ps)
         }
       case _ => streamedPlan.outputPartitioning
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
index f363156c81e5..3fb968bfea7a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
@@ -46,7 +46,8 @@ trait ShuffledJoin extends JoinCodegenSupport {
 
   override def outputPartitioning: Partitioning = joinType match {
     case _: InnerLike =>
-      PartitioningCollection(Seq(left.outputPartitioning, 
right.outputPartitioning))
+      PartitioningCollection.fromPartitionings(
+        Seq(left.outputPartitioning, right.outputPartitioning))
     case LeftOuter | LeftSingle => left.outputPartitioning
     case RightOuter => right.outputPartitioning
     case FullOuter => 
UnknownPartitioning(left.outputPartitioning.numPartitions)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
index 71a7d4cf56e1..9eca04c98591 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
@@ -242,7 +242,8 @@ case class StreamingSymmetricHashJoinExec(
 
   override def outputPartitioning: Partitioning = joinType match {
     case _: InnerLike =>
-      PartitioningCollection(Seq(left.outputPartitioning, 
right.outputPartitioning))
+      PartitioningCollection.fromPartitionings(
+        Seq(left.outputPartitioning, right.outputPartitioning))
     case LeftOuter => left.outputPartitioning
     case RightOuter => right.outputPartitioning
     case FullOuter => 
UnknownPartitioning(left.outputPartitioning.numPartitions)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
index a38570924620..a70baece7784 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
@@ -387,7 +387,7 @@ class ProjectedOrderingAndPartitioningSuite
     val y = AttributeReference("y", IntegerType)()
     val yAlias = AttributeReference("y_alias", IntegerType)()
     val keys2d = Seq(InternalRow(1, 1), InternalRow(1, 2), InternalRow(2, 1), 
InternalRow(2, 2))
-    val childPartitioning = PartitioningCollection(Seq(
+    val childPartitioning = PartitioningCollection.fromPartitionings(Seq(
       KeyedPartitioning(Seq(x, y), keys2d),
       KeyedPartitioning(Seq(x, yAlias), keys2d)))
     val child = DummyLeafExecWithPartitioning(
@@ -587,27 +587,20 @@ class ProjectedOrderingAndPartitioningSuite
     }
   }
 
-  test("SPARK-46367: mixed-arity KeyedPartitionings in input fail with a clear 
assertion") {
-    // The function assumes all input KPs share the same arity (the invariant 
asserted by
-    // `GroupPartitionsExec`). Without the assert below, indexing 
`kp.expressions(i)` for
-    // `i >= kp.expressions.length` would throw an opaque 
`IndexOutOfBoundsException`. The assert
-    // surfaces the real cause -- an upstream node violated the invariant -- 
so the bug can be
-    // fixed at the producer.
+  test("SPARK-46367: mixed-arity KeyedPartitionings rejected by 
PartitioningCollection") {
+    // PartitioningCollection enforces matching expression arity (and shared 
partitionKeys
+    // references) across all its KeyedPartitionings, so the invariant 
required by
+    // `AliasAwareOutputExpression` cannot be violated by the input.
     val x = AttributeReference("x", IntegerType)()
     val y = AttributeReference("y", IntegerType)()
     val keys2d = Seq(InternalRow(1, 1), InternalRow(2, 2))
     val keys1d = Seq(InternalRow(1), InternalRow(2))
-    val child = DummyLeafExecWithPartitioning(
-      output = Seq(x, y),
-      partitioning = PartitioningCollection(Seq(
+    val e = intercept[IllegalArgumentException] {
+      PartitioningCollection.fromPartitionings(Seq(
         KeyedPartitioning(Seq(x, y), keys2d),
-        KeyedPartitioning(Seq(x), keys1d))))
-    val project = ProjectExec(Seq(x), child)
-    val e = intercept[AssertionError] {
-      project.outputPartitioning
+        KeyedPartitioning(Seq(x), keys1d)))
     }
-    assert(e.getMessage.contains("All input KeyedPartitionings must share the 
same expression " +
-      "arity"))
+    assert(e.getMessage.contains("partitionKeys"))
   }
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala
index 5d2adeb0c00a..51951d68cc60 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala
@@ -97,7 +97,7 @@ class GroupPartitionsExecSuite extends SharedSparkSession {
     val leftKP = KeyedPartitioning(Seq(exprA), partitionKeys)
     val rightKP = KeyedPartitioning(Seq(exprB), partitionKeys)
     val child = DummySparkPlan(
-      outputPartitioning = PartitioningCollection(Seq(leftKP, rightKP)),
+      outputPartitioning = 
PartitioningCollection.fromPartitionings(Seq(leftKP, rightKP)),
       outputOrdering = Seq(SortOrder(exprA, Ascending, sameOrderExpressions = 
Seq(exprB))))
     val gpe = GroupPartitionsExec(child)
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 1e35985f5049..74b706bce34f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -821,7 +821,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
         KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprB) :: Nil, 
Seq.empty)
     )
     plan2 = new DummySparkPlanWithBatchScanChild(
-      outputPartitioning = PartitioningCollection(Seq(
+      outputPartitioning = PartitioningCollection.fromPartitionings(Seq(
         KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 
Seq.empty),
         KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 
Seq.empty))
       )
@@ -1050,7 +1050,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
 
       // With partition collections
       plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning =
-        PartitioningCollection(
+        PartitioningCollection.fromPartitionings(
           Seq(KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, 
leftPartValues),
             KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, 
leftPartValues))
         )
@@ -1077,13 +1077,13 @@ class EnsureRequirementsSuite extends 
SharedSparkSession {
 
       // Nested partition collections
       plan2 = new DummySparkPlanWithBatchScanChild(outputPartitioning =
-        PartitioningCollection(
+        PartitioningCollection.fromPartitionings(
           Seq(
-            PartitioningCollection(
+            PartitioningCollection.fromPartitionings(
               Seq(
                 KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, 
rightPartValues),
                 KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, 
rightPartValues))),
-              PartitioningCollection(
+              PartitioningCollection.fromPartitionings(
                 Seq(
                   KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: 
Nil, rightPartValues),
                   KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: 
Nil, rightPartValues)))
@@ -1539,7 +1539,7 @@ private case class DummyBothKPBinaryExec(left: SparkPlan, 
right: SparkPlan)
   override def output: Seq[Attribute] = left.output ++ right.output
   override def outputOrdering: Seq[SortOrder] = left.outputOrdering
   override def outputPartitioning: Partitioning =
-    PartitioningCollection(Seq(left.outputPartitioning, 
right.outputPartitioning))
+    PartitioningCollection.fromPartitionings(Seq(left.outputPartitioning, 
right.outputPartitioning))
   override protected def doExecute(): RDD[InternalRow] = null
   override protected def withNewChildrenInternal(
       newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to