This is an automated email from the ASF dual-hosted git repository.
lixiao pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new 63b7a07 [SPARK-26366][SQL] ReplaceExceptWithFilter should consider
NULL as False
63b7a07 is described below
commit 63b7a074ea2fd223b03a71588a237333ae279d1f
Author: Marco Gaido <[email protected]>
AuthorDate: Tue Dec 18 23:21:52 2018 -0800
[SPARK-26366][SQL] ReplaceExceptWithFilter should consider NULL as False
## What changes were proposed in this pull request?
In `ReplaceExceptWithFilter` we do not consider properly the case in which
the condition returns NULL. Indeed, in that case, since negating NULL still
returns NULL, so it is not true the assumption that negating the condition
returns all the rows which didn't satisfy it, rows returning NULL may not be
returned. This happens when constraints inferred by
`InferFiltersFromConstraints` are not enough, as it happens with `OR`
conditions.
The rule had also problems with non-deterministic conditions: in such a
scenario, this rule would change the probability of the output.
The PR fixes these problem by:
- returning False for the condition when it is Null (in this way we do
return all the rows which didn't satisfy it);
- avoiding any transformation when the condition is non-deterministic.
## How was this patch tested?
added UTs
Closes #23315 from mgaido91/SPARK-26366.
Authored-by: Marco Gaido <[email protected]>
Signed-off-by: gatorsmile <[email protected]>
(cherry picked from commit 834b8609793525a5a486013732d8c98e1c6e6504)
Signed-off-by: gatorsmile <[email protected]>
---
.../optimizer/ReplaceExceptWithFilter.scala | 32 +++++++++-------
.../catalyst/optimizer/ReplaceOperatorSuite.scala | 44 ++++++++++++++++------
.../scala/org/apache/spark/sql/DatasetSuite.scala | 11 ++++++
.../scala/org/apache/spark/sql/SQLQuerySuite.scala | 38 +++++++++++++++++++
4 files changed, 101 insertions(+), 24 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
index efd3944..4996d24 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
@@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
* Note:
* Before flipping the filter condition of the right node, we should:
* 1. Combine all it's [[Filter]].
- * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL
values in the condition).
+ * 2. Update the attribute references to the left node;
+ * 3. Add a Coalesce(condition, False) (to take into account of NULL values in
the condition).
*/
object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
@@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
plan.transform {
case e @ Except(left, right, false) if isEligible(left, right) =>
- val newCondition = transformCondition(left, skipProject(right))
- newCondition.map { c =>
- Distinct(Filter(Not(c), left))
- }.getOrElse {
+ val filterCondition =
combineFilters(skipProject(right)).asInstanceOf[Filter].condition
+ if (filterCondition.deterministic) {
+ transformCondition(left, filterCondition).map { c =>
+ Distinct(Filter(Not(c), left))
+ }.getOrElse {
+ e
+ }
+ } else {
e
}
}
}
- private def transformCondition(left: LogicalPlan, right: LogicalPlan):
Option[Expression] = {
- val filterCondition =
-
InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition
-
- val attributeNameMap: Map[String, Attribute] = left.output.map(x =>
(x.name, x)).toMap
-
- if (filterCondition.references.forall(r =>
attributeNameMap.contains(r.name))) {
- Some(filterCondition.transform { case a: AttributeReference =>
attributeNameMap(a.name) })
+ private def transformCondition(plan: LogicalPlan, condition: Expression):
Option[Expression] = {
+ val attributeNameMap: Map[String, Attribute] = plan.output.map(x =>
(x.name, x)).toMap
+ if (condition.references.forall(r => attributeNameMap.contains(r.name))) {
+ val rewrittenCondition = condition.transform {
+ case a: AttributeReference => attributeNameMap(a.name)
+ }
+ // We need to consider as False when the condition is NULL, otherwise we
do not return those
+ // rows containing NULL which are instead filtered in the Except right
plan
+ Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral)))
} else {
None
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index 3b1b2d5..c8e15c7 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If,
Literal, Not}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.BooleanType
class ReplaceOperatorSuite extends PlanTest {
@@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)),
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1,
Literal.FalseLiteral))),
Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze
comparePlans(optimized, correctAnswer)
@@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)), table1)).analyze
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1,
Literal.FalseLiteral))),
+ table1)).analyze
comparePlans(optimized, correctAnswer)
}
@@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)),
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1,
Literal.FalseLiteral))),
Project(Seq(attributeA, attributeB), table1))).analyze
comparePlans(optimized, correctAnswer)
@@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)),
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1,
Literal.FalseLiteral))),
Filter(attributeB === 2, Filter(attributeA === 1,
table1)))).analyze
comparePlans(optimized, correctAnswer)
@@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA === 1 && attributeB === 2)),
+ Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2,
Literal.FalseLiteral))),
Project(Seq(attributeA, attributeB),
Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze
@@ -229,4 +226,29 @@ class ReplaceOperatorSuite extends PlanTest {
comparePlans(optimized, query)
}
+
+ test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") {
+ val basePlan = LocalRelation(Seq('a.int, 'b.int))
+ val otherPlan = basePlan.where('a.in(1, 2) || 'b.in())
+ val except = Except(basePlan, otherPlan, false)
+ val result = OptimizeIn(Optimize.execute(except.analyze))
+ val correctAnswer = Aggregate(basePlan.output, basePlan.output,
+ Filter(!Coalesce(Seq(
+ 'a.in(1, 2) || If('b.isNotNull, Literal.FalseLiteral, Literal(null,
BooleanType)),
+ Literal.FalseLiteral)),
+ basePlan)).analyze
+ comparePlans(result, correctAnswer)
+ }
+
+ test("SPARK-26366: ReplaceExceptWithFilter should not transform
non-detrministic") {
+ val basePlan = LocalRelation(Seq('a.int, 'b.int))
+ val otherPlan = basePlan.where('a > rand(1L))
+ val except = Except(basePlan, otherPlan, false)
+ val result = Optimize.execute(except.analyze)
+ val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2)
=>
+ a1 <=> a2 }.reduce( _ && _)
+ val correctAnswer = Aggregate(basePlan.output, otherPlan.output,
+ Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze
+ comparePlans(result, correctAnswer)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index f6f51b5..50406bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1556,6 +1556,17 @@ class DatasetSuite extends QueryTest with
SharedSQLContext {
checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1",
BigDecimal.valueOf(1.1111))))
}
+
+ test("SPARK-26366: return nulls which are not filtered in except") {
+ val inputDF = sqlContext.createDataFrame(
+ sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))),
+ StructType(Seq(
+ StructField("a", StringType, nullable = true),
+ StructField("b", StringType, nullable = true))))
+
+ val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
+ checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
+ }
}
case class TestDataUnion(x: Int, y: Int, z: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index dbb0790..beb1753 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2909,6 +2909,44 @@ class SQLQuerySuite extends QueryTest with
SharedSQLContext {
}
}
}
+
+ test("SPARK-26366: verify ReplaceExceptWithFilter") {
+ Seq(true, false).foreach { enabled =>
+ withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) {
+ val df = spark.createDataFrame(
+ sparkContext.parallelize(Seq(Row(0, 3, 5),
+ Row(0, 3, null),
+ Row(null, 3, 5),
+ Row(0, null, 5),
+ Row(0, null, null),
+ Row(null, null, 5),
+ Row(null, 3, null),
+ Row(null, null, null))),
+ StructType(Seq(StructField("c1", IntegerType),
+ StructField("c2", IntegerType),
+ StructField("c3", IntegerType))))
+ val where = "c2 >= 3 OR c1 >= 0"
+ val whereNullSafe =
+ """
+ |(c2 IS NOT NULL AND c2 >= 3)
+ |OR (c1 IS NOT NULL AND c1 >= 0)
+ """.stripMargin
+
+ val df_a = df.filter(where)
+ val df_b = df.filter(whereNullSafe)
+ checkAnswer(df.except(df_a), df.except(df_b))
+
+ val whereWithIn = "c2 >= 3 OR c1 in (2)"
+ val whereWithInNullSafe =
+ """
+ |(c2 IS NOT NULL AND c2 >= 3)
+ """.stripMargin
+ val dfIn_a = df.filter(whereWithIn)
+ val dfIn_b = df.filter(whereWithInNullSafe)
+ checkAnswer(df.except(dfIn_a), df.except(dfIn_b))
+ }
+ }
+ }
}
case class Foo(bar: Option[String])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]