Repository: spark Updated Branches: refs/heads/master 246012859 -> 01a7d33d0
[SPARK-18711][SQL] should disable subexpression elimination for LambdaVariable ## What changes were proposed in this pull request? This is kind of a long-standing bug, it's hidden until https://github.com/apache/spark/pull/15780 , which may add `AssertNotNull` on top of `LambdaVariable` and thus enables subexpression elimination. However, subexpression elimination will evaluate the common expressions at the beginning, which is invalid for `LambdaVariable`. `LambdaVariable` usually represents loop variable, which can't be evaluated ahead of the loop. This PR skips expressions containing `LambdaVariable` when doing subexpression elimination. ## How was this patch tested? updated test in `DatasetAggregatorSuite` Author: Wenchen Fan <[email protected]> Closes #16143 from cloud-fan/aggregator. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/01a7d33d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/01a7d33d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/01a7d33d Branch: refs/heads/master Commit: 01a7d33d0851d82fd1bb477a58d9925fe8d727d8 Parents: 2460128 Author: Wenchen Fan <[email protected]> Authored: Mon Dec 5 11:37:13 2016 -0800 Committer: Herman van Hovell <[email protected]> Committed: Mon Dec 5 11:37:13 2016 -0800 ---------------------------------------------------------------------- .../sql/catalyst/expressions/EquivalentExpressions.scala | 6 +++++- .../scala/org/apache/spark/sql/DatasetAggregatorSuite.scala | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/01a7d33d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index b8e2b67..6c246a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable /** * This class is used to compute equality of (sub)expression trees. Expressions can be added @@ -72,7 +73,10 @@ class EquivalentExpressions { root: Expression, ignoreLeaf: Boolean = true, skipReferenceToExpressions: Boolean = true): Unit = { - val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf + val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) || + // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the + // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. + root.find(_.isInstanceOf[LambdaVariable]).isDefined // There are some special expressions that we should not recurse into children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) // 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination. http://git-wip-us.apache.org/repos/asf/spark/blob/01a7d33d/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 36b2651..0e7eaa9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -92,13 +92,13 @@ object NameAgg extends Aggregator[AggData, String, String] { } -object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] { +object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[(Int, Int)]] { def zero: Seq[Int] = Nil def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2 - def finish(r: Seq[Int]): Seq[Int] = r + def finish(r: Seq[Int]): Seq[(Int, Int)] = r.map(i => i -> i) override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder() - override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder() + override def outputEncoder: Encoder[Seq[(Int, Int)]] = ExpressionEncoder() } @@ -281,7 +281,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { checkDataset( ds.groupByKey(_.b).agg(SeqAgg.toColumn), - "a" -> Seq(1, 2) + "a" -> Seq(1 -> 1, 2 -> 2) ) } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
