Repository: spark
Updated Branches:
  refs/heads/branch-1.6 2561976dc -> 21e63929c


[SPARK-10707][SQL] Fix nullability computation in union output

Author: Mikhail Bautin <[email protected]>

Closes #9308 from mbautin/SPARK-10707.

(cherry picked from commit 4021a28ac30b65cb61cf1e041253847253a2d89f)
Signed-off-by: Reynold Xin <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/21e63929
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/21e63929
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/21e63929

Branch: refs/heads/branch-1.6
Commit: 21e63929c96f51b31b1dd929013ff609a6ee3072
Parents: 2561976
Author: Mikhail Bautin <[email protected]>
Authored: Mon Nov 23 22:26:08 2015 -0800
Committer: Reynold Xin <[email protected]>
Committed: Mon Nov 23 22:26:14 2015 -0800

----------------------------------------------------------------------
 .../catalyst/plans/logical/basicOperators.scala | 11 +++++--
 .../spark/sql/execution/basicOperators.scala    |  9 ++++--
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 31 ++++++++++++++++++++
 3 files changed, 46 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/21e63929/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 0c44448..737e62f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -92,8 +92,10 @@ case class Filter(condition: Expression, child: LogicalPlan) 
extends UnaryNode {
 }
 
 abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends 
BinaryNode {
-  // TODO: These aren't really the same attributes as nullability etc might 
change.
-  final override def output: Seq[Attribute] = left.output
+  override def output: Seq[Attribute] =
+    left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
+      leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
+    }
 
   final override lazy val resolved: Boolean =
     childrenResolved &&
@@ -115,7 +117,10 @@ case class Union(left: LogicalPlan, right: LogicalPlan) 
extends SetOperation(lef
 
 case class Intersect(left: LogicalPlan, right: LogicalPlan) extends 
SetOperation(left, right)
 
-case class Except(left: LogicalPlan, right: LogicalPlan) extends 
SetOperation(left, right)
+case class Except(left: LogicalPlan, right: LogicalPlan) extends 
SetOperation(left, right) {
+  /** We don't use right.output because those rows get excluded from the set. 
*/
+  override def output: Seq[Attribute] = left.output
+}
 
 case class Join(
   left: LogicalPlan,

http://git-wip-us.apache.org/repos/asf/spark/blob/21e63929/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index e79092e..d57b8e7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -130,8 +130,13 @@ case class Sample(
  * Union two plans, without a distinct. This is UNION ALL in SQL.
  */
 case class Union(children: Seq[SparkPlan]) extends SparkPlan {
-  // TODO: attributes output by union should be distinct for nullability 
purposes
-  override def output: Seq[Attribute] = children.head.output
+  override def output: Seq[Attribute] = {
+    children.tail.foldLeft(children.head.output) { case (currentOutput, child) 
=>
+      currentOutput.zip(child.output).map { case (a1, a2) =>
+        a1.withNullability(a1.nullable || a2.nullable)
+      }
+    }
+  }
   override def outputsUnsafeRows: Boolean = 
children.forall(_.outputsUnsafeRows)
   override def canProcessUnsafeRows: Boolean = true
   override def canProcessSafeRows: Boolean = true

http://git-wip-us.apache.org/repos/asf/spark/blob/21e63929/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 167aea8..bb82b56 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1997,4 +1997,35 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
     sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
     verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
   }
+
+  test("SPARK-10707: nullability should be correctly propagated through set 
operations (1)") {
+    // This test produced an incorrect result of 1 before the SPARK-10707 fix 
because of the
+    // NullPropagation rule: COUNT(v) got replaced with COUNT(1) because the 
output column of
+    // UNION was incorrectly considered non-nullable:
+    checkAnswer(
+      sql("""SELECT count(v) FROM (
+            |  SELECT v FROM (
+            |    SELECT 'foo' AS v UNION ALL
+            |    SELECT NULL AS v
+            |  ) my_union WHERE isnull(v)
+            |) my_subview""".stripMargin),
+      Seq(Row(0)))
+  }
+
+  test("SPARK-10707: nullability should be correctly propagated through set 
operations (2)") {
+    // This test uses RAND() to stop column pruning for Union and checks the 
resulting isnull
+    // value. This would produce an incorrect result before the fix in 
SPARK-10707 because the "v"
+    // column of the union was considered non-nullable.
+    checkAnswer(
+      sql(
+        """
+          |SELECT a FROM (
+          |  SELECT ISNULL(v) AS a, RAND() FROM (
+          |    SELECT 'foo' AS v UNION ALL SELECT null AS v
+          |  ) my_union
+          |) my_view
+        """.stripMargin),
+      Row(false) :: Row(true) :: Nil)
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to