This is an automated email from the ASF dual-hosted git repository.

dtenedor pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 04624f62db0e [SPARK-54707] [SQL] Refactor `PIVOT` resolution main 
logic to the `PivotTransformer`
04624f62db0e is described below

commit 04624f62db0e1aa850a2b4e722e2bc402454a4cb
Author: mihailoale-db <[email protected]>
AuthorDate: Tue Dec 16 14:42:23 2025 -0800

    [SPARK-54707] [SQL] Refactor `PIVOT` resolution main logic to the 
`PivotTransformer`
    
    ### What changes were proposed in this pull request?
    In this issue I propose to refactor `PIVOT` resolution main logic to the 
`PivotTransformer` in order to reuse it later to implement `PIVOT` in the 
single-pass resolver.
    
    ### Why are the changes needed?
    To ease the development of the single-pass resolver.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #53474 from mihailoale-db/pivotrefactor.
    
    Authored-by: mihailoale-db <[email protected]>
    Signed-off-by: Daniel Tenedorio <[email protected]>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 104 +---------
 .../sql/catalyst/analysis/PivotTransformer.scala   | 223 +++++++++++++++++++++
 2 files changed, 233 insertions(+), 94 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 311c1c946fbe..ad9a48979b11 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -897,100 +897,16 @@ class Analyzer(
         }
         // Check all aggregate expressions.
         aggregates.foreach(checkValidAggregateExpression)
-        // Check all pivot values are literal and match pivot column data type.
-        val evalPivotValues = pivotValues.map { value =>
-          val foldable = trimAliases(value).foldable
-          if (!foldable) {
-            throw QueryCompilationErrors.nonLiteralPivotValError(value)
-          }
-          if (!Cast.canCast(value.dataType, pivotColumn.dataType)) {
-            throw QueryCompilationErrors.pivotValDataTypeMismatchError(value, 
pivotColumn)
-          }
-          Cast(value, pivotColumn.dataType, 
Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
-        }
-        // Group-by expressions coming from SQL are implicit and need to be 
deduced.
-        val groupByExprs = groupByExprsOpt.getOrElse {
-          val pivotColAndAggRefs = pivotColumn.references ++ 
AttributeSet(aggregates)
-          child.output.filterNot(pivotColAndAggRefs.contains)
-        }
-        val singleAgg = aggregates.size == 1
-        def outputName(value: Expression, aggregate: Expression): String = {
-          val stringValue = value match {
-            case n: NamedExpression => n.name
-            case _ =>
-              val utf8Value =
-                Cast(value, StringType, 
Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
-              Option(utf8Value).map(_.toString).getOrElse("null")
-          }
-          if (singleAgg) {
-            stringValue
-          } else {
-            val suffix = aggregate match {
-              case n: NamedExpression => n.name
-              case _ => toPrettySQL(aggregate)
-            }
-            stringValue + "_" + suffix
-          }
-        }
-        if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) {
-          // Since evaluating |pivotValues| if statements for each input row 
can get slow this is an
-          // alternate plan that instead uses two steps of aggregation.
-          val namedAggExps: Seq[NamedExpression] = aggregates.map(a => 
Alias(a, a.sql)())
-          val namedPivotCol = pivotColumn match {
-            case n: NamedExpression => n
-            case _ => Alias(pivotColumn, "__pivot_col")()
-          }
-          val bigGroup = groupByExprs :+ namedPivotCol
-          val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
-          val pivotAggs = namedAggExps.map { a =>
-            Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, 
evalPivotValues)
-              .toAggregateExpression()
-            , "__pivot_" + a.sql)()
-          }
-          val groupByExprsAttr = groupByExprs.map(_.toAttribute)
-          val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ 
pivotAggs, firstAgg)
-          val pivotAggAttribute = pivotAggs.map(_.toAttribute)
-          val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, 
i) =>
-            aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) 
=>
-              Alias(ExtractValue(pivotAtt, Literal(i), resolver), 
outputName(value, aggregate))()
-            }
-          }
-          Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
-        } else {
-          val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { 
value =>
-            def ifExpr(e: Expression) = {
-              If(
-                EqualNullSafe(
-                  pivotColumn,
-                  Cast(value, pivotColumn.dataType, 
Some(conf.sessionLocalTimeZone))),
-                e, Literal(null))
-            }
-            aggregates.map { aggregate =>
-              val filteredAggregate = aggregate.transformDown {
-                // Assumption is the aggregate function ignores nulls. This is 
true for all current
-                // AggregateFunction's with the exception of First and Last in 
their default mode
-                // (which we handle) and possibly some Hive UDAF's.
-                case First(expr, _) =>
-                  First(ifExpr(expr), true)
-                case Last(expr, _) =>
-                  Last(ifExpr(expr), true)
-                case a: ApproximatePercentile =>
-                  // ApproximatePercentile takes two literals for accuracy and 
percentage which
-                  // should not be wrapped by if-else.
-                  a.withNewChildren(ifExpr(a.first) :: a.second :: a.third :: 
Nil)
-                case a: AggregateFunction =>
-                  a.withNewChildren(a.children.map(ifExpr))
-              }.transform {
-                // We are duplicating aggregates that are now computing a 
different value for each
-                // pivot value.
-                // TODO: Don't construct the physical container until after 
analysis.
-                case ae: AggregateExpression => ae.copy(resultId = 
NamedExpression.newExprId)
-              }
-              Alias(filteredAggregate, outputName(value, aggregate))()
-            }
-          }
-          Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
-        }
+        PivotTransformer(
+          child = child,
+          pivotValues = pivotValues,
+          pivotColumn = pivotColumn,
+          groupByExpressionsOpt = groupByExprsOpt,
+          aggregates = aggregates,
+          childOutput = child.output,
+          newAlias = (child, name) =>
+            Alias(child, name.get)()
+        )
     }
 
     // Support any aggregate expression that can appear in an Aggregate plan 
except Pandas UDF.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PivotTransformer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PivotTransformer.scala
new file mode 100644
index 000000000000..7f22dab71e3b
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PivotTransformer.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.expressions.{
+  Alias,
+  AliasHelper,
+  Attribute,
+  AttributeSet,
+  Cast,
+  EmptyRow,
+  EqualNullSafe,
+  Expression,
+  ExtractValue,
+  If,
+  Literal,
+  NamedExpression
+}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{
+  AggregateExpression,
+  AggregateFunction,
+  ApproximatePercentile,
+  First,
+  Last,
+  PivotFirst
+}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.util.toPrettySQL
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types.StringType
+
+/**
+ * Object used to transform [[Pivot]] node into a [[Project]] or [[Aggregate]] 
(based on the tree
+ * structure below the [[Pivot]]).
+ */
+object PivotTransformer extends AliasHelper with SQLConfHelper {
+
+  /**
+   * Transform a pivot operation into an [[Aggregate]] or a combination of 
[[Aggregate]]s and
+   * [[Project]] operators.
+   *
+   *  1. Check all pivot values are literal and match pivot column data type.
+   *
+   *  2. Deduce group-by expressions. Group-by expressions coming from SQL are 
implicit and need to
+   *     be deduced by filtering out pivot column and aggregate references 
from child output.
+   *     In case of:
+   *     {{{
+   *       SELECT year, region, q1, q2, q3, q4
+   *       FROM sales
+   *       PIVOT (sum(sales) AS sales
+   *         FOR quarter
+   *         IN (1 AS q1, 2 AS q2, 3 AS q3, 4 AS q4));
+   *     }}}
+   *     where table `sales` has `year`, `quarter`, `region`, `sales` as 
columns.
+   *     In this example: pivot column would be `quarter`, aggregate would be 
`sales` and because
+   *     of that, `year` and `region` would be grouping expressions.
+   *
+   *  3. Choose between two execution strategies based on aggregate data types:
+   *
+   *     a) If all aggregates support [[PivotFirst]] data types (fast path):
+   *        Since evaluating `pivotValues` `IF` statements for each input row 
can get slow, use an
+   *        alternate plan that instead uses two steps of aggregation:
+   *        - First aggregation: group by original grouping expressions + 
pivot column, compute
+   *          aggregates
+   *        - Second aggregation: group by original grouping expressions only, 
use [[PivotFirst]]
+   *          to extract values for each pivot value
+   *        - Final projection: extract individual pivot outputs using 
[[ExtractValue]]
+   *
+   *     b) Otherwise (standard path):
+   *        Create a single [[Aggregate]] with filtered aggregates for each 
pivot value. For each
+   *        aggregate and pivot value combination:
+   *        - Wrap aggregate children with `If(pivotColumn == pivotValue, 
expr, null)` expressions.
+   *        - Handle special cases for [[First]], [[Last]], and 
[[ApproximatePercentile]] which
+   *          have specific semantics around null handling.
+   */
+  def apply(
+      child: LogicalPlan,
+      pivotValues: Seq[Expression],
+      pivotColumn: Expression,
+      groupByExpressionsOpt: Option[Seq[NamedExpression]],
+      aggregates: Seq[Expression],
+      childOutput: Seq[Attribute],
+      newAlias: (Expression, Option[String]) => Alias): LogicalPlan = {
+    val evalPivotValues = pivotValues.map { value =>
+      val foldable = trimAliases(value).foldable
+      if (!foldable) {
+        throw QueryCompilationErrors.nonLiteralPivotValError(value)
+      }
+      if (!Cast.canCast(value.dataType, pivotColumn.dataType)) {
+        throw QueryCompilationErrors.pivotValDataTypeMismatchError(value, 
pivotColumn)
+      }
+      Cast(value, pivotColumn.dataType, 
Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
+    }
+    val groupByExpressions = groupByExpressionsOpt.getOrElse {
+      val pivotColumnAndAggregatesRefs = pivotColumn.references ++ 
AttributeSet(aggregates)
+      childOutput.filterNot(pivotColumnAndAggregatesRefs.contains)
+    }
+    if (aggregates.forall(aggregate => 
PivotFirst.supportsDataType(aggregate.dataType))) {
+      val namedAggExps: Seq[NamedExpression] = aggregates.map { aggregate =>
+        newAlias(aggregate, Some(aggregate.sql))
+      }
+      val namedPivotCol = pivotColumn match {
+        case namedExpression: NamedExpression => namedExpression
+        case _ =>
+          newAlias(pivotColumn, Some("__pivot_col"))
+      }
+      val extendedGroupingExpressions = groupByExpressions :+ namedPivotCol
+      val firstAgg =
+        Aggregate(extendedGroupingExpressions, extendedGroupingExpressions ++ 
namedAggExps, child)
+      val pivotAggregates = namedAggExps.map { a =>
+        newAlias(
+          PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues)
+            .toAggregateExpression(),
+          Some("__pivot_" + a.sql)
+        )
+      }
+      val groupByExpressionsAttributes = groupByExpressions.map(_.toAttribute)
+      val secondAgg =
+        Aggregate(
+          groupByExpressionsAttributes,
+          groupByExpressionsAttributes ++ pivotAggregates,
+          firstAgg
+        )
+      val pivotAggregatesAttributes = pivotAggregates.map(_.toAttribute)
+      val pivotOutputs = pivotValues.zipWithIndex.flatMap {
+        case (value, i) =>
+          aggregates.zip(pivotAggregatesAttributes).map {
+            case (aggregate, pivotAtt) =>
+              newAlias(
+                ExtractValue(pivotAtt, Literal(i), conf.resolver),
+                Some(outputName(value, aggregate, isSingleAggregate = 
aggregates.size == 1))
+              )
+          }
+      }
+      Project(groupByExpressionsAttributes ++ pivotOutputs, secondAgg)
+    } else {
+      val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value 
=>
+        aggregates.map { aggregate =>
+          val filteredAggregate = aggregate
+            .transformDown {
+              case First(expression, _) =>
+                First(createIfExpression(expression, pivotColumn, value), true)
+              case Last(expression, _) =>
+                Last(createIfExpression(expression, pivotColumn, value), true)
+              case approximatePercentile: ApproximatePercentile =>
+                approximatePercentile.withNewChildren(
+                  createIfExpression(approximatePercentile.first, pivotColumn, 
value) ::
+                  approximatePercentile.second ::
+                  approximatePercentile.third ::
+                  Nil
+                )
+              case aggregateFunction: AggregateFunction =>
+                
aggregateFunction.withNewChildren(aggregateFunction.children.map { child =>
+                  createIfExpression(child, pivotColumn, value)
+                })
+            }
+            .transform {
+              // TODO: Don't construct the physical container until after 
analysis.
+              case aggregateExpression: AggregateExpression =>
+                aggregateExpression.copy(resultId = NamedExpression.newExprId)
+            }
+          newAlias(
+            filteredAggregate,
+            Some(outputName(value, aggregate, isSingleAggregate = 
aggregates.size == 1))
+          )
+        }
+      }
+      Aggregate(groupByExpressions, groupByExpressions ++ pivotAggregates, 
child)
+    }
+  }
+
+  private def outputName(
+      value: Expression,
+      aggregate: Expression,
+      isSingleAggregate: Boolean): String = {
+    val stringValue = value match {
+      case namedExpression: NamedExpression => namedExpression.name
+      case _ =>
+        val utf8Value =
+          Cast(value, StringType, 
Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
+        Option(utf8Value).map(_.toString).getOrElse("null")
+    }
+    if (isSingleAggregate) {
+      stringValue
+    } else {
+      val suffix = aggregate match {
+        case namedExpression: NamedExpression => namedExpression.name
+        case _ => toPrettySQL(aggregate)
+      }
+      stringValue + "_" + suffix
+    }
+  }
+
+  private def createIfExpression(
+      expression: Expression,
+      pivotColumn: Expression,
+      value: Expression) = {
+    If(
+      EqualNullSafe(
+        pivotColumn,
+        Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))
+      ),
+      expression,
+      Literal(null)
+    )
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to