Repository: spark
Updated Branches:
  refs/heads/master 4d535d1f1 -> 560489f4e


[SPARK-13732][SPARK-13797][SQL] Remove projectList from Window and Eliminate 
useless Window

#### What changes were proposed in this pull request?

`projectList` is useless. Its value is always the same as the child.output. 
Remove it from the class `Window`. Removal can simplify the codes in Analyzer 
and Optimizer.

This PR is based on the discussion started by cloud-fan in a separate PR:
https://github.com/apache/spark/pull/5604#discussion_r55140466

This PR also eliminates useless `Window`.

cloud-fan yhuai

#### How was this patch tested?

Existing test cases cover it.

Author: gatorsmile <[email protected]>
Author: xiaoli <[email protected]>
Author: Xiao Li <[email protected]>

Closes #11565 from gatorsmile/removeProjListWindow.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/560489f4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/560489f4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/560489f4

Branch: refs/heads/master
Commit: 560489f4e16ff18b5e66e7de1bb84d890369a462
Parents: 4d535d1
Author: gatorsmile <[email protected]>
Authored: Fri Mar 11 11:59:18 2016 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Fri Mar 11 11:59:18 2016 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 11 +---
 .../apache/spark/sql/catalyst/dsl/package.scala |  6 ++
 .../sql/catalyst/optimizer/Optimizer.scala      | 20 +++---
 .../catalyst/plans/logical/basicOperators.scala |  5 +-
 .../catalyst/optimizer/ColumnPruningSuite.scala | 68 +++++++++++++++++++-
 .../spark/sql/execution/SparkStrategies.scala   |  5 +-
 .../org/apache/spark/sql/execution/Window.scala |  6 +-
 7 files changed, 94 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/560489f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9ab0a20..b654827 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -421,7 +421,7 @@ class Analyzer(
           val newOutput = oldVersion.generatorOutput.map(_.newInstance())
           (oldVersion, oldVersion.copy(generatorOutput = newOutput))
 
-        case oldVersion @ Window(_, windowExpressions, _, _, child)
+        case oldVersion @ Window(windowExpressions, _, _, child)
             if 
AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
               .nonEmpty =>
           (oldVersion, oldVersion.copy(windowExpressions = 
newAliases(windowExpressions)))
@@ -658,10 +658,6 @@ class Analyzer(
         case p: Project =>
           val missing = missingAttrs -- p.child.outputSet
           Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, 
missing))
-        case w: Window =>
-          val missing = missingAttrs -- w.child.outputSet
-          w.copy(projectList = w.projectList ++ missingAttrs,
-            child = addMissingAttr(w.child, missing))
         case a: Aggregate =>
           // all the missing attributes should be grouping expressions
           // TODO: push down AggregateExpression
@@ -1166,7 +1162,6 @@ class Analyzer(
         // Set currentChild to the newly created Window operator.
         currentChild =
           Window(
-            currentChild.output,
             windowExpressions,
             partitionSpec,
             orderSpec,
@@ -1436,10 +1431,10 @@ object CleanupAliases extends Rule[LogicalPlan] {
       val cleanedAggs = 
aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
       Aggregate(grouping.map(trimAliases), cleanedAggs, child)
 
-    case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) 
=>
+    case w @ Window(windowExprs, partitionSpec, orderSpec, child) =>
       val cleanedWindowExprs =
         windowExprs.map(e => 
trimNonTopLevelAliases(e).asInstanceOf[NamedExpression])
-      Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
+      Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
         orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
 
     // Operators that operate on objects should only have expressions from 
encoders, which should

http://git-wip-us.apache.org/repos/asf/spark/blob/560489f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 6346326..dc5264e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -268,6 +268,12 @@ package object dsl {
         Aggregate(groupingExprs, aliasedExprs, logicalPlan)
       }
 
+      def window(
+          windowExpressions: Seq[NamedExpression],
+          partitionSpec: Seq[Expression],
+          orderSpec: Seq[SortOrder]): LogicalPlan =
+        Window(windowExpressions, partitionSpec, orderSpec, logicalPlan)
+
       def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, 
logicalPlan)
 
       def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, 
otherPlan)

http://git-wip-us.apache.org/repos/asf/spark/blob/560489f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 650b4ee..8577667 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -315,21 +315,17 @@ object SetOperationPushDown extends Rule[LogicalPlan] 
with PredicateHelper {
  *   - LeftSemiJoin
  */
 object ColumnPruning extends Rule[LogicalPlan] {
-  def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
+  private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): 
Boolean =
     output1.size == output2.size &&
       output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    // Prunes the unused columns from project list of 
Project/Aggregate/Window/Expand
+    // Prunes the unused columns from project list of Project/Aggregate/Expand
     case p @ Project(_, p2: Project) if (p2.outputSet -- 
p.references).nonEmpty =>
       p.copy(child = p2.copy(projectList = 
p2.projectList.filter(p.references.contains)))
     case p @ Project(_, a: Aggregate) if (a.outputSet -- 
p.references).nonEmpty =>
       p.copy(
         child = a.copy(aggregateExpressions = 
a.aggregateExpressions.filter(p.references.contains)))
-    case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty =>
-      p.copy(child = w.copy(
-        projectList = w.projectList.filter(p.references.contains),
-        windowExpressions = w.windowExpressions.filter(p.references.contains)))
     case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- 
a.references).nonEmpty =>
       val newOutput = e.output.filter(a.references.contains(_))
       val newProjects = e.projections.map { proj =>
@@ -343,11 +339,9 @@ object ColumnPruning extends Rule[LogicalPlan] {
     case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- 
mp.references).nonEmpty =>
       mp.copy(child = prunedChild(child, mp.references))
 
-    // Prunes the unused columns from child of Aggregate/Window/Expand/Generate
+    // Prunes the unused columns from child of Aggregate/Expand/Generate
     case a @ Aggregate(_, _, child) if (child.outputSet -- 
a.references).nonEmpty =>
       a.copy(child = prunedChild(child, a.references))
-    case w @ Window(_, _, _, _, child) if (child.outputSet -- 
w.references).nonEmpty =>
-      w.copy(child = prunedChild(child, w.references))
     case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty 
=>
       e.copy(child = prunedChild(child, e.references))
     case g: Generate if !g.join && (g.child.outputSet -- 
g.references).nonEmpty =>
@@ -381,6 +375,14 @@ object ColumnPruning extends Rule[LogicalPlan] {
         p
       }
 
+    // Prune unnecessary window expressions
+    case p @ Project(_, w: Window) if (w.windowOutputSet -- 
p.references).nonEmpty =>
+      p.copy(child = w.copy(
+        windowExpressions = w.windowExpressions.filter(p.references.contains)))
+
+    // Eliminate no-op Window
+    case w: Window if w.windowExpressions.isEmpty => w.child
+
     // Eliminate no-op Projects
     case p @ Project(projectList, child) if sameOutput(child.output, p.output) 
=> child
 

http://git-wip-us.apache.org/repos/asf/spark/blob/560489f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 3bc246a..09ea3fe 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -434,14 +434,15 @@ case class Aggregate(
 }
 
 case class Window(
-    projectList: Seq[Attribute],
     windowExpressions: Seq[NamedExpression],
     partitionSpec: Seq[Expression],
     orderSpec: Seq[SortOrder],
     child: LogicalPlan) extends UnaryNode {
 
   override def output: Seq[Attribute] =
-    projectList ++ windowExpressions.map(_.toAttribute)
+    child.output ++ windowExpressions.map(_.toAttribute)
+
+  def windowOutputSet: AttributeSet = 
AttributeSet(windowExpressions.map(_.toAttribute))
 }
 
 private[sql] object Expand {

http://git-wip-us.apache.org/repos/asf/spark/blob/560489f4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 409e922..dd7d65d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, 
SortOrder}
+import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete, Count}
 import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -33,7 +34,8 @@ class ColumnPruningSuite extends PlanTest {
 
   object Optimize extends RuleExecutor[LogicalPlan] {
     val batches = Batch("Column pruning", FixedPoint(100),
-      ColumnPruning) :: Nil
+      ColumnPruning,
+      CollapseProject) :: Nil
   }
 
   test("Column pruning for Generate when Generate.join = false") {
@@ -258,6 +260,68 @@ class ColumnPruningSuite extends PlanTest {
     comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
   }
 
+  test("Column pruning on Window with useless aggregate functions") {
+    val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int)
+
+    val originalQuery =
+      input.groupBy('a, 'c, 'd)('a, 'c, 'd,
+        WindowExpression(
+          AggregateExpression(Count('b), Complete, isDistinct = false),
+          WindowSpecDefinition( 'a :: Nil,
+            SortOrder('b, Ascending) :: Nil,
+            UnspecifiedFrame)).as('window)).select('a, 'c)
+
+    val correctAnswer = input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 
'c).analyze
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Column pruning on Window with selected agg expressions") {
+    val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int)
+
+    val originalQuery =
+      input.select('a, 'b, 'c, 'd,
+        WindowExpression(
+          AggregateExpression(Count('b), Complete, isDistinct = false),
+          WindowSpecDefinition( 'a :: Nil,
+            SortOrder('b, Ascending) :: Nil,
+            UnspecifiedFrame)).as('window)).where('window > 1).select('a, 'c)
+
+    val correctAnswer =
+      input.select('a, 'b, 'c)
+        .window(WindowExpression(
+          AggregateExpression(Count('b), Complete, isDistinct = false),
+          WindowSpecDefinition( 'a :: Nil,
+            SortOrder('b, Ascending) :: Nil,
+            UnspecifiedFrame)).as('window) :: Nil,
+          'a :: Nil, 'b.asc :: Nil)
+        .select('a, 'c, 'window).where('window > 1).select('a, 'c).analyze
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Column pruning on Window in select") {
+    val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int)
+
+    val originalQuery =
+      input.select('a, 'b, 'c, 'd,
+        WindowExpression(
+          AggregateExpression(Count('b), Complete, isDistinct = false),
+          WindowSpecDefinition( 'a :: Nil,
+            SortOrder('b, Ascending) :: Nil,
+            UnspecifiedFrame)).as('window)).select('a, 'c)
+
+    val correctAnswer = input.select('a, 'c).analyze
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+
+    comparePlans(optimized, correctAnswer)
+  }
+
   test("Column pruning on Union") {
     val input1 = LocalRelation('a.int, 'b.string, 'c.double)
     val input2 = LocalRelation('c.int, 'd.string, 'e.double)

http://git-wip-us.apache.org/repos/asf/spark/blob/560489f4/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index debd04a..bae0750 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -344,9 +344,8 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.Filter(condition, planLater(child)) :: Nil
       case e @ logical.Expand(_, _, child) =>
         execution.Expand(e.projections, e.output, planLater(child)) :: Nil
-      case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, 
child) =>
-        execution.Window(
-          projectList, windowExprs, partitionSpec, orderSpec, 
planLater(child)) :: Nil
+      case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
+        execution.Window(windowExprs, partitionSpec, orderSpec, 
planLater(child)) :: Nil
       case logical.Sample(lb, ub, withReplacement, seed, child) =>
         execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: 
Nil
       case logical.LocalRelation(output, data) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/560489f4/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 84154a4..a4c0e1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -81,14 +81,14 @@ import 
org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, Unsaf
  * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]].
  */
 case class Window(
-    projectList: Seq[Attribute],
     windowExpression: Seq[NamedExpression],
     partitionSpec: Seq[Expression],
     orderSpec: Seq[SortOrder],
     child: SparkPlan)
   extends UnaryNode {
 
-  override def output: Seq[Attribute] = projectList ++ 
windowExpression.map(_.toAttribute)
+  override def output: Seq[Attribute] =
+    child.output ++ windowExpression.map(_.toAttribute)
 
   override def requiredChildDistribution: Seq[Distribution] = {
     if (partitionSpec.isEmpty) {
@@ -275,7 +275,7 @@ case class Window(
     val unboundToRefMap = expressions.zip(references).toMap
     val patchedWindowExpression = 
windowExpression.map(_.transform(unboundToRefMap))
     UnsafeProjection.create(
-      projectList ++ patchedWindowExpression,
+      child.output ++ patchedWindowExpression,
       child.output)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to