This is an automated email from the ASF dual-hosted git repository.
dtenedor pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 04624f62db0e [SPARK-54707] [SQL] Refactor `PIVOT` resolution main
logic to the `PivotTransformer`
04624f62db0e is described below
commit 04624f62db0e1aa850a2b4e722e2bc402454a4cb
Author: mihailoale-db <[email protected]>
AuthorDate: Tue Dec 16 14:42:23 2025 -0800
[SPARK-54707] [SQL] Refactor `PIVOT` resolution main logic to the
`PivotTransformer`
### What changes were proposed in this pull request?
In this issue I propose to refactor `PIVOT` resolution main logic to the
`PivotTransformer` in order to reuse it later to implement `PIVOT` in the
single-pass resolver.
### Why are the changes needed?
To ease the development of the single-pass resolver.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53474 from mihailoale-db/pivotrefactor.
Authored-by: mihailoale-db <[email protected]>
Signed-off-by: Daniel Tenedorio <[email protected]>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 104 +---------
.../sql/catalyst/analysis/PivotTransformer.scala | 223 +++++++++++++++++++++
2 files changed, 233 insertions(+), 94 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 311c1c946fbe..ad9a48979b11 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -897,100 +897,16 @@ class Analyzer(
}
// Check all aggregate expressions.
aggregates.foreach(checkValidAggregateExpression)
- // Check all pivot values are literal and match pivot column data type.
- val evalPivotValues = pivotValues.map { value =>
- val foldable = trimAliases(value).foldable
- if (!foldable) {
- throw QueryCompilationErrors.nonLiteralPivotValError(value)
- }
- if (!Cast.canCast(value.dataType, pivotColumn.dataType)) {
- throw QueryCompilationErrors.pivotValDataTypeMismatchError(value,
pivotColumn)
- }
- Cast(value, pivotColumn.dataType,
Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
- }
- // Group-by expressions coming from SQL are implicit and need to be
deduced.
- val groupByExprs = groupByExprsOpt.getOrElse {
- val pivotColAndAggRefs = pivotColumn.references ++
AttributeSet(aggregates)
- child.output.filterNot(pivotColAndAggRefs.contains)
- }
- val singleAgg = aggregates.size == 1
- def outputName(value: Expression, aggregate: Expression): String = {
- val stringValue = value match {
- case n: NamedExpression => n.name
- case _ =>
- val utf8Value =
- Cast(value, StringType,
Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
- Option(utf8Value).map(_.toString).getOrElse("null")
- }
- if (singleAgg) {
- stringValue
- } else {
- val suffix = aggregate match {
- case n: NamedExpression => n.name
- case _ => toPrettySQL(aggregate)
- }
- stringValue + "_" + suffix
- }
- }
- if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) {
- // Since evaluating |pivotValues| if statements for each input row
can get slow this is an
- // alternate plan that instead uses two steps of aggregation.
- val namedAggExps: Seq[NamedExpression] = aggregates.map(a =>
Alias(a, a.sql)())
- val namedPivotCol = pivotColumn match {
- case n: NamedExpression => n
- case _ => Alias(pivotColumn, "__pivot_col")()
- }
- val bigGroup = groupByExprs :+ namedPivotCol
- val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
- val pivotAggs = namedAggExps.map { a =>
- Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute,
evalPivotValues)
- .toAggregateExpression()
- , "__pivot_" + a.sql)()
- }
- val groupByExprsAttr = groupByExprs.map(_.toAttribute)
- val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++
pivotAggs, firstAgg)
- val pivotAggAttribute = pivotAggs.map(_.toAttribute)
- val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value,
i) =>
- aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt)
=>
- Alias(ExtractValue(pivotAtt, Literal(i), resolver),
outputName(value, aggregate))()
- }
- }
- Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
- } else {
- val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap {
value =>
- def ifExpr(e: Expression) = {
- If(
- EqualNullSafe(
- pivotColumn,
- Cast(value, pivotColumn.dataType,
Some(conf.sessionLocalTimeZone))),
- e, Literal(null))
- }
- aggregates.map { aggregate =>
- val filteredAggregate = aggregate.transformDown {
- // Assumption is the aggregate function ignores nulls. This is
true for all current
- // AggregateFunction's with the exception of First and Last in
their default mode
- // (which we handle) and possibly some Hive UDAF's.
- case First(expr, _) =>
- First(ifExpr(expr), true)
- case Last(expr, _) =>
- Last(ifExpr(expr), true)
- case a: ApproximatePercentile =>
- // ApproximatePercentile takes two literals for accuracy and
percentage which
- // should not be wrapped by if-else.
- a.withNewChildren(ifExpr(a.first) :: a.second :: a.third ::
Nil)
- case a: AggregateFunction =>
- a.withNewChildren(a.children.map(ifExpr))
- }.transform {
- // We are duplicating aggregates that are now computing a
different value for each
- // pivot value.
- // TODO: Don't construct the physical container until after
analysis.
- case ae: AggregateExpression => ae.copy(resultId =
NamedExpression.newExprId)
- }
- Alias(filteredAggregate, outputName(value, aggregate))()
- }
- }
- Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
- }
+ PivotTransformer(
+ child = child,
+ pivotValues = pivotValues,
+ pivotColumn = pivotColumn,
+ groupByExpressionsOpt = groupByExprsOpt,
+ aggregates = aggregates,
+ childOutput = child.output,
+ newAlias = (child, name) =>
+ Alias(child, name.get)()
+ )
}
// Support any aggregate expression that can appear in an Aggregate plan
except Pandas UDF.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PivotTransformer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PivotTransformer.scala
new file mode 100644
index 000000000000..7f22dab71e3b
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PivotTransformer.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.expressions.{
+ Alias,
+ AliasHelper,
+ Attribute,
+ AttributeSet,
+ Cast,
+ EmptyRow,
+ EqualNullSafe,
+ Expression,
+ ExtractValue,
+ If,
+ Literal,
+ NamedExpression
+}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{
+ AggregateExpression,
+ AggregateFunction,
+ ApproximatePercentile,
+ First,
+ Last,
+ PivotFirst
+}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan,
Project}
+import org.apache.spark.sql.catalyst.util.toPrettySQL
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types.StringType
+
+/**
+ * Object used to transform [[Pivot]] node into a [[Project]] or [[Aggregate]]
(based on the tree
+ * structure below the [[Pivot]]).
+ */
+object PivotTransformer extends AliasHelper with SQLConfHelper {
+
+ /**
+ * Transform a pivot operation into an [[Aggregate]] or a combination of
[[Aggregate]]s and
+ * [[Project]] operators.
+ *
+ * 1. Check all pivot values are literal and match pivot column data type.
+ *
+ * 2. Deduce group-by expressions. Group-by expressions coming from SQL are
implicit and need to
+ * be deduced by filtering out pivot column and aggregate references
from child output.
+ * In case of:
+ * {{{
+ * SELECT year, region, q1, q2, q3, q4
+ * FROM sales
+ * PIVOT (sum(sales) AS sales
+ * FOR quarter
+ * IN (1 AS q1, 2 AS q2, 3 AS q3, 4 AS q4));
+ * }}}
+ * where table `sales` has `year`, `quarter`, `region`, `sales` as
columns.
+ * In this example: pivot column would be `quarter`, aggregate would be
`sales` and because
+ * of that, `year` and `region` would be grouping expressions.
+ *
+ * 3. Choose between two execution strategies based on aggregate data types:
+ *
+ * a) If all aggregates support [[PivotFirst]] data types (fast path):
+ * Since evaluating `pivotValues` `IF` statements for each input row
can get slow, use an
+ * alternate plan that instead uses two steps of aggregation:
+ * - First aggregation: group by original grouping expressions +
pivot column, compute
+ * aggregates
+ * - Second aggregation: group by original grouping expressions only,
use [[PivotFirst]]
+ * to extract values for each pivot value
+ * - Final projection: extract individual pivot outputs using
[[ExtractValue]]
+ *
+ * b) Otherwise (standard path):
+ * Create a single [[Aggregate]] with filtered aggregates for each
pivot value. For each
+ * aggregate and pivot value combination:
+ * - Wrap aggregate children with `If(pivotColumn == pivotValue,
expr, null)` expressions.
+ * - Handle special cases for [[First]], [[Last]], and
[[ApproximatePercentile]] which
+ * have specific semantics around null handling.
+ */
+ def apply(
+ child: LogicalPlan,
+ pivotValues: Seq[Expression],
+ pivotColumn: Expression,
+ groupByExpressionsOpt: Option[Seq[NamedExpression]],
+ aggregates: Seq[Expression],
+ childOutput: Seq[Attribute],
+ newAlias: (Expression, Option[String]) => Alias): LogicalPlan = {
+ val evalPivotValues = pivotValues.map { value =>
+ val foldable = trimAliases(value).foldable
+ if (!foldable) {
+ throw QueryCompilationErrors.nonLiteralPivotValError(value)
+ }
+ if (!Cast.canCast(value.dataType, pivotColumn.dataType)) {
+ throw QueryCompilationErrors.pivotValDataTypeMismatchError(value,
pivotColumn)
+ }
+ Cast(value, pivotColumn.dataType,
Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
+ }
+ val groupByExpressions = groupByExpressionsOpt.getOrElse {
+ val pivotColumnAndAggregatesRefs = pivotColumn.references ++
AttributeSet(aggregates)
+ childOutput.filterNot(pivotColumnAndAggregatesRefs.contains)
+ }
+ if (aggregates.forall(aggregate =>
PivotFirst.supportsDataType(aggregate.dataType))) {
+ val namedAggExps: Seq[NamedExpression] = aggregates.map { aggregate =>
+ newAlias(aggregate, Some(aggregate.sql))
+ }
+ val namedPivotCol = pivotColumn match {
+ case namedExpression: NamedExpression => namedExpression
+ case _ =>
+ newAlias(pivotColumn, Some("__pivot_col"))
+ }
+ val extendedGroupingExpressions = groupByExpressions :+ namedPivotCol
+ val firstAgg =
+ Aggregate(extendedGroupingExpressions, extendedGroupingExpressions ++
namedAggExps, child)
+ val pivotAggregates = namedAggExps.map { a =>
+ newAlias(
+ PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues)
+ .toAggregateExpression(),
+ Some("__pivot_" + a.sql)
+ )
+ }
+ val groupByExpressionsAttributes = groupByExpressions.map(_.toAttribute)
+ val secondAgg =
+ Aggregate(
+ groupByExpressionsAttributes,
+ groupByExpressionsAttributes ++ pivotAggregates,
+ firstAgg
+ )
+ val pivotAggregatesAttributes = pivotAggregates.map(_.toAttribute)
+ val pivotOutputs = pivotValues.zipWithIndex.flatMap {
+ case (value, i) =>
+ aggregates.zip(pivotAggregatesAttributes).map {
+ case (aggregate, pivotAtt) =>
+ newAlias(
+ ExtractValue(pivotAtt, Literal(i), conf.resolver),
+ Some(outputName(value, aggregate, isSingleAggregate =
aggregates.size == 1))
+ )
+ }
+ }
+ Project(groupByExpressionsAttributes ++ pivotOutputs, secondAgg)
+ } else {
+ val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value
=>
+ aggregates.map { aggregate =>
+ val filteredAggregate = aggregate
+ .transformDown {
+ case First(expression, _) =>
+ First(createIfExpression(expression, pivotColumn, value), true)
+ case Last(expression, _) =>
+ Last(createIfExpression(expression, pivotColumn, value), true)
+ case approximatePercentile: ApproximatePercentile =>
+ approximatePercentile.withNewChildren(
+ createIfExpression(approximatePercentile.first, pivotColumn,
value) ::
+ approximatePercentile.second ::
+ approximatePercentile.third ::
+ Nil
+ )
+ case aggregateFunction: AggregateFunction =>
+
aggregateFunction.withNewChildren(aggregateFunction.children.map { child =>
+ createIfExpression(child, pivotColumn, value)
+ })
+ }
+ .transform {
+ // TODO: Don't construct the physical container until after
analysis.
+ case aggregateExpression: AggregateExpression =>
+ aggregateExpression.copy(resultId = NamedExpression.newExprId)
+ }
+ newAlias(
+ filteredAggregate,
+ Some(outputName(value, aggregate, isSingleAggregate =
aggregates.size == 1))
+ )
+ }
+ }
+ Aggregate(groupByExpressions, groupByExpressions ++ pivotAggregates,
child)
+ }
+ }
+
+ private def outputName(
+ value: Expression,
+ aggregate: Expression,
+ isSingleAggregate: Boolean): String = {
+ val stringValue = value match {
+ case namedExpression: NamedExpression => namedExpression.name
+ case _ =>
+ val utf8Value =
+ Cast(value, StringType,
Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
+ Option(utf8Value).map(_.toString).getOrElse("null")
+ }
+ if (isSingleAggregate) {
+ stringValue
+ } else {
+ val suffix = aggregate match {
+ case namedExpression: NamedExpression => namedExpression.name
+ case _ => toPrettySQL(aggregate)
+ }
+ stringValue + "_" + suffix
+ }
+ }
+
+ private def createIfExpression(
+ expression: Expression,
+ pivotColumn: Expression,
+ value: Expression) = {
+ If(
+ EqualNullSafe(
+ pivotColumn,
+ Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))
+ ),
+ expression,
+ Literal(null)
+ )
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]