This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1824dd3821bf [SPARK-53406][SQL] Avoid unnecessary shuffle join in
direct passthrough shuffle id
1824dd3821bf is described below
commit 1824dd3821bf5d9218317eccac6d54e571f1afce
Author: Shujing Yang <[email protected]>
AuthorDate: Sun Oct 26 22:18:57 2025 +0800
[SPARK-53406][SQL] Avoid unnecessary shuffle join in direct passthrough
shuffle id
### What changes were proposed in this pull request?
This PR implements compatibility checking for ShufflePartitionIdPassThrough
partitioning to avoid unnecessary shuffle operations when both sides of a join
use compatible direct partition ID pass-through.
### Why are the changes needed?
Improve performance
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New unit tests
### Was this patch authored or co-authored using generative AI tooling?
Yes
Closes #52443 from shujingyang-db/shuffle-spec-direct-partition.
Lead-authored-by: Shujing Yang <[email protected]>
Co-authored-by: Shujing Yang
<[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/plans/physical/partitioning.scala | 50 +++++-
.../spark/sql/catalyst/ShuffleSpecSuite.scala | 63 +++++++
.../execution/exchange/EnsureRequirements.scala | 33 +++-
.../exchange/EnsureRequirementsSuite.scala | 186 +++++++++++++++++++++
4 files changed, 327 insertions(+), 5 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index f855483ea3c3..1cbb49c7a1f7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -638,7 +638,10 @@ case class ShufflePartitionIdPassThrough(
expr: DirectShufflePartitionID,
numPartitions: Int) extends Expression with Partitioning with Unevaluable {
- // TODO(SPARK-53401): Support Shuffle Spec in Direct Partition ID Pass
Through
+ override def createShuffleSpec(distribution: ClusteredDistribution):
ShuffleSpec = {
+ ShufflePartitionIdPassThroughSpec(this, distribution)
+ }
+
def partitionIdExpression: Expression = Pmod(expr.child,
Literal(numPartitions))
def expressions: Seq[Expression] = expr :: Nil
@@ -966,6 +969,51 @@ object KeyGroupedShuffleSpec {
}
}
+case class ShufflePartitionIdPassThroughSpec(
+ partitioning: ShufflePartitionIdPassThrough,
+ distribution: ClusteredDistribution) extends ShuffleSpec {
+
+ /**
+ * A sequence where each element is a set of positions of the partition key
to the cluster
+ * keys. Similar to HashShuffleSpec, this maps the partitioning expression
to positions
+ * in the distribution clustering keys.
+ */
+ lazy val keyPositions: mutable.BitSet = {
+ val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
+ distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos)
=>
+ distKeyToPos.getOrElseUpdate(distKey.canonicalized,
mutable.BitSet.empty).add(distKeyPos)
+ }
+ distKeyToPos.getOrElse(partitioning.expr.child.canonicalized,
mutable.BitSet.empty)
+ }
+
+ override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
+ case SinglePartitionShuffleSpec =>
+ partitioning.numPartitions == 1
+ case otherPassThroughSpec @ ShufflePartitionIdPassThroughSpec(
+ otherPartitioning, otherDistribution) =>
+ // As ShufflePartitionIdPassThrough only allows a single expression
+ // as the partitioning expression, we check compatibility as follows:
+ // 1. Same number of clustering expressions
+ // 2. Same number of partitions
+ // 3. each partitioning expression from both sides has overlapping
positions in their
+ // corresponding distributions.
+ distribution.clustering.length == otherDistribution.clustering.length &&
+ partitioning.numPartitions == otherPartitioning.numPartitions && {
+ val otherKeyPositions = otherPassThroughSpec.keyPositions
+ keyPositions.intersect(otherKeyPositions).nonEmpty
+ }
+ case ShuffleSpecCollection(specs) =>
+ specs.exists(isCompatibleWith)
+ case _ =>
+ false
+ }
+
+ // We don't support creating partitioning for ShufflePartitionIdPassThrough.
+ override def canCreatePartitioning: Boolean = false
+
+ override def numPartitions: Int = partitioning.numPartitions
+}
+
case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
override def isCompatibleWith(other: ShuffleSpec): Boolean = {
specs.exists(_.isCompatibleWith(other))
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
index fc5d39fd9c2b..85d285aa76c0 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
@@ -19,11 +19,15 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.{SparkFunSuite, SparkUnsupportedOperationException}
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.DirectShufflePartitionID
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.internal.SQLConf
class ShuffleSpecSuite extends SparkFunSuite with SQLHelper {
+ private val passThrough_a_10 =
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10)
+ private val passThrough_b_10 =
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"b"), 10)
+ private val passThrough_c_10 =
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10)
protected def checkCompatible(
left: ShuffleSpec,
right: ShuffleSpec,
@@ -479,4 +483,63 @@ class ShuffleSpecSuite extends SparkFunSuite with
SQLHelper {
"methodName" -> "createPartitioning$",
"className" ->
"org.apache.spark.sql.catalyst.plans.physical.ShuffleSpec"))
}
+
+ test("compatibility: ShufflePartitionIdPassThroughSpec on both sides") {
+ val ab = ClusteredDistribution(Seq($"a", $"b"))
+ val cd = ClusteredDistribution(Seq($"c", $"d"))
+
+ // Identical specs should be compatible
+ checkCompatible(
+ passThrough_a_10.createShuffleSpec(ab),
+ passThrough_c_10.createShuffleSpec(cd),
+ expected = true
+ )
+
+ // Different number of partitions should be incompatible
+ checkCompatible(
+ passThrough_a_10.createShuffleSpec(ab),
+ ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"),
5).createShuffleSpec(cd),
+ expected = false
+ )
+
+ // Mismatched key positions should be incompatible
+ checkCompatible(
+ passThrough_b_10.createShuffleSpec(ab),
+ passThrough_c_10.createShuffleSpec(cd),
+ expected = false
+ )
+
+ // Mismatched clustering keys
+ checkCompatible(
+ passThrough_a_10.createShuffleSpec(ClusteredDistribution(Seq($"e",
$"b"))),
+ passThrough_c_10.createShuffleSpec(ab),
+ expected = false
+ )
+ }
+
+ test("compatibility: ShufflePartitionIdPassThroughSpec vs other specs") {
+ val ab = ClusteredDistribution(Seq($"a", $"b"))
+ val cd = ClusteredDistribution(Seq($"c", $"d"))
+
+ // Compatibility with SinglePartitionShuffleSpec when numPartitions is 1
+ checkCompatible(
+ ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"),
1).createShuffleSpec(ab),
+ SinglePartitionShuffleSpec,
+ expected = true
+ )
+
+ // Incompatible with SinglePartitionShuffleSpec when numPartitions > 1
+ checkCompatible(
+ passThrough_a_10.createShuffleSpec(ab),
+ SinglePartitionShuffleSpec,
+ expected = false
+ )
+
+ // Incompatible with HashShuffleSpec
+ checkCompatible(
+ passThrough_a_10.createShuffleSpec(ab),
+ HashShuffleSpec(HashPartitioning(Seq($"c"), 10), cd),
+ expected = false
+ )
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index a0fc4b65fdbf..b97d765afcf7 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -165,22 +165,30 @@ case class EnsureRequirements(
// Check if the following conditions are satisfied:
// 1. There are exactly two children (e.g., join). Note that Spark
doesn't support
// multi-way join at the moment, so this check should be sufficient.
- // 2. All children are of `KeyGroupedPartitioning`, and they are
compatible with each other
+ // 2. All children are of the compatible key group partitioning or
+ // compatible shuffle partition id pass through partitioning
// If both are true, skip shuffle.
- val isKeyGroupCompatible = parent.isDefined &&
+ val areChildrenCompatible = parent.isDefined &&
children.length == 2 && childrenIndexes.length == 2 && {
val left = children.head
val right = children(1)
+
+ // key group compatibility check
val newChildren = checkKeyGroupCompatible(
parent.get, left, right, requiredChildDistributions)
if (newChildren.isDefined) {
children = newChildren.get
+ true
+ } else {
+ // If key group check fails, check ShufflePartitionIdPassThrough
compatibility
+ checkShufflePartitionIdPassThroughCompatible(
+ left, right, requiredChildDistributions)
}
- newChildren.isDefined
}
children = children.zip(requiredChildDistributions).zipWithIndex.map {
- case ((child, _), idx) if isKeyGroupCompatible ||
!childrenIndexes.contains(idx) =>
+ case ((child, _), idx) if areChildrenCompatible ||
+ !childrenIndexes.contains(idx) =>
child
case ((child, dist), idx) =>
if (bestSpecOpt.isDefined &&
bestSpecOpt.get.isCompatibleWith(specs(idx))) {
@@ -600,6 +608,23 @@ case class EnsureRequirements(
if (isCompatible) Some(Seq(newLeft, newRight)) else None
}
+ private def checkShufflePartitionIdPassThroughCompatible(
+ left: SparkPlan,
+ right: SparkPlan,
+ requiredChildDistribution: Seq[Distribution]): Boolean = {
+ (left.outputPartitioning, right.outputPartitioning) match {
+ case (p1: ShufflePartitionIdPassThrough, p2:
ShufflePartitionIdPassThrough) =>
+ assert(requiredChildDistribution.length == 2)
+ val leftSpec = p1.createShuffleSpec(
+ requiredChildDistribution.head.asInstanceOf[ClusteredDistribution])
+ val rightSpec = p2.createShuffleSpec(
+ requiredChildDistribution(1).asInstanceOf[ClusteredDistribution])
+ leftSpec.isCompatibleWith(rightSpec)
+ case _ =>
+ false
+ }
+ }
+
// Similar to `OptimizeSkewedJoin.canSplitRightSide`
private def canReplicateLeftSide(joinType: JoinType): Boolean = {
joinType == Inner || joinType == Cross || joinType == RightOuter
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 3b0bb088a107..b94ca4673641 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.exchange
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.DirectShufflePartitionID
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.optimizer.BuildRight
import org.apache.spark.sql.catalyst.plans.Inner
@@ -1196,6 +1197,191 @@ class EnsureRequirementsSuite extends
SharedSparkSession {
TransformExpression(BucketFunction, expr, Some(numBuckets))
}
+ test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when
children are compatible") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ val passThrough_a_5 =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
+
+ val leftPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val rightPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val join = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(
+ leftKeys,
+ rightKeys,
+ _,
+ _,
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ _
+ ) =>
+ assert(leftKeys === Seq(exprA))
+ assert(rightKeys === Seq(exprA))
+ case other => fail(s"We don't expect shuffle on either side, but got:
$other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough incompatibility - different partitions")
{
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Different number of partitions - should add shuffles
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val rightPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8))
+ val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _),
_),
+ SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides should be shuffled to default partitions
+ assert(p1.numPartitions == 10)
+ assert(p2.numPartitions == 10)
+ assert(p1.expressions == Seq(exprA))
+ assert(p2.expressions == Seq(exprB))
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough incompatibility - key position
mismatch") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Key position mismatch - should add shuffles
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val rightPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5))
+ // Join on different keys than partitioning keys
+ val join = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC ::
Nil, Inner, None,
+ leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides shuffled due to key mismatch
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // ShufflePartitionIdPassThrough vs HashPartitioning - always adds
shuffles
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val rightPlan = DummySparkPlan(
+ outputPartitioning = HashPartitioning(exprB :: Nil, 5))
+ val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+ SortExec(_, _, _: DummySparkPlan, _), _) =>
+ // Left side shuffled, right side kept as-is
+ case other => fail(s"Expected shuffle on the left side, but got:
$other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough vs SinglePartition - shuffles added") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
+ // Even when compatible (numPartitions=1), shuffles added due to
canCreatePartitioning=false
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 1))
+ val rightPlan = DummySparkPlan(outputPartitioning = SinglePartition)
+ val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides shuffled due to canCreatePartitioning = false
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+
+ test("ShufflePartitionIdPassThrough - compatible with multiple clustering
keys") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ val passThrough_a_5 =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
+ val passThrough_b_5 =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)
+
+ // Both partitioned by exprA, joined on (exprA, exprB)
+ // Should be compatible because exprA positions overlap
+ val leftPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val rightPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val joinA = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB ::
Nil, Inner, None,
+ leftPlanA, rightPlanA)
+
+ EnsureRequirements.apply(joinA) match {
+ case SortMergeJoinExec(
+ leftKeys,
+ rightKeys,
+ _,
+ _,
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ _
+ ) =>
+ assert(leftKeys === Seq(exprA, exprB))
+ assert(rightKeys === Seq(exprA, exprB))
+ case other => fail(s"We don't expect shuffle on either side with
multiple " +
+ s"clustering keys, but got: $other")
+ }
+
+ // Both sides partitioned by exprB and join on (exprA, exprB)
+ // Should be compatible because partition key exprB matches at position
1 in join keys
+ val leftPlanB = DummySparkPlan(outputPartitioning = passThrough_b_5)
+ val rightPlanB = DummySparkPlan(outputPartitioning = passThrough_b_5)
+ val joinB = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB ::
Nil, Inner, None,
+ leftPlanB, rightPlanB)
+
+ EnsureRequirements.apply(joinB) match {
+ case SortMergeJoinExec(
+ leftKeys,
+ rightKeys,
+ _,
+ _,
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ _
+ ) =>
+ // No shuffles because exprB (partition key) appears at position 1
in join keys
+ assert(leftKeys === Seq(exprA, exprB))
+ assert(rightKeys === Seq(exprA, exprB))
+ case other => fail(s"Expected no shuffles due to position overlap at
position 1, " +
+ s"but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough - incompatible when partition key not in
join keys") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Partitioned by exprA and exprB respectively, but joining on
completely different keys
+ // Should require shuffles because partition keys don't match join keys
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val rightPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
+ val join = SortMergeJoinExec(exprC :: Nil, exprD :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _),
_),
+ SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides should be shuffled because partition keys not in join
keys
+ assert(p1.numPartitions == 10)
+ assert(p2.numPartitions == 10)
+ assert(p1.expressions == Seq(exprC))
+ assert(p2.expressions == Seq(exprD))
+ case other => fail(s"Expected shuffles on both sides due to key
mismatch, but got: $other")
+ }
+ }
+ }
+
def years(expr: Expression): TransformExpression = {
TransformExpression(YearsFunction, Seq(expr))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]