This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 1b36e3c4db41 [SPARK-46170][SQL][3.5] Support inject adaptive query 
post planner strategy rules in SparkSessionExtensions
1b36e3c4db41 is described below

commit 1b36e3c4db41df9c6ec476db31f9cbec0d60ab6f
Author: ulysses-you <[email protected]>
AuthorDate: Tue Feb 6 14:54:45 2024 +0800

    [SPARK-46170][SQL][3.5] Support inject adaptive query post planner strategy 
rules in SparkSessionExtensions
    
    This pr is backport https://github.com/apache/spark/pull/44074 for 
branch-3.5 since 3.5 is a lts version
    
    ### What changes were proposed in this pull request?
    
    This pr adds a new extension entrance `queryPostPlannerStrategyRules` in 
`SparkSessionExtensions`. It will be applied between plannerStrategy and 
queryStagePrepRules in AQE, so it can get the whole plan before injecting 
exchanges.
    
    ### Why are the changes needed?
    
    3.5 is a lts version
    
    ### Does this PR introduce _any_ user-facing change?
    
    no, only for develop
    
    ### How was this patch tested?
    
    add test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #44074 from ulysses-you/post-planner.
    
    Authored-by: ulysses-you <ulyssesyou18gmail.com>
    
    Closes #45037 from ulysses-you/SPARK-46170.
    
    Authored-by: ulysses-you <[email protected]>
    Signed-off-by: Kent Yao <[email protected]>
---
 .../apache/spark/sql/SparkSessionExtensions.scala  | 21 ++++++++++++
 .../execution/adaptive/AdaptiveRulesHolder.scala   |  5 ++-
 .../execution/adaptive/AdaptiveSparkPlanExec.scala | 14 ++++++--
 .../sql/internal/BaseSessionStateBuilder.scala     |  3 +-
 .../spark/sql/SparkSessionExtensionSuite.scala     | 40 +++++++++++++++++++++-
 5 files changed, 78 insertions(+), 5 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index b7c86ab7de6b..677dba008257 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, 
SparkPlan}
  * <li>Customized Parser.</li>
  * <li>(External) Catalog listeners.</li>
  * <li>Columnar Rules.</li>
+ * <li>Adaptive Query Post Planner Strategy Rules.</li>
  * <li>Adaptive Query Stage Preparation Rules.</li>
  * <li>Adaptive Query Execution Runtime Optimizer Rules.</li>
  * <li>Adaptive Query Stage Optimizer Rules.</li>
@@ -112,10 +113,13 @@ class SparkSessionExtensions {
   type FunctionDescription = (FunctionIdentifier, ExpressionInfo, 
FunctionBuilder)
   type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, 
TableFunctionBuilder)
   type ColumnarRuleBuilder = SparkSession => ColumnarRule
+  type QueryPostPlannerStrategyBuilder = SparkSession => Rule[SparkPlan]
   type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]
   type QueryStageOptimizerRuleBuilder = SparkSession => Rule[SparkPlan]
 
   private[this] val columnarRuleBuilders = 
mutable.Buffer.empty[ColumnarRuleBuilder]
+  private[this] val queryPostPlannerStrategyRuleBuilders =
+    mutable.Buffer.empty[QueryPostPlannerStrategyBuilder]
   private[this] val queryStagePrepRuleBuilders = 
mutable.Buffer.empty[QueryStagePrepRuleBuilder]
   private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder]
   private[this] val queryStageOptimizerRuleBuilders =
@@ -128,6 +132,14 @@ class SparkSessionExtensions {
     columnarRuleBuilders.map(_.apply(session)).toSeq
   }
 
+  /**
+   * Build the override rules for the query post planner strategy phase of 
adaptive query execution.
+   */
+  private[sql] def buildQueryPostPlannerStrategyRules(
+      session: SparkSession): Seq[Rule[SparkPlan]] = {
+    queryPostPlannerStrategyRuleBuilders.map(_.apply(session)).toSeq
+  }
+
   /**
    * Build the override rules for the query stage preparation phase of 
adaptive query execution.
    */
@@ -156,6 +168,15 @@ class SparkSessionExtensions {
     columnarRuleBuilders += builder
   }
 
+  /**
+   * Inject a rule that applied between `plannerStrategy` and 
`queryStagePrepRules`, so
+   * it can get the whole plan before injecting exchanges.
+   * Note, these rules can only be applied within AQE.
+   */
+  def injectQueryPostPlannerStrategyRule(builder: 
QueryPostPlannerStrategyBuilder): Unit = {
+    queryPostPlannerStrategyRuleBuilders += builder
+  }
+
   /**
    * Inject a rule that can override the query stage preparation phase of 
adaptive query
    * execution.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
index 8391fe44f559..ee2cd8a4953b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
@@ -29,9 +29,12 @@ import org.apache.spark.sql.execution.SparkPlan
  *                              query stage
  * @param queryStageOptimizerRules applied to a new query stage before its 
execution. It makes sure
  *                                 all children query stages are materialized
+ * @param queryPostPlannerStrategyRules applied between `plannerStrategy` and 
`queryStagePrepRules`,
+ *                                      so it can get the whole plan before 
injecting exchanges.
  */
 class AdaptiveRulesHolder(
     val queryStagePrepRules: Seq[Rule[SparkPlan]],
     val runtimeOptimizerRules: Seq[Rule[LogicalPlan]],
-    val queryStageOptimizerRules: Seq[Rule[SparkPlan]]) {
+    val queryStageOptimizerRules: Seq[Rule[SparkPlan]],
+    val queryPostPlannerStrategyRules: Seq[Rule[SparkPlan]]) {
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index fa671c8faf8b..96b83a91cc73 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -193,9 +193,19 @@ case class AdaptiveSparkPlanExec(
     optimized
   }
 
+  private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = 
{
+    applyPhysicalRules(
+      plan,
+      
context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules,
+      Some((planChangeLogger, "AQE Query Post Planner Strategy Rules"))
+    )
+  }
+
   @transient val initialPlan = context.session.withActive {
     applyPhysicalRules(
-      inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE 
Preparations")))
+      applyQueryPostPlannerStrategyRules(inputPlan),
+      queryStagePreparationRules,
+      Some((planChangeLogger, "AQE Preparations")))
   }
 
   @volatile private var currentPhysicalPlan = initialPlan
@@ -706,7 +716,7 @@ case class AdaptiveSparkPlanExec(
       val optimized = optimizer.execute(logicalPlan)
       val sparkPlan = 
context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
       val newPlan = applyPhysicalRules(
-        sparkPlan,
+        applyQueryPostPlannerStrategyRules(sparkPlan),
         preprocessingRules ++ queryStagePreparationRules,
         Some((planChangeLogger, "AQE Replanning")))
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 5543b409d170..3a07dbf5480d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -318,7 +318,8 @@ abstract class BaseSessionStateBuilder(
     new AdaptiveRulesHolder(
       extensions.buildQueryStagePrepRules(session),
       extensions.buildRuntimeOptimizerRules(session),
-      extensions.buildQueryStageOptimizerRules(session))
+      extensions.buildQueryStageOptimizerRules(session),
+      extensions.buildQueryPostPlannerStrategyRules(session))
   }
 
   protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 21518085ca4c..8b4ac474f875 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -29,15 +29,17 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, 
InternalRow, TableIden
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
ParserInterface}
 import org.apache.spark.sql.catalyst.plans.SQLHelper
 import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Limit, 
LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
SinglePartition}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.connector.write.WriterCommitMessage
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec, 
ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec, 
WriteFilesSpec}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
@@ -516,6 +518,31 @@ class SparkSessionExtensionSuite extends SparkFunSuite 
with SQLHelper with Adapt
       }
     }
   }
+
+  test("SPARK-46170: Support inject adaptive query post planner strategy rules 
in " +
+    "SparkSessionExtensions") {
+    val extensions = create { extensions =>
+      extensions.injectQueryPostPlannerStrategyRule(_ => 
MyQueryPostPlannerStrategyRule)
+    }
+    withSession(extensions) { session =>
+      
assert(session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules
+        .contains(MyQueryPostPlannerStrategyRule))
+      import session.implicits._
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+          SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") {
+        val input = Seq(10, 20, 10).toDF("c1")
+        val df = input.groupBy("c1").count()
+        df.collect()
+        assert(df.rdd.partitions.length == 1)
+        assert(collectFirst(df.queryExecution.executedPlan) {
+          case s: ShuffleExchangeExec if s.outputPartitioning == 
SinglePartition => true
+        }.isDefined)
+        assert(collectFirst(df.queryExecution.executedPlan) {
+          case _: SortExec => true
+        }.isDefined)
+      }
+    }
+  }
 }
 
 case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -1190,3 +1217,14 @@ object RequireAtLeaseTwoPartitions extends 
Rule[SparkPlan] {
     }
   }
 }
+
+object MyQueryPostPlannerStrategyRule extends Rule[SparkPlan] {
+  override def apply(plan: SparkPlan): SparkPlan = {
+    plan.transformUp {
+      case h: HashAggregateExec if 
h.aggregateExpressions.map(_.mode).contains(Partial) =>
+        ShuffleExchangeExec(SinglePartition, h)
+      case h: HashAggregateExec if 
h.aggregateExpressions.map(_.mode).contains(Final) =>
+        SortExec(h.groupingExpressions.map(k => SortOrder.apply(k, 
Ascending)), false, h)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to