This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c05236237ed [SPARK-42003][SQL] Reduce duplicate code in 
ResolveGroupByAll
c05236237ed is described below

commit c05236237ed7c0ad7dfbe2a185bd96acf51a2c4f
Author: Gengliang Wang <[email protected]>
AuthorDate: Thu Jan 12 17:03:45 2023 +0800

    [SPARK-42003][SQL] Reduce duplicate code in ResolveGroupByAll
    
    ### What changes were proposed in this pull request?
    
    Reduce duplicate code in ResolveGroupByAll by moving the group by 
expression inference into a new method.
    
    ### Why are the changes needed?
    
    Code clean up
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    ### How was this patch tested?
    
    Existing UT
    
    Closes #39523 from gengliangwang/refactorAll.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/analysis/ResolveGroupByAll.scala  | 36 +++++++++++++++-------
 1 file changed, 25 insertions(+), 11 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala
index d45ea412031..8c6ba20cd1a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala
@@ -47,6 +47,24 @@ object ResolveGroupByAll extends Rule[LogicalPlan] {
     }
   }
 
+  /**
+   * Returns all the grouping expressions inferred from a GROUP BY ALL 
aggregate.
+   * The result is optional. If Spark fails to infer the grouping columns, it 
is None.
+   * Otherwise, it contains all the non-aggregate expressions from the project 
list of the input
+   * Aggregate.
+   */
+  private def getGroupingExpressions(a: Aggregate): Option[Seq[Expression]] = {
+    val groupingExprs = 
a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
+    // If the grouping exprs are empty, this could either be (1) a valid 
global aggregate, or
+    // (2) we simply fail to infer the grouping columns. As an example, in "i 
+ sum(j)", we will
+    // not automatically infer the grouping column to be "i".
+    if (groupingExprs.isEmpty && 
a.aggregateExpressions.exists(containsAttribute)) {
+      None
+    } else {
+      Some(groupingExprs)
+    }
+  }
+
   override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsUpWithPruning(
     _.containsAllPatterns(UNRESOLVED_ATTRIBUTE, AGGREGATE), ruleId) {
     case a: Aggregate
@@ -54,18 +72,15 @@ object ResolveGroupByAll extends Rule[LogicalPlan] {
       // Only makes sense to do the rewrite once all the aggregate expressions 
have been resolved.
       // Otherwise, we might incorrectly pull an actual aggregate expression 
over to the grouping
       // expression list (because we don't know they would be aggregate 
expressions until resolved).
-      val groupingExprs = 
a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
+      val groupingExprs = getGroupingExpressions(a)
 
-      // If the grouping exprs are empty, this could either be (1) a valid 
global aggregate, or
-      // (2) we simply fail to infer the grouping columns. As an example, in 
"i + sum(j)", we will
-      // not automatically infer the grouping column to be "i".
-      if (groupingExprs.isEmpty && 
a.aggregateExpressions.exists(containsAttribute)) {
-        // Case (2): don't replace the ALL. We will eventually tell the user 
in checkAnalysis
-        // that we cannot resolve the all in group by.
+      if (groupingExprs.isEmpty) {
+        // Don't replace the ALL when we fail to infer the grouping columns. 
We will eventually
+        // tell the user in checkAnalysis that we cannot resolve the all in 
group by.
         a
       } else {
-        // Case (1): this is a valid global aggregate.
-        a.copy(groupingExpressions = groupingExprs)
+        // This is a valid GROUP BY ALL aggregate.
+        a.copy(groupingExpressions = groupingExprs.get)
       }
   }
 
@@ -94,8 +109,7 @@ object ResolveGroupByAll extends Rule[LogicalPlan] {
    */
   def checkAnalysis(operator: LogicalPlan): Unit = operator match {
     case a: Aggregate if a.aggregateExpressions.forall(_.resolved) && 
matchToken(a) =>
-      val noAgg = 
a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
-      if (noAgg.isEmpty && a.aggregateExpressions.exists(containsAttribute)) {
+      if (getGroupingExpressions(a).isEmpty) {
         operator.failAnalysis(
           errorClass = "UNRESOLVED_ALL_IN_GROUP_BY",
           messageParameters = Map.empty)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to