This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 62956c92cfc7 [SPARK-46707][SQL] Added throwable field to expressions
to improve predicate pushdown
62956c92cfc7 is described below
commit 62956c92cfc74d7523328d168b6d837938cde763
Author: Kelvin Jiang <[email protected]>
AuthorDate: Thu Jan 18 19:25:24 2024 +0800
[SPARK-46707][SQL] Added throwable field to expressions to improve
predicate pushdown
### What changes were proposed in this pull request?
This PR adds the field `throwable` to `Expression`. If an expression is
marked as throwable, we will avoid pushing filters containing these expressions
through joins, filters, and aggregations (i.e. operators that filter input).
### Why are the changes needed?
For predicate pushdown, currently it is possible that we push down a filter
that ends up being evaluated on more rows than before it was pushed down (e.g.
if we push the filter through a selective join). In this case, it is possible
that we now evaluate the filter on a row that will cause a runtime error to be
thrown, when prior to pushing this would not have happened.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added UTs.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44716 from kelvinjian-db/SPARK-46707-throwable.
Authored-by: Kelvin Jiang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/expressions/Expression.scala | 5 ++
.../expressions/collectionOperations.scala | 3 ++
.../spark/sql/catalyst/optimizer/Optimizer.scala | 27 +++++-----
.../catalyst/optimizer/FilterPushdownSuite.scala | 63 ++++++++++++++++++++++
4 files changed, 84 insertions(+), 14 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 2cc813bd3055..484418f5e5a7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -140,6 +140,11 @@ abstract class Expression extends TreeNode[Expression] {
*/
def stateful: Boolean = false
+ /**
+ * Returns true if the expression could potentially throw an exception when
evaluated.
+ */
+ lazy val throwable: Boolean = children.exists(_.throwable)
+
/**
* Returns a copy of this expression where all stateful expressions are
replaced with fresh
* uninitialized copies. If the expression contains no stateful expressions
then the original
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 04f56eaf8c1e..5aa96dd1a6aa 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -2983,6 +2983,9 @@ case class Sequence(
override def nullable: Boolean = children.exists(_.nullable)
+ // If step is defined, then an error will be thrown if the start and stop do
not satisfy the step.
+ override lazy val throwable: Boolean = stepOpt.isDefined
+
override def dataType: ArrayType = ArrayType(start.dataType, containsNull =
false)
override def checkInputDataTypes(): TypeCheckResult = {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 8fcc7c7c26b4..4186c8c1db91 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1549,10 +1549,11 @@ object CombineFilters extends Rule[LogicalPlan] with
PredicateHelper {
val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
// The query execution/optimization does not guarantee the expressions are
evaluated in order.
- // We only can combine them if and only if both are deterministic.
+ // We only can combine them if and only if both are deterministic and the
outer condition is not
+ // throwable (inner can be throwable as it was going to be evaluated first
anyways).
case Filter(fc, nf @ Filter(nc, grandChild)) if nc.deterministic =>
- val (combineCandidates, nonDeterministic) =
- splitConjunctivePredicates(fc).partition(_.deterministic)
+ val (combineCandidates, rest) =
+ splitConjunctivePredicates(fc).partition(p => p.deterministic &&
!p.throwable)
val mergedFilter = (ExpressionSet(combineCandidates) --
ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And) match
{
case Some(ac) =>
@@ -1560,7 +1561,7 @@ object CombineFilters extends Rule[LogicalPlan] with
PredicateHelper {
case None =>
nf
}
- nonDeterministic.reduceOption(And).map(c => Filter(c,
mergedFilter)).getOrElse(mergedFilter)
+ rest.reduceOption(And).map(c => Filter(c,
mergedFilter)).getOrElse(mergedFilter)
}
}
@@ -1730,16 +1731,12 @@ object PushPredicateThroughNonJoin extends
Rule[LogicalPlan] with PredicateHelpe
// For each filter, expand the alias and check if the filter can be
evaluated using
// attributes produced by the aggregate operator's child operator.
- val (candidates, nonDeterministic) =
- splitConjunctivePredicates(condition).partition(_.deterministic)
-
- val (pushDown, rest) = candidates.partition { cond =>
+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition
{ cond =>
val replaced = replaceAlias(cond, aliasMap)
- cond.references.nonEmpty &&
replaced.references.subsetOf(aggregate.child.outputSet)
+ cond.deterministic && !cond.throwable &&
+ cond.references.nonEmpty &&
replaced.references.subsetOf(aggregate.child.outputSet)
}
- val stayUp = rest ++ nonDeterministic
-
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val replaced = replaceAlias(pushDownPredicate, aliasMap)
@@ -1904,13 +1901,14 @@ object PushPredicateThroughJoin extends
Rule[LogicalPlan] with PredicateHelper {
* @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth)
*/
private def split(condition: Seq[Expression], left: LogicalPlan, right:
LogicalPlan) = {
- val (pushDownCandidates, nonDeterministic) =
condition.partition(_.deterministic)
+ val (pushDownCandidates, stayUp) =
+ condition.partition(cond => cond.deterministic && !cond.throwable)
val (leftEvaluateCondition, rest) =
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
val (rightEvaluateCondition, commonCondition) =
rest.partition(expr => expr.references.subsetOf(right.outputSet))
- (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++
nonDeterministic)
+ (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ stayUp)
}
private def canPushThrough(joinType: JoinType): Boolean = joinType match {
@@ -1933,8 +1931,9 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan]
with PredicateHelper {
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+ // don't push throwable expressions into join condition
val (newJoinConditions, others) =
- commonFilterCondition.partition(canEvaluateWithinJoin)
+ commonFilterCondition.partition(cond =>
canEvaluateWithinJoin(cond) && !cond.throwable)
val newJoinCond = (newJoinConditions ++
joinCondition).reduceLeftOption(And)
val join = Join(newLeft, newRight, joinType, newJoinCond, hint)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 2ebb43d4fba3..bd2ac28a049f 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -1433,4 +1433,67 @@ class FilterPushdownSuite extends PlanTest {
val correctAnswer = RebalancePartitions(Seq.empty, testRelation.where($"a"
> 3)).analyze
comparePlans(optimized, correctAnswer)
}
+
+ test("SPARK-46707: push down predicate with sequence (without step) through
joins") {
+ val x = testRelation.subquery("x")
+ val y = testRelation1.subquery("y")
+
+ // do not push down when sequence has step param
+ val queryWithStep = x.join(y, joinType = Inner, condition = Some($"x.c"
=== $"y.d"))
+ .where(IsNotNull(Sequence($"x.a", $"x.b", Some(Literal(1)))))
+ .analyze
+ val optimizedQueryWithStep = Optimize.execute(queryWithStep)
+ comparePlans(optimizedQueryWithStep, queryWithStep)
+
+ // push down when sequence does not have step param
+ val queryWithoutStep = x.join(y, joinType = Inner, condition = Some($"x.c"
=== $"y.d"))
+ .where(IsNotNull(Sequence($"x.a", $"x.b", None)))
+ .analyze
+ val optimizedQueryWithoutStep = Optimize.execute(queryWithoutStep)
+ val correctAnswer = x.where(IsNotNull(Sequence($"x.a", $"x.b", None)))
+ .join(y, joinType = Inner, condition = Some($"x.c" === $"y.d"))
+ .analyze
+ comparePlans(optimizedQueryWithoutStep, correctAnswer)
+ }
+
+ test("SPARK-46707: push down predicate with sequence (without step) through
aggregates") {
+ val x = testRelation.subquery("x")
+
+ // do not push down when sequence has step param
+ val queryWithStep = x.groupBy($"x.a", $"x.b")($"x.a", $"x.b")
+ .where(IsNotNull(Sequence($"x.a", $"x.b", Some(Literal(1)))))
+ .analyze
+ val optimizedQueryWithStep = Optimize.execute(queryWithStep)
+ comparePlans(optimizedQueryWithStep, queryWithStep)
+
+ // push down when sequence does not have step param
+ val queryWithoutStep = x.groupBy($"x.a", $"x.b")($"x.a", $"x.b")
+ .where(IsNotNull(Sequence($"x.a", $"x.b", None)))
+ .analyze
+ val optimizedQueryWithoutStep = Optimize.execute(queryWithoutStep)
+ val correctAnswer = x.where(IsNotNull(Sequence($"x.a", $"x.b", None)))
+ .groupBy($"x.a", $"x.b")($"x.a", $"x.b")
+ .analyze
+ comparePlans(optimizedQueryWithoutStep, correctAnswer)
+ }
+
+ test("SPARK-46707: combine predicate with sequence (without step) with other
filters") {
+ val x = testRelation.subquery("x")
+
+ // do not combine when sequence has step param
+ val queryWithStep = x.where($"x.c" > 1)
+ .where(IsNotNull(Sequence($"x.a", $"x.b", Some(Literal(1)))))
+ .analyze
+ val optimizedQueryWithStep = Optimize.execute(queryWithStep)
+ comparePlans(optimizedQueryWithStep, queryWithStep)
+
+ // combine when sequence does not have step param
+ val queryWithoutStep = x.where($"x.c" > 1)
+ .where(IsNotNull(Sequence($"x.a", $"x.b", None)))
+ .analyze
+ val optimizedQueryWithoutStep = Optimize.execute(queryWithoutStep)
+ val correctAnswer = x.where(IsNotNull(Sequence($"x.a", $"x.b", None)) &&
$"x.c" > 1)
+ .analyze
+ comparePlans(optimizedQueryWithoutStep, correctAnswer)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]