This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c05236237ed [SPARK-42003][SQL] Reduce duplicate code in
ResolveGroupByAll
c05236237ed is described below
commit c05236237ed7c0ad7dfbe2a185bd96acf51a2c4f
Author: Gengliang Wang <[email protected]>
AuthorDate: Thu Jan 12 17:03:45 2023 +0800
[SPARK-42003][SQL] Reduce duplicate code in ResolveGroupByAll
### What changes were proposed in this pull request?
Reduce duplicate code in ResolveGroupByAll by moving the group by
expression inference into a new method.
### Why are the changes needed?
Code clean up
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing UT
Closes #39523 from gengliangwang/refactorAll.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/analysis/ResolveGroupByAll.scala | 36 +++++++++++++++-------
1 file changed, 25 insertions(+), 11 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala
index d45ea412031..8c6ba20cd1a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala
@@ -47,6 +47,24 @@ object ResolveGroupByAll extends Rule[LogicalPlan] {
}
}
+ /**
+ * Returns all the grouping expressions inferred from a GROUP BY ALL
aggregate.
+ * The result is optional. If Spark fails to infer the grouping columns, it
is None.
+ * Otherwise, it contains all the non-aggregate expressions from the project
list of the input
+ * Aggregate.
+ */
+ private def getGroupingExpressions(a: Aggregate): Option[Seq[Expression]] = {
+ val groupingExprs =
a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
+ // If the grouping exprs are empty, this could either be (1) a valid
global aggregate, or
+ // (2) we simply fail to infer the grouping columns. As an example, in "i
+ sum(j)", we will
+ // not automatically infer the grouping column to be "i".
+ if (groupingExprs.isEmpty &&
a.aggregateExpressions.exists(containsAttribute)) {
+ None
+ } else {
+ Some(groupingExprs)
+ }
+ }
+
override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsUpWithPruning(
_.containsAllPatterns(UNRESOLVED_ATTRIBUTE, AGGREGATE), ruleId) {
case a: Aggregate
@@ -54,18 +72,15 @@ object ResolveGroupByAll extends Rule[LogicalPlan] {
// Only makes sense to do the rewrite once all the aggregate expressions
have been resolved.
// Otherwise, we might incorrectly pull an actual aggregate expression
over to the grouping
// expression list (because we don't know they would be aggregate
expressions until resolved).
- val groupingExprs =
a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
+ val groupingExprs = getGroupingExpressions(a)
- // If the grouping exprs are empty, this could either be (1) a valid
global aggregate, or
- // (2) we simply fail to infer the grouping columns. As an example, in
"i + sum(j)", we will
- // not automatically infer the grouping column to be "i".
- if (groupingExprs.isEmpty &&
a.aggregateExpressions.exists(containsAttribute)) {
- // Case (2): don't replace the ALL. We will eventually tell the user
in checkAnalysis
- // that we cannot resolve the all in group by.
+ if (groupingExprs.isEmpty) {
+ // Don't replace the ALL when we fail to infer the grouping columns.
We will eventually
+ // tell the user in checkAnalysis that we cannot resolve the all in
group by.
a
} else {
- // Case (1): this is a valid global aggregate.
- a.copy(groupingExpressions = groupingExprs)
+ // This is a valid GROUP BY ALL aggregate.
+ a.copy(groupingExpressions = groupingExprs.get)
}
}
@@ -94,8 +109,7 @@ object ResolveGroupByAll extends Rule[LogicalPlan] {
*/
def checkAnalysis(operator: LogicalPlan): Unit = operator match {
case a: Aggregate if a.aggregateExpressions.forall(_.resolved) &&
matchToken(a) =>
- val noAgg =
a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
- if (noAgg.isEmpty && a.aggregateExpressions.exists(containsAttribute)) {
+ if (getGroupingExpressions(a).isEmpty) {
operator.failAnalysis(
errorClass = "UNRESOLVED_ALL_IN_GROUP_BY",
messageParameters = Map.empty)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]