Repository: spark
Updated Branches:
  refs/heads/master e3554605b -> 20b8f2c32


[SPARK-15370][SQL] Revert PR "Update RewriteCorrelatedSuquery rule"

This reverts commit 9770f6ee60f6834e4e1200234109120427a5cc0d.

Author: Herman van Hovell <[email protected]>

Closes #13626 from hvanhovell/SPARK-15370-revert.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/20b8f2c3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/20b8f2c3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/20b8f2c3

Branch: refs/heads/master
Commit: 20b8f2c32af696c3856221c4c4fcd12c3f068af2
Parents: e355460
Author: Herman van Hovell <[email protected]>
Authored: Sun Jun 12 15:06:37 2016 -0700
Committer: Herman van Hovell <[email protected]>
Committed: Sun Jun 12 15:06:37 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/predicates.scala   |   7 +-
 .../sql/catalyst/optimizer/Optimizer.scala      | 198 +------------------
 .../org/apache/spark/sql/SubquerySuite.scala    |  81 --------
 3 files changed, 6 insertions(+), 280 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/20b8f2c3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index a3b098a..8a6cf53 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -69,11 +69,8 @@ trait PredicateHelper {
   protected def replaceAlias(
       condition: Expression,
       aliases: AttributeMap[Expression]): Expression = {
-    // Use transformUp to prevent infinite recursion when the replacement 
expression
-    // redefines the same ExprId,
-    condition.transformUp {
-      case a: Attribute =>
-        aliases.getOrElse(a, a)
+    condition.transform {
+      case a: Attribute => aliases.getOrElse(a, a)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/20b8f2c3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index d115274..a12c2ef 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -528,8 +528,7 @@ object CollapseProject extends Rule[LogicalPlan] {
     // Substitute any attributes that are produced by the lower projection, so 
that we safely
     // eliminate it.
     // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b 
+ 1 ...'
-    // Use transformUp to prevent infinite recursion.
-    val rewrittenUpper = upper.map(_.transformUp {
+    val rewrittenUpper = upper.map(_.transform {
       case a: Attribute => aliases.getOrElse(a, a)
     })
     // collapse upper and lower Projects may introduce unnecessary Aliases, 
trim them here.
@@ -1784,128 +1783,6 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] {
   }
 
   /**
-   * Statically evaluate an expression containing zero or more placeholders, 
given a set
-   * of bindings for placeholder values.
-   */
-  private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : 
Option[Any] = {
-    val rewrittenExpr = expr transform {
-      case r @ AttributeReference(_, dataType, _, _) =>
-        bindings(r.exprId) match {
-          case Some(v) => Literal.create(v, dataType)
-          case None => Literal.default(NullType)
-        }
-    }
-    Option(rewrittenExpr.eval())
-  }
-
-  /**
-   * Statically evaluate an expression containing one or more aggregates on an 
empty input.
-   */
-  private def evalAggOnZeroTups(expr: Expression) : Option[Any] = {
-    // AggregateExpressions are Unevaluable, so we need to replace all 
aggregates
-    // in the expression with the value they would return for zero input 
tuples.
-    // Also replace attribute refs (for example, for grouping columns) with 
NULL.
-    val rewrittenExpr = expr transform {
-      case a @ AggregateExpression(aggFunc, _, _, resultId) =>
-        aggFunc.defaultResult.getOrElse(Literal.default(NullType))
-
-      case AttributeReference(_, _, _, _) => Literal.default(NullType)
-    }
-    Option(rewrittenExpr.eval())
-  }
-
-  /**
-   * Statically evaluate a scalar subquery on an empty input.
-   *
-   * <b>WARNING:</b> This method only covers subqueries that pass the checks 
under
-   * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in
-   * CheckAnalysis become less restrictive, this method will need to change.
-   */
-  private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
-    // Inputs to this method will start with a chain of zero or more 
SubqueryAlias
-    // and Project operators, followed by an optional Filter, followed by an
-    // Aggregate. Traverse the operators recursively.
-    def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = {
-      lp match {
-        case SubqueryAlias(_, child) => evalPlan(child)
-        case Filter(condition, child) =>
-          val bindings = evalPlan(child)
-          if (bindings.isEmpty) bindings
-          else {
-            val exprResult = evalExpr(condition, bindings).getOrElse(false)
-              .asInstanceOf[Boolean]
-            if (exprResult) bindings else Map.empty
-          }
-
-        case Project(projectList, child) =>
-          val bindings = evalPlan(child)
-          if (bindings.isEmpty) {
-            bindings
-          } else {
-            projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
-          }
-
-        case Aggregate(_, aggExprs, _) =>
-          // Some of the expressions under the Aggregate node are the join 
columns
-          // for joining with the outer query block. Fill those expressions in 
with
-          // nulls and statically evaluate the remainder.
-          aggExprs.map(ne => ne match {
-            case AttributeReference(_, _, _, _) => (ne.exprId, None)
-            case Alias(AttributeReference(_, _, _, _), _) => (ne.exprId, None)
-            case _ => (ne.exprId, evalAggOnZeroTups(ne))
-          }).toMap
-
-        case _ => sys.error(s"Unexpected operator in scalar subquery: $lp")
-      }
-    }
-
-    val resultMap = evalPlan(plan)
-
-    // By convention, the scalar subquery result is the leftmost field.
-    resultMap(plan.output.head.exprId)
-  }
-
-  /**
-   * Split the plan for a scalar subquery into the parts above the innermost 
query block
-   * (first part of returned value), the HAVING clause of the innermost query 
block
-   * (optional second part) and the parts below the HAVING CLAUSE (third part).
-   */
-  private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], 
Option[Filter], Aggregate) = {
-    val topPart = ArrayBuffer.empty[LogicalPlan]
-    var bottomPart : LogicalPlan = plan
-    while (true) {
-      bottomPart match {
-        case havingPart@Filter(_, aggPart@Aggregate(_, _, _)) =>
-          return (topPart, Option(havingPart), aggPart.asInstanceOf[Aggregate])
-
-        case aggPart@Aggregate(_, _, _) =>
-          // No HAVING clause
-          return (topPart, None, aggPart)
-
-        case p@Project(_, child) =>
-          topPart += p
-          bottomPart = child
-
-        case s@SubqueryAlias(_, child) =>
-          topPart += s
-          bottomPart = child
-
-        case Filter(_, op@_) =>
-          sys.error(s"Correlated subquery has unexpected operator $op below 
filter")
-
-        case op@_ => sys.error(s"Unexpected operator $op in correlated 
subquery")
-      }
-    }
-
-    sys.error("This line should be unreachable")
-  }
-
-
-
-  // Name of generated column used in rewrite below
-  val ALWAYS_TRUE_COLNAME = "alwaysTrue"
-
-  /**
    * Construct a new child plan by left joining the given subqueries to a base 
plan.
    */
   private def constructLeftJoins(
@@ -1913,76 +1790,9 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] {
       subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
     subqueries.foldLeft(child) {
       case (currentChild, ScalarSubquery(query, conditions, _)) =>
-        val origOutput = query.output.head
-
-        val resultWithZeroTups = evalSubqueryOnZeroTups(query)
-        if (resultWithZeroTups.isEmpty) {
-          // CASE 1: Subquery guaranteed not to have the COUNT bug
-          Project(
-            currentChild.output :+ origOutput,
-            Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
-        } else {
-          // Subquery might have the COUNT bug. Add appropriate corrections.
-          val (topPart, havingNode, aggNode) = splitSubquery(query)
-
-          // The next two cases add a leading column to the outer join input 
to make it
-          // possible to distinguish between the case when no tuples join and 
the case
-          // when the tuple that joins contains null values.
-          // The leading column always has the value TRUE.
-          val alwaysTrueExprId = NamedExpression.newExprId
-          val alwaysTrueExpr = Alias(Literal.TrueLiteral,
-            ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId)
-          val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME,
-            BooleanType)(exprId = alwaysTrueExprId)
-
-          val aggValRef = query.output.head
-
-          if (!havingNode.isDefined) {
-            // CASE 2: Subquery with no HAVING clause
-            Project(
-              currentChild.output :+
-                Alias(
-                  If(IsNull(alwaysTrueRef),
-                    Literal(resultWithZeroTups.get, origOutput.dataType),
-                    aggValRef), origOutput.name)(exprId = origOutput.exprId),
-              Join(currentChild,
-                Project(query.output :+ alwaysTrueExpr, query),
-                LeftOuter, conditions.reduceOption(And)))
-
-          } else {
-            // CASE 3: Subquery with HAVING clause. Pull the HAVING clause 
above the join.
-            // Need to modify any operators below the join to pass through all 
columns
-            // referenced in the HAVING clause.
-            var subqueryRoot : UnaryNode = aggNode
-            val havingInputs : Seq[NamedExpression] = aggNode.output
-
-            topPart.reverse.foreach(
-              _ match {
-                case Project(projList, _) =>
-                  subqueryRoot = Project(projList ++ havingInputs, 
subqueryRoot)
-                case s@SubqueryAlias(alias, _) => subqueryRoot = 
SubqueryAlias(alias, subqueryRoot)
-                case op@_ => sys.error(s"Unexpected operator $op in corelated 
subquery")
-              }
-            )
-
-            // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
-            //      WHEN NOT (original HAVING clause expr) THEN CAST(null AS 
<type of aggVal>)
-            //      ELSE (aggregate value) END AS (original column name)
-            val caseExpr = Alias(CaseWhen(
-              Seq[(Expression, Expression)] (
-                (IsNull(alwaysTrueRef), Literal(resultWithZeroTups.get, 
origOutput.dataType)),
-                (Not(havingNode.get.condition), Literal(null, 
aggValRef.dataType))
-              ), aggValRef
-            ), origOutput.name) (exprId = origOutput.exprId)
-
-            Project(
-              currentChild.output :+ caseExpr,
-              Join(currentChild,
-                Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
-                LeftOuter, conditions.reduceOption(And)))
-
-          }
-        }
+        Project(
+          currentChild.output :+ query.output.head,
+          Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/20b8f2c3/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 1d9ff21..1a99fb6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -490,85 +490,4 @@ class SubquerySuite extends QueryTest with 
SharedSQLContext {
         """.stripMargin),
       Row(3) :: Nil)
   }
-
-  test("SPARK-15370: COUNT bug in WHERE clause (Filter)") {
-    // Case 1: Canonical example of the COUNT bug
-    checkAnswer(
-      sql("select l.a from l where (select count(*) from r where l.a = r.c) < 
l.a"),
-      Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
-    // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently 
uses
-    // a rewrite that is vulnerable to the COUNT bug
-    checkAnswer(
-      sql("select l.a from l where (select count(*) from r where l.a = r.c) = 
0"),
-      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
-    // Case 3: COUNT bug without a COUNT aggregate
-    checkAnswer(
-      sql("select l.a from l where (select sum(r.d) is null from r where l.a = 
r.c)"),
-      Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
-  }
-
-  test("SPARK-15370: COUNT bug in SELECT clause (Project)") {
-    checkAnswer(
-      sql("select a, (select count(*) from r where l.a = r.c) as cnt from l"),
-      Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: 
Row(null, 0)
-        :: Row(null, 0) :: Row(6, 1) :: Nil)
-  }
-
-  test("SPARK-15370: COUNT bug in HAVING clause (Filter)") {
-    checkAnswer(
-      sql("select l.a as grp_a from l group by l.a " +
-        "having (select count(*) from r where grp_a = r.c) = 0 " +
-        "order by grp_a"),
-      Row(null) :: Row(1) :: Nil)
-  }
-
-  test("SPARK-15370: COUNT bug in Aggregate") {
-    checkAnswer(
-      sql("select l.a as aval, sum((select count(*) from r where l.a = r.c)) 
as cnt " +
-        "from l group by l.a order by aval"),
-      Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1)  :: Nil)
-  }
-
-  test("SPARK-15370: COUNT bug negative examples") {
-    // Case 1: Potential COUNT bug case that was working correctly prior to 
the fix
-    checkAnswer(
-      sql("select l.a from l where (select sum(r.d) from r where l.a = r.c) is 
null"),
-      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil)
-    // Case 2: COUNT aggregate but no COUNT bug due to > 0 test.
-    checkAnswer(
-      sql("select l.a from l where (select count(*) from r where l.a = r.c) > 
0"),
-      Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil)
-    // Case 3: COUNT inside aggregate expression but no COUNT bug.
-    checkAnswer(
-      sql("select l.a from l where (select count(*) + sum(r.d) from r where 
l.a = r.c) = 0"),
-      Nil)
-  }
-
-  test("SPARK-15370: COUNT bug in subquery in subquery in subquery") {
-    checkAnswer(
-      sql("""select l.a from l
-            |where (
-            |    select cntPlusOne + 1 as cntPlusTwo from (
-            |        select cnt + 1 as cntPlusOne from (
-            |            select sum(r.c) s, count(*) cnt from r where l.a = 
r.c having cnt = 0
-            |        )
-            |    )
-            |) = 2""".stripMargin),
-      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
-  }
-
-  test("SPARK-15370: COUNT bug with nasty predicate expr") {
-    checkAnswer(
-      sql("select l.a from l where " +
-        "(select case when count(*) = 1 then null else count(*) end as cnt " +
-        "from r where l.a = r.c) = 0"),
-      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
-  }
-
-  test("SPARK-15370: COUNT bug with attribute ref in subquery input and output 
") {
-    checkAnswer(
-      sql("select l.b, (select (r.c + count(*)) is null from r where l.a = 
r.c) from l"),
-      Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
-        Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, 
true) :: Nil)
-  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to