Repository: spark
Updated Branches:
  refs/heads/master e249e6f8b -> 222dcf793


[SPARK-12660][SPARK-14967][SQL] Implement Except Distinct by Left Anti Join

#### What changes were proposed in this pull request?
Replaces a logical `Except` operator with a `Left-anti Join` operator. This 
way, we can take advantage of all the benefits of join implementations (e.g. 
managed memory, code generation, broadcast joins).
```SQL
  SELECT a1, a2 FROM Tab1 EXCEPT SELECT b1, b2 FROM Tab2
  ==>  SELECT DISTINCT a1, a2 FROM Tab1 LEFT ANTI JOIN Tab2 ON a1<=>b1 AND 
a2<=>b2
```
 Note:
 1. This rule is only applicable to EXCEPT DISTINCT. Do not use it for EXCEPT 
ALL.
 2. This rule has to be done after de-duplicating the attributes; otherwise, 
the enerated
    join conditions will be incorrect.

This PR also corrects the existing behavior in Spark. Before this PR, the 
behavior is like
```SQL
  test("except") {
    val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id")
    val df_right = Seq(1, 3).toDF("id")

    checkAnswer(
      df_left.except(df_right),
      Row(2) :: Row(2) :: Row(4) :: Nil
    )
  }
```
After this PR, the result is corrected. We strictly follow the SQL compliance 
of `Except Distinct`.

#### How was this patch tested?
Modified and added a few test cases to verify the optimization rule and the 
results of operators.

Author: gatorsmile <[email protected]>

Closes #12736 from gatorsmile/exceptByAntiJoin.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/222dcf79
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/222dcf79
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/222dcf79

Branch: refs/heads/master
Commit: 222dcf79377df33007d7a9780dafa2c740dbe6a3
Parents: e249e6f
Author: gatorsmile <[email protected]>
Authored: Fri Apr 29 15:30:36 2016 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Fri Apr 29 15:30:36 2016 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  2 +
 .../sql/catalyst/analysis/CheckAnalysis.scala   | 11 ++-
 .../sql/catalyst/optimizer/Optimizer.scala      | 60 ++++++++---------
 .../plans/logical/basicLogicalOperators.scala   |  6 +-
 .../analysis/HiveTypeCoercionSuite.scala        |  8 ---
 .../optimizer/FilterPushdownSuite.scala         | 34 ----------
 .../optimizer/ReplaceOperatorSuite.scala        | 17 ++++-
 .../catalyst/optimizer/SetOperationSuite.scala  | 16 -----
 .../spark/sql/execution/SparkStrategies.scala   |  5 +-
 .../sql/execution/basicPhysicalOperators.scala  | 12 ----
 .../org/apache/spark/sql/JavaDatasetSuite.java  |  2 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   | 70 ++++++++++++++++++--
 12 files changed, 132 insertions(+), 111 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e37d976..f6a65f7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -530,6 +530,8 @@ class Analyzer(
         j.copy(right = dedupRight(left, right))
       case i @ Intersect(left, right) if !i.duplicateResolved =>
         i.copy(right = dedupRight(left, right))
+      case i @ Except(left, right) if !i.duplicateResolved =>
+        i.copy(right = dedupRight(left, right))
 
       // When resolve `SortOrder`s in Sort based on child, don't report errors 
as
       // we still have chance to resolve it based on its descendants

http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 6b737d6..74f434e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -283,7 +283,16 @@ trait CheckAnalysis extends PredicateHelper {
                  |Failure when resolving conflicting references in Intersect:
                  |$plan
                  |Conflicting attributes: 
${conflictingAttributes.mkString(",")}
-                 |""".stripMargin)
+               """.stripMargin)
+
+          case e: Except if !e.duplicateResolved =>
+            val conflictingAttributes = 
e.left.outputSet.intersect(e.right.outputSet)
+            failAnalysis(
+              s"""
+                 |Failure when resolving conflicting references in Except:
+                 |$plan
+                 |Conflicting attributes: 
${conflictingAttributes.mkString(",")}
+               """.stripMargin)
 
           case o if !o.resolved =>
             failAnalysis(

http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 54bf4a5..434c033 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -65,6 +65,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, 
conf: CatalystConf)
       CombineUnions) ::
     Batch("Replace Operators", fixedPoint,
       ReplaceIntersectWithSemiJoin,
+      ReplaceExceptWithAntiJoin,
       ReplaceDistinctWithAggregate) ::
     Batch("Aggregate", fixedPoint,
       RemoveLiteralFromGroupExpressions) ::
@@ -232,17 +233,12 @@ object LimitPushDown extends Rule[LogicalPlan] {
 }
 
 /**
- * Pushes certain operations to both sides of a Union or Except operator.
+ * Pushes certain operations to both sides of a Union operator.
  * Operations that are safe to pushdown are listed as follows.
  * Union:
  * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it 
is
  * safe to pushdown Filters and Projections through it. Once we add UNION 
DISTINCT,
  * we will not be able to pushdown Projections.
- *
- * Except:
- * It is not safe to pushdown Projections through it because we need to get the
- * intersect of rows by comparing the entire rows. It is fine to pushdown 
Filters
- * with deterministic condition.
  */
 object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
 
@@ -310,17 +306,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with 
PredicateHelper {
         Filter(pushToRight(deterministic, rewrites), child)
       }
       Filter(nondeterministic, Union(newFirstChild +: newOtherChildren))
-
-    // Push down filter through EXCEPT
-    case Filter(condition, Except(left, right)) =>
-      val (deterministic, nondeterministic) = 
partitionByDeterministic(condition)
-      val rewrites = buildRewrites(left, right)
-      Filter(nondeterministic,
-        Except(
-          Filter(deterministic, left),
-          Filter(pushToRight(deterministic, rewrites), right)
-        )
-      )
   }
 }
 
@@ -1007,16 +992,15 @@ object PushDownPredicate extends Rule[LogicalPlan] with 
PredicateHelper {
         filter
       }
 
-    case filter @ Filter(condition, child)
-      if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] =>
-      // Union/Intersect could change the rows, so non-deterministic predicate 
can't be pushed down
+    case filter @ Filter(condition, union: Union) =>
+      // Union could change the rows, so non-deterministic predicate can't be 
pushed down
       val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition 
{ cond =>
         cond.deterministic
       }
       if (pushDown.nonEmpty) {
         val pushDownCond = pushDown.reduceLeft(And)
-        val output = child.output
-        val newGrandChildren = child.children.map { grandchild =>
+        val output = union.output
+        val newGrandChildren = union.children.map { grandchild =>
           val newCond = pushDownCond transform {
             case e if output.exists(_.semanticEquals(e)) =>
               grandchild.output(output.indexWhere(_.semanticEquals(e)))
@@ -1024,21 +1008,16 @@ object PushDownPredicate extends Rule[LogicalPlan] with 
PredicateHelper {
           assert(newCond.references.subsetOf(grandchild.outputSet))
           Filter(newCond, grandchild)
         }
-        val newChild = child.withNewChildren(newGrandChildren)
+        val newUnion = union.withNewChildren(newGrandChildren)
         if (stayUp.nonEmpty) {
-          Filter(stayUp.reduceLeft(And), newChild)
+          Filter(stayUp.reduceLeft(And), newUnion)
         } else {
-          newChild
+          newUnion
         }
       } else {
         filter
       }
 
-    case filter @ Filter(condition, e @ Except(left, _)) =>
-      pushDownPredicate(filter, e.left) { predicate =>
-        e.copy(left = Filter(predicate, left))
-      }
-
     // two filters should be combine together by other rules
     case filter @ Filter(_, f: Filter) => filter
     // should not push predicates through sample, or will generate different 
results.
@@ -1423,6 +1402,27 @@ object ReplaceIntersectWithSemiJoin extends 
Rule[LogicalPlan] {
 }
 
 /**
+ * Replaces logical [[Except]] operator with a left-anti [[Join]] operator.
+ * {{{
+ *   SELECT a1, a2 FROM Tab1 EXCEPT SELECT b1, b2 FROM Tab2
+ *   ==>  SELECT DISTINCT a1, a2 FROM Tab1 LEFT ANTI JOIN Tab2 ON a1<=>b1 AND 
a2<=>b2
+ * }}}
+ *
+ * Note:
+ * 1. This rule is only applicable to EXCEPT DISTINCT. Do not use it for 
EXCEPT ALL.
+ * 2. This rule has to be done after de-duplicating the attributes; otherwise, 
the generated
+ *    join conditions will be incorrect.
+ */
+object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case Except(left, right) =>
+      assert(left.output.size == right.output.size)
+      val joinCond = left.output.zip(right.output).map { case (l, r) => 
EqualNullSafe(l, r) }
+      Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And)))
+  }
+}
+
+/**
  * Removes literals from group expressions in [[Aggregate]], as they have no 
effect to the result
  * but only makes the grouping key bigger.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index a445ce6..b358e21 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -165,6 +165,9 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) 
extends SetOperation
 }
 
 case class Except(left: LogicalPlan, right: LogicalPlan) extends 
SetOperation(left, right) {
+
+  def duplicateResolved: Boolean = 
left.outputSet.intersect(right.outputSet).isEmpty
+
   /** We don't use right.output because those rows get excluded from the set. 
*/
   override def output: Seq[Attribute] = left.output
 
@@ -173,7 +176,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) 
extends SetOperation(le
   override lazy val resolved: Boolean =
     childrenResolved &&
       left.output.length == right.output.length &&
-      left.output.zip(right.output).forall { case (l, r) => l.dataType == 
r.dataType }
+      left.output.zip(right.output).forall { case (l, r) => l.dataType == 
r.dataType } &&
+      duplicateResolved
 
   override def statistics: Statistics = {
     Statistics(sizeInBytes = left.statistics.sizeInBytes)

http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 18de8b1..b591861 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -488,14 +488,6 @@ class HiveTypeCoercionSuite extends PlanTest {
     assert(r1.right.isInstanceOf[Project])
     assert(r2.left.isInstanceOf[Project])
     assert(r2.right.isInstanceOf[Project])
-
-    val r3 = wt(Except(firstTable, firstTable)).asInstanceOf[Except]
-    checkOutput(r3.left, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, 
ByteType, DoubleType))
-    checkOutput(r3.right, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, 
ByteType, DoubleType))
-
-    // Check if no Project is added
-    assert(r3.left.isInstanceOf[LocalRelation])
-    assert(r3.right.isInstanceOf[LocalRelation])
   }
 
   test("WidenSetOperationTypes for union") {

http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index e2cc80c..e9b4bb0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -710,40 +710,6 @@ class FilterPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("intersect") {
-    val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
-
-    val originalQuery = Intersect(testRelation, testRelation2)
-      .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-
-    val correctAnswer = Intersect(
-      testRelation.where('a === 2L),
-      testRelation2.where('d === 2L))
-      .where('b + Rand(10).as("rnd") === 3)
-      .analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
-  test("except") {
-    val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
-
-    val originalQuery = Except(testRelation, testRelation2)
-      .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-
-    val correctAnswer = Except(
-      testRelation.where('a === 2L),
-      testRelation2)
-      .where('b + Rand(10).as("rnd") === 3)
-      .analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
   test("expand") {
     val agg = testRelation
       .groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c))

http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index f8ae5d9..f23e262 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
+import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 
@@ -29,6 +29,7 @@ class ReplaceOperatorSuite extends PlanTest {
     val batches =
       Batch("Replace Operators", FixedPoint(100),
         ReplaceDistinctWithAggregate,
+        ReplaceExceptWithAntiJoin,
         ReplaceIntersectWithSemiJoin) :: Nil
   }
 
@@ -46,6 +47,20 @@ class ReplaceOperatorSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
+  test("replace Except with Left-anti Join") {
+    val table1 = LocalRelation('a.int, 'b.int)
+    val table2 = LocalRelation('c.int, 'd.int)
+
+    val query = Except(table1, table2)
+    val optimized = Optimize.execute(query.analyze)
+
+    val correctAnswer =
+      Aggregate(table1.output, table1.output,
+        Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd))).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
   test("replace Distinct with Aggregate") {
     val input = LocalRelation('a.int, 'b.int)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index b08cdc8..83ca9d5 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -39,7 +39,6 @@ class SetOperationSuite extends PlanTest {
   val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
   val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int)
   val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil)
-  val testExcept = Except(testRelation, testRelation2)
 
   test("union: combine unions into one unions") {
     val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation)
@@ -56,15 +55,6 @@ class SetOperationSuite extends PlanTest {
     comparePlans(combinedUnionsOptimized, unionOptimized3)
   }
 
-  test("except: filter to each side") {
-    val exceptQuery = testExcept.where('c >= 5)
-    val exceptOptimized = Optimize.execute(exceptQuery.analyze)
-    val exceptCorrectAnswer =
-      Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze
-
-    comparePlans(exceptOptimized, exceptCorrectAnswer)
-  }
-
   test("union: filter to each side") {
     val unionQuery = testUnion.where('a === 1)
     val unionOptimized = Optimize.execute(unionQuery.analyze)
@@ -85,10 +75,4 @@ class SetOperationSuite extends PlanTest {
         testRelation3.select('g) :: Nil).analyze
     comparePlans(unionOptimized, unionCorrectAnswer)
   }
-
-  test("SPARK-10539: Project should not be pushed down through Intersect or 
Except") {
-    val exceptQuery = testExcept.select('a, 'b, 'c)
-    val exceptOptimized = Optimize.execute(exceptQuery.analyze)
-    comparePlans(exceptOptimized, exceptQuery.analyze)
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3955c5d..1eb1f8e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -297,6 +297,9 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case logical.Intersect(left, right) =>
         throw new IllegalStateException(
           "logical intersect operator should have been replaced by semi-join 
in the optimizer")
+      case logical.Except(left, right) =>
+        throw new IllegalStateException(
+          "logical except operator should have been replaced by anti-join in 
the optimizer")
 
       case logical.DeserializeToObject(deserializer, objAttr, child) =>
         execution.DeserializeToObject(deserializer, objAttr, planLater(child)) 
:: Nil
@@ -347,8 +350,6 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.GlobalLimitExec(limit, planLater(child)) :: Nil
       case logical.Union(unionChildren) =>
         execution.UnionExec(unionChildren.map(planLater)) :: Nil
-      case logical.Except(left, right) =>
-        execution.ExceptExec(planLater(left), planLater(right)) :: Nil
       case g @ logical.Generate(generator, join, outer, _, _, child) =>
         execution.GenerateExec(
           generator, join = join, outer = outer, g.output, planLater(child)) 
:: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 77be613..d492fa7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -492,18 +492,6 @@ case class CoalesceExec(numPartitions: Int, child: 
SparkPlan) extends UnaryExecN
 }
 
 /**
- * Physical plan for returning a table with the elements from left that are 
not in right using
- * the built-in spark subtract function.
- */
-case class ExceptExec(left: SparkPlan, right: SparkPlan) extends 
BinaryExecNode {
-  override def output: Seq[Attribute] = left.output
-
-  protected override def doExecute(): RDD[InternalRow] = {
-    left.execute().map(_.copy()).subtract(right.execute().map(_.copy()))
-  }
-}
-
-/**
  * A plan node that does nothing but lie about the output of its child.  Used 
to spice a
  * (hopefully structurally equivalent) tree from a different optimization 
sequence into an already
  * resolved tree.

http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 5abd62c..f1b1c22 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -291,7 +291,7 @@ public class JavaDatasetSuite implements Serializable {
       unioned.collectAsList());
 
     Dataset<String> subtracted = ds.except(ds2);
-    Assert.assertEquals(Arrays.asList("abc", "abc"), 
subtracted.collectAsList());
+    Assert.assertEquals(Arrays.asList("abc"), subtracted.collectAsList());
   }
 
   private static <T> Set<T> toSet(List<T> records) {

http://git-wip-us.apache.org/repos/asf/spark/blob/222dcf79/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 681476b..f10d837 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -398,6 +398,66 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
       Row(4, "d") :: Nil)
     checkAnswer(lowerCaseData.except(lowerCaseData), Nil)
     checkAnswer(upperCaseData.except(upperCaseData), Nil)
+
+    // check null equality
+    checkAnswer(
+      nullInts.except(nullInts.filter("0 = 1")),
+      nullInts)
+    checkAnswer(
+      nullInts.except(nullInts),
+      Nil)
+
+    // check if values are de-duplicated
+    checkAnswer(
+      allNulls.except(allNulls.filter("0 = 1")),
+      Row(null) :: Nil)
+    checkAnswer(
+      allNulls.except(allNulls),
+      Nil)
+
+    // check if values are de-duplicated
+    val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", 
"value")
+    checkAnswer(
+      df.except(df.filter("0 = 1")),
+      Row("id1", 1) ::
+      Row("id", 1) ::
+      Row("id1", 2) :: Nil)
+
+    // check if the empty set on the left side works
+    checkAnswer(
+      allNulls.filter("0 = 1").except(allNulls),
+      Nil)
+  }
+
+  test("except distinct - SQL compliance") {
+    val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id")
+    val df_right = Seq(1, 3).toDF("id")
+
+    checkAnswer(
+      df_left.except(df_right),
+      Row(2) :: Row(4) :: Nil
+    )
+  }
+
+  test("except - nullability") {
+    val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF()
+    assert(nonNullableInts.schema.forall(!_.nullable))
+
+    val df1 = nonNullableInts.except(nullInts)
+    checkAnswer(df1, Row(11) :: Nil)
+    assert(df1.schema.forall(!_.nullable))
+
+    val df2 = nullInts.except(nonNullableInts)
+    checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil)
+    assert(df2.schema.forall(_.nullable))
+
+    val df3 = nullInts.except(nullInts)
+    checkAnswer(df3, Nil)
+    assert(df3.schema.forall(_.nullable))
+
+    val df4 = nonNullableInts.except(nonNullableInts)
+    checkAnswer(df4, Nil)
+    assert(df4.schema.forall(!_.nullable))
   }
 
   test("intersect") {
@@ -433,23 +493,23 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
 
   test("intersect - nullability") {
     val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF()
-    assert(nonNullableInts.schema.forall(_.nullable == false))
+    assert(nonNullableInts.schema.forall(!_.nullable))
 
     val df1 = nonNullableInts.intersect(nullInts)
     checkAnswer(df1, Row(1) :: Row(3) :: Nil)
-    assert(df1.schema.forall(_.nullable == false))
+    assert(df1.schema.forall(!_.nullable))
 
     val df2 = nullInts.intersect(nonNullableInts)
     checkAnswer(df2, Row(1) :: Row(3) :: Nil)
-    assert(df2.schema.forall(_.nullable == false))
+    assert(df2.schema.forall(!_.nullable))
 
     val df3 = nullInts.intersect(nullInts)
     checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
-    assert(df3.schema.forall(_.nullable == true))
+    assert(df3.schema.forall(_.nullable))
 
     val df4 = nonNullableInts.intersect(nonNullableInts)
     checkAnswer(df4, Row(1) :: Row(3) :: Nil)
-    assert(df4.schema.forall(_.nullable == false))
+    assert(df4.schema.forall(!_.nullable))
   }
 
   test("udf") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to