Repository: spark
Updated Branches:
  refs/heads/branch-2.0 05bb5b6f6 -> fbc73f731


[SPARK-14785] [SQL] Support correlated scalar subqueries

## What changes were proposed in this pull request?
In this PR we add support for correlated scalar subqueries. An example of such 
a query is:
```SQL
select * from tbl1 a where a.value > (select max(value) from tbl2 b where b.key 
= a.key)
```
The implementation adds the `RewriteCorrelatedScalarSubquery` rule to the 
Optimizer. This rule plans these subqueries using `LEFT OUTER` joins. It 
currently supports rewrites for `Project`, `Aggregate` & `Filter` logical plans.

I could not find a well defined semantics for the use of scalar subqueries in 
an `Aggregate`. The current implementation currently evaluates the scalar 
subquery *before* aggregation. This means that you either have to make scalar 
subquery part of the grouping expression, or that you have to aggregate it 
further on. I am open to suggestions on this.

The implementation currently forces the uniqueness of a scalar subquery by 
enforcing that it is aggregated and that the resulting column is wrapped in an 
`AggregateExpression`.

## How was this patch tested?
Added tests to `SubquerySuite`.

Author: Herman van Hovell <[email protected]>

Closes #12822 from hvanhovell/SPARK-14785.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fbc73f73
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fbc73f73
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fbc73f73

Branch: refs/heads/branch-2.0
Commit: fbc73f73186873cfd60581e58aff4a8d919e39b4
Parents: 05bb5b6
Author: Herman van Hovell <[email protected]>
Authored: Mon May 2 16:32:31 2016 -0700
Committer: Davies Liu <[email protected]>
Committed: Mon May 2 16:33:51 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 11 ++-
 .../sql/catalyst/analysis/CheckAnalysis.scala   | 42 +++++++++-
 .../sql/catalyst/expressions/subquery.scala     | 39 ++++++----
 .../sql/catalyst/optimizer/Optimizer.scala      | 82 ++++++++++++++++++--
 .../plans/logical/basicLogicalOperators.scala   |  2 +-
 .../catalyst/analysis/AnalysisErrorSuite.scala  | 11 +--
 .../org/apache/spark/sql/SubquerySuite.scala    | 47 +++++++++++
 7 files changed, 195 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fbc73f73/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2f8ab3f..59af5b7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1081,10 +1081,10 @@ class Analyzer(
       // Step 2: Pull out the predicates if the plan is resolved.
       if (current.resolved) {
         // Make sure the resolved query has the required number of output 
columns. This is only
-        // needed for IN expressions.
+        // needed for Scalar and IN subqueries.
         if (requiredColumns > 0 && requiredColumns != current.output.size) {
-          failAnalysis(s"The number of fields in the value ($requiredColumns) 
does not " +
-            s"match with the number of columns in the subquery 
(${current.output.size})")
+          failAnalysis(s"The number of columns in the subquery 
(${current.output.size}) " +
+            s"does not match the required number of columns 
($requiredColumns)")
         }
         // Pullout predicates and construct a new plan.
         f.tupled(rewriteSubQuery(current, plans))
@@ -1099,8 +1099,11 @@ class Analyzer(
      */
     private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): 
LogicalPlan = {
       plan transformExpressions {
+        case s @ ScalarSubquery(sub, conditions, exprId)
+            if sub.resolved && conditions.isEmpty && sub.output.size != 1 =>
+          failAnalysis(s"Scalar subquery must return only one column, but got 
${sub.output.size}")
         case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
-          resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
+          resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId))
         case e @ Exists(sub, exprId) =>
           resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, 
exprId))
         case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved =>

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc73f73/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 6e3a14d..800bf01 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin}
+import org.apache.spark.sql.catalyst.plans.UsingJoin
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
 
@@ -60,9 +60,6 @@ trait CheckAnalysis extends PredicateHelper {
             val from = operator.inputSet.map(_.name).mkString(", ")
             a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: 
[$from]")
 
-          case ScalarSubquery(_, conditions, _) if conditions.nonEmpty =>
-            failAnalysis("Correlated scalar subqueries are not supported.")
-
           case e: Expression if e.checkInputDataTypes().isFailure =>
             e.checkInputDataTypes() match {
               case TypeCheckResult.TypeCheckFailure(message) =>
@@ -104,6 +101,36 @@ trait CheckAnalysis extends PredicateHelper {
                 failAnalysis(s"Window specification $s is not valid because 
$m")
               case None => w
             }
+
+          case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty 
=>
+            // Make sure we are using equi-joins.
+            conditions.foreach {
+              case _: EqualTo | _: EqualNullSafe => // ok
+              case e => failAnalysis(
+                s"The correlated scalar subquery can only contain equality 
predicates: $e")
+            }
+
+            // Make sure correlated scalar subqueries contain one row for 
every outer row by
+            // enforcing that they are aggregates which contain exactly one 
aggregate expressions.
+            // The analyzer has already checked that subquery contained only 
one output column, and
+            // added all the grouping expressions to the aggregate.
+            def checkAggregate(a: Aggregate): Unit = {
+              val aggregates = a.expressions.flatMap(_.collect {
+                case a: AggregateExpression => a
+              })
+              if (aggregates.isEmpty) {
+                failAnalysis("The output of a correlated scalar subquery must 
be aggregated")
+              }
+            }
+
+            query match {
+              case a: Aggregate => checkAggregate(a)
+              case Filter(_, a: Aggregate) => checkAggregate(a)
+              case Project(_, a: Aggregate) => checkAggregate(a)
+              case Project(_, Filter(_, a: Aggregate)) => checkAggregate(a)
+              case fail => failAnalysis(s"Correlated scalar subqueries must be 
Aggregated: $fail")
+            }
+            s
         }
 
         operator match {
@@ -220,6 +247,13 @@ trait CheckAnalysis extends PredicateHelper {
                 | but one table has '${firstError.output.length}' columns and 
another table has
                 | '${s.children.head.output.length}' columns""".stripMargin)
 
+          case p if 
p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
+            p match {
+              case _: Filter | _: Aggregate | _: Project => // Ok
+              case other => failAnalysis(
+                s"Correlated scalar sub-queries can only be used in a 
Filter/Aggregate/Project: $p")
+            }
+
           case p if 
p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
             failAnalysis(s"Predicate sub-queries can only be used in a Filter: 
$p")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc73f73/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index eed062f..5001f9a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -44,6 +44,15 @@ abstract class SubqueryExpression extends Expression {
   protected def conditionString: String = children.mkString("[", " && ", "]")
 }
 
+object SubqueryExpression {
+  def hasCorrelatedSubquery(e: Expression): Boolean = {
+    e.find {
+      case e: SubqueryExpression if e.children.nonEmpty => true
+      case _ => false
+    }.isDefined
+  }
+}
+
 /**
  * A subquery that will return only one row and one column. This will be 
converted into a physical
  * scalar subquery during planning.
@@ -55,28 +64,26 @@ case class ScalarSubquery(
     children: Seq[Expression] = Seq.empty,
     exprId: ExprId = NamedExpression.newExprId)
   extends SubqueryExpression with Unevaluable {
-
-  override def plan: LogicalPlan = SubqueryAlias(toString, query)
-
   override lazy val resolved: Boolean = childrenResolved && query.resolved
-
-  override def dataType: DataType = query.schema.fields.head.dataType
-
-  override def checkInputDataTypes(): TypeCheckResult = {
-    if (query.schema.length != 1) {
-      TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one 
column, but got " +
-        query.schema.length.toString)
-    } else {
-      TypeCheckResult.TypeCheckSuccess
-    }
+  override lazy val references: AttributeSet = {
+    if (query.resolved) super.references -- query.outputSet
+    else super.references
   }
-
+  override def dataType: DataType = query.schema.fields.head.dataType
   override def foldable: Boolean = false
   override def nullable: Boolean = true
-
+  override def plan: LogicalPlan = SubqueryAlias(toString, query)
   override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = 
plan)
+  override def toString: String = s"scalar-subquery#${exprId.id} 
$conditionString"
+}
 
-  override def toString: String = s"subquery#${exprId.id} $conditionString"
+object ScalarSubquery {
+  def hasCorrelatedScalarSubquery(e: Expression): Boolean = {
+    e.find {
+      case e: ScalarSubquery if e.children.nonEmpty => true
+      case _ => false
+    }.isDefined
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc73f73/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e1c969f..a3ab89d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import scala.annotation.tailrec
 import scala.collection.immutable.HashSet
+import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
 import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, 
DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry}
@@ -100,6 +101,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, 
conf: CatalystConf)
       EliminateSorts,
       SimplifyCasts,
       SimplifyCaseConversionExpressions,
+      RewriteCorrelatedScalarSubquery,
       EliminateSerialization) ::
     Batch("Decimal Optimizations", fixedPoint,
       DecimalAggregates) ::
@@ -1081,7 +1083,7 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
     assert(input.size >= 2)
     if (input.size == 2) {
       val (joinConditions, others) = conditions.partition(
-        e => !PredicateSubquery.hasPredicateSubquery(e))
+        e => !SubqueryExpression.hasCorrelatedSubquery(e))
       val join = Join(input(0), input(1), Inner, 
joinConditions.reduceLeftOption(And))
       if (others.nonEmpty) {
         Filter(others.reduceLeft(And), join)
@@ -1101,7 +1103,7 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
 
       val joinedRefs = left.outputSet ++ right.outputSet
       val (joinConditions, others) = conditions.partition(
-        e => e.references.subsetOf(joinedRefs) && 
!PredicateSubquery.hasPredicateSubquery(e))
+        e => e.references.subsetOf(joinedRefs) && 
!SubqueryExpression.hasCorrelatedSubquery(e))
       val joined = Join(left, right, Inner, 
joinConditions.reduceLeftOption(And))
 
       // should not have reference to same logical plan
@@ -1134,7 +1136,7 @@ object OuterJoinElimination extends Rule[LogicalPlan] 
with PredicateHelper {
    * Returns whether the expression returns null or false when all inputs are 
nulls.
    */
   private def canFilterOutNull(e: Expression): Boolean = {
-    if (!e.deterministic || PredicateSubquery.hasPredicateSubquery(e)) return 
false
+    if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) 
return false
     val attributes = e.references.toSeq
     val emptyRow = new GenericInternalRow(attributes.length)
     val v = BindReferences.bindReference(e, attributes).eval(emptyRow)
@@ -1203,7 +1205,6 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
     case f @ Filter(filterCondition, Join(left, right, joinType, 
joinCondition)) =>
       val (leftFilterConditions, rightFilterConditions, commonFilterCondition) 
=
         split(splitConjunctivePredicates(filterCondition), left, right)
-
       joinType match {
         case Inner =>
           // push down the single side `where` condition into respective sides
@@ -1212,7 +1213,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
           val newRight = rightFilterConditions.
             reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
           val (newJoinConditions, others) =
-            commonFilterCondition.partition(e => 
!PredicateSubquery.hasPredicateSubquery(e))
+            commonFilterCondition.partition(e => 
!SubqueryExpression.hasCorrelatedSubquery(e))
           val newJoinCond = (newJoinConditions ++ 
joinCondition).reduceLeftOption(And)
 
           val join = Join(newLeft, newRight, Inner, newJoinCond)
@@ -1573,3 +1574,74 @@ object RewritePredicateSubquery extends 
Rule[LogicalPlan] with PredicateHelper {
       }
   }
 }
+
+/**
+ * This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT 
OUTER joins.
+ */
+object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
+  /**
+   * Extract all correlated scalar subqueries from an expression. The 
subqueries are collected using
+   * the given collector. The expression is rewritten and returned.
+   */
+  private def extractCorrelatedScalarSubqueries[E <: Expression](
+      expression: E,
+      subqueries: ArrayBuffer[ScalarSubquery]): E = {
+    val newExpression = expression transform {
+      case s: ScalarSubquery if s.children.nonEmpty =>
+        subqueries += s
+        s.query.output.head
+    }
+    newExpression.asInstanceOf[E]
+  }
+
+  /**
+   * Construct a new child plan by left joining the given subqueries to a base 
plan.
+   */
+  private def constructLeftJoins(
+      child: LogicalPlan,
+      subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
+    subqueries.foldLeft(child) {
+      case (currentChild, ScalarSubquery(query, conditions, _)) =>
+        Project(
+          currentChild.output :+ query.output.head,
+          Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
+    }
+  }
+
+  /**
+   * Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing 
correlated scalar
+   * subqueries.
+   */
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case a @ Aggregate(grouping, expressions, child) =>
+      val subqueries = ArrayBuffer.empty[ScalarSubquery]
+      val newExpressions = 
expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
+      if (subqueries.nonEmpty) {
+        // We currently only allow correlated subqueries in an aggregate if 
they are part of the
+        // grouping expressions. As a result we need to replace all the scalar 
subqueries in the
+        // grouping expressions by their result.
+        val newGrouping = grouping.map { e =>
+          
subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e)
+        }
+        Aggregate(newGrouping, newExpressions, constructLeftJoins(child, 
subqueries))
+      } else {
+        a
+      }
+    case p @ Project(expressions, child) =>
+      val subqueries = ArrayBuffer.empty[ScalarSubquery]
+      val newExpressions = 
expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
+      if (subqueries.nonEmpty) {
+        Project(newExpressions, constructLeftJoins(child, subqueries))
+      } else {
+        p
+      }
+    case f @ Filter(condition, child) =>
+      val subqueries = ArrayBuffer.empty[ScalarSubquery]
+      val newCondition = extractCorrelatedScalarSubqueries(condition, 
subqueries)
+      if (subqueries.nonEmpty) {
+        Project(f.output, Filter(newCondition, constructLeftJoins(child, 
subqueries)))
+      } else {
+        f
+      }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc73f73/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 830a7ac..7b4615d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -109,7 +109,7 @@ case class Filter(condition: Expression, child: LogicalPlan)
 
   override protected def validConstraints: Set[Expression] = {
     val predicates = splitConjunctivePredicates(condition)
-      .filterNot(PredicateSubquery.hasPredicateSubquery)
+      .filterNot(SubqueryExpression.hasCorrelatedSubquery)
     child.constraints.union(predicates.toSet)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc73f73/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 10bff3d..2e88f61 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -111,7 +111,8 @@ class AnalysisErrorSuite extends AnalysisTest {
     "scalar subquery with 2 columns",
      testRelation.select(
        (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + 
Literal(1)).as('a)),
-     "Scalar subquery must return only one column, but got 2" :: Nil)
+       "The number of columns in the subquery (2)" ::
+       "does not match the required number of columns (1)":: Nil)
 
   errorTest(
     "scalar subquery with no column",
@@ -499,12 +500,4 @@ class AnalysisErrorSuite extends AnalysisTest {
       LocalRelation(a))
     assertAnalysisError(plan3, "Accessing outer query column is not allowed 
in" :: Nil)
   }
-
-  test("Correlated Scalar Subquery") {
-    val a = AttributeReference("a", IntegerType)()
-    val b = AttributeReference("b", IntegerType)()
-    val sub = Project(Seq(b), Filter(EqualTo(UnresolvedAttribute("a"), b), 
LocalRelation(b)))
-    val plan = Project(Seq(a, Alias(ScalarSubquery(sub), "b")()), 
LocalRelation(a))
-    assertAnalysisError(plan, "Correlated scalar subqueries are not 
supported." :: Nil)
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc73f73/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index ff3f9bb..80bb4e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -234,4 +234,51 @@ class SubquerySuite extends QueryTest with 
SharedSQLContext {
       sql("select a from l group by 1 having exists (select 1 from r where d < 
min(b))"),
       Row(null) :: Row(1) :: Row(3) :: Nil)
   }
+
+  test("correlated scalar subquery in where") {
+    checkAnswer(
+      sql("select * from l where b < (select max(d) from r where a = c)"),
+      Row(2, 1.0) :: Row(2, 1.0) :: Nil)
+  }
+
+  test("correlated scalar subquery in select") {
+    checkAnswer(
+      sql("select a, (select sum(b) from l l2 where l2.a = l1.a) sum_b from l 
l1"),
+      Row(1, 4.0) :: Row(1, 4.0) :: Row(2, 2.0) :: Row(2, 2.0) :: Row(3, 3.0) 
::
+      Row(null, null) :: Row(null, null) :: Row(6, null) :: Nil)
+  }
+
+  test("correlated scalar subquery in select (null safe)") {
+    checkAnswer(
+      sql("select a, (select sum(b) from l l2 where l2.a <=> l1.a) sum_b from 
l l1"),
+      Row(1, 4.0) :: Row(1, 4.0) :: Row(2, 2.0) :: Row(2, 2.0) :: Row(3, 3.0) 
::
+        Row(null, 5.0) :: Row(null, 5.0) :: Row(6, null) :: Nil)
+  }
+
+  test("correlated scalar subquery in aggregate") {
+    checkAnswer(
+      sql("select a, (select sum(d) from r where a = c) sum_d from l l1 group 
by 1, 2"),
+      Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, 
null) :: Nil)
+  }
+
+  test("non-aggregated correlated scalar subquery") {
+    val msg1 = intercept[AnalysisException] {
+      sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1")
+    }
+    assert(msg1.getMessage.contains("Correlated scalar subqueries must be 
Aggregated"))
+
+    val msg2 = intercept[AnalysisException] {
+      sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b 
from l l1")
+    }
+    assert(msg2.getMessage.contains(
+      "The output of a correlated scalar subquery must be aggregated"))
+  }
+
+  test("non-equal correlated scalar subquery") {
+    val msg1 = intercept[AnalysisException] {
+      sql("select a, (select b from l l2 where l2.a < l1.a) sum_b from l l1")
+    }
+    assert(msg1.getMessage.contains(
+      "The correlated scalar subquery can only contain equality predicates"))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to