This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 01c294b05f3a [SPARK-45760][SQL] Add With expression to avoid
duplicating expressions
01c294b05f3a is described below
commit 01c294b05f3a9b7bd87cda0ee8b0160f5f58bb24
Author: Wenchen Fan <[email protected]>
AuthorDate: Wed Nov 8 00:57:31 2023 +0800
[SPARK-45760][SQL] Add With expression to avoid duplicating expressions
### What changes were proposed in this pull request?
Sometimes we need to duplicate expressions when rewriting the plan. It's OK
for small query, as codegen has common-subexpression-elimination (CSE) to avoid
evaluating the same expression. However, when the query is big, duplicating
expressions can lead to a very big expression tree and make catalyst rules very
slow, or even OOM when updating a leaf node (need to copy all tree nodes).
This PR introduces a new expression to do expression-level CTE: it adds a
Project to pre-evaluate the common expressions, so that they appear only once
on the query plan tree, and are evaluated only once. `NullIf` now uses this new
expression to avoid duplicating the `left` child expression.
### Why are the changes needed?
make catalyst more efficient.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
new test suite
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43623 from cloud-fan/with.
Lead-authored-by: Wenchen Fan <[email protected]>
Co-authored-by: Peter Toth <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../explain-results/function_count_if.explain | 5 +-
.../explain-results/function_regexp_substr.explain | 5 +-
.../sql/connect/ProtoToParsedPlanTestSuite.scala | 15 +-
.../spark/sql/catalyst/expressions/With.scala | 63 +++++++++
.../sql/catalyst/expressions/nullExpressions.scala | 6 +-
.../spark/sql/catalyst/optimizer/Optimizer.scala | 3 +
.../catalyst/optimizer/RewriteWithExpression.scala | 90 ++++++++++++
.../spark/sql/catalyst/trees/TreePatterns.scala | 2 +
.../optimizer/RewriteWithExpressionSuite.scala | 157 +++++++++++++++++++++
9 files changed, 338 insertions(+), 8 deletions(-)
diff --git
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
index 1c23bbf6bce5..f2ada15eccb7 100644
---
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
+++
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
@@ -1,2 +1,3 @@
-Aggregate [count(if (((a#0 > 0) = false)) null else (a#0 > 0)) AS count_if((a
> 0))#0L]
-+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
+Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0)
AS count_if((a > 0))#0L]
++- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0]
+ +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain
index 69fc760c8291..1811f770f829 100644
---
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain
+++
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain
@@ -1,2 +1,3 @@
-Project [if ((regexp_extract(g#0, \d{2}(a|b|m), 0) = )) null else
regexp_extract(g#0, \d{2}(a|b|m), 0) AS regexp_substr(g, \d{2}(a|b|m))#0]
-+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
+Project [if ((_common_expr_0#0 = )) null else _common_expr_0#0 AS
regexp_substr(g, \d{2}(a|b|m))#0]
++- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, regexp_extract(g#0,
\d{2}(a|b|m), 0) AS _common_expr_0#0]
+ +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
index 9fdaffcba670..e0c4e21503e9 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
@@ -29,7 +29,9 @@ import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.{catalog, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{caseSensitiveResolution,
Analyzer, FunctionRegistry, Resolver, TableFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
-import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions
+import org.apache.spark.sql.catalyst.optimizer.{ReplaceExpressions,
RewriteWithExpression}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.SessionHolder
@@ -181,8 +183,15 @@ class ProtoToParsedPlanTestSuite
val planner = new SparkConnectPlanner(SessionHolder.forTesting(spark))
val catalystPlan =
analyzer.executeAndCheck(planner.transformRelation(relation), new
QueryPlanningTracker)
- val actual =
-
removeMemoryAddress(normalizeExprIds(ReplaceExpressions(catalystPlan)).treeString)
+ val finalAnalyzedPlan = {
+ object Helper extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Finish Analysis", Once, ReplaceExpressions) ::
+ Batch("Rewrite With expression", Once, RewriteWithExpression) ::
Nil
+ }
+ Helper.execute(catalystPlan)
+ }
+ val actual =
removeMemoryAddress(normalizeExprIds(finalAnalyzedPlan).treeString)
val goldenFile =
goldenFilePath.resolve(relativePath).getParent.resolve(name + ".explain")
Try(readGoldenFile(goldenFile)) match {
case Success(expected) if expected == actual => // Test passes.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
new file mode 100644
index 000000000000..bfed63af1740
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF,
TreePattern, WITH_EXPRESSION}
+import org.apache.spark.sql.types.DataType
+
+/**
+ * An expression holder that keeps a list of common expressions and allow the
actual expression to
+ * reference these common expressions. The common expressions are guaranteed
to be evaluated only
+ * once even if it's referenced more than once. This is similar to CTE but is
expression-level.
+ */
+case class With(child: Expression, defs: Seq[CommonExpressionDef])
+ extends Expression with Unevaluable {
+ override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION)
+ override def dataType: DataType = child.dataType
+ override def nullable: Boolean = child.nullable
+ override def children: Seq[Expression] = child +: defs
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ copy(child = newChildren.head, defs =
newChildren.tail.map(_.asInstanceOf[CommonExpressionDef]))
+ }
+}
+
+/**
+ * A wrapper of common expression to carry the id.
+ */
+case class CommonExpressionDef(child: Expression, id: Long =
CommonExpressionDef.newId)
+ extends UnaryExpression with Unevaluable {
+ override def dataType: DataType = child.dataType
+ override protected def withNewChildInternal(newChild: Expression):
Expression =
+ copy(child = newChild)
+}
+
+/**
+ * A reference to the common expression by its id. Only resolved common
expressions can be
+ * referenced, so that we can determine the data type and nullable of the
reference node.
+ */
+case class CommonExpressionRef(id: Long, dataType: DataType, nullable: Boolean)
+ extends LeafExpression with Unevaluable {
+ def this(exprDef: CommonExpressionDef) = this(exprDef.id, exprDef.dataType,
exprDef.nullable)
+ override val nodePatterns: Seq[TreePattern] = Seq(COMMON_EXPR_REF)
+}
+
+object CommonExpressionDef {
+ private[sql] val curId = new java.util.concurrent.atomic.AtomicLong()
+ def newId: Long = curId.getAndIncrement()
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 948cb6fbedd3..0e9e375b8acf 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -154,7 +154,11 @@ case class NullIf(left: Expression, right: Expression,
replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {
def this(left: Expression, right: Expression) = {
- this(left, right, If(EqualTo(left, right), Literal.create(null,
left.dataType), left))
+ this(left, right, {
+ val commonExpr = CommonExpressionDef(left)
+ val ref = new CommonExpressionRef(commonExpr)
+ With(If(EqualTo(ref, right), Literal.create(null, left.dataType), ref),
Seq(commonExpr))
+ })
}
override def parameters: Seq[Expression] = Seq(left, right)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 48ecb9aee211..decef766ae97 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -147,6 +147,9 @@ abstract class Optimizer(catalogManager: CatalogManager)
val batches = (
Batch("Finish Analysis", Once, FinishAnalysis) ::
+ // We must run this batch after `ReplaceExpressions`, as
`RuntimeReplaceable` expression
+ // may produce `With` expressions that need to be rewritten.
+ Batch("Rewrite With expression", Once, RewriteWithExpression) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
new file mode 100644
index 000000000000..c5bd71b4a7d1
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, CommonExpressionDef,
CommonExpressionRef, Expression, With}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF,
WITH_EXPRESSION}
+
+/**
+ * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the
common expressions, or
+ * just inline them if they are cheap.
+ *
+ * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions.
If we expand its
+ * usage, we should support aggregate/window functions as well.
+ */
+object RewriteWithExpression extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) {
+ case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
+ var newChildren = p.children
+ var newPlan: LogicalPlan = p.transformExpressionsUp {
+ case With(child, defs) =>
+ val refToExpr = mutable.HashMap.empty[Long, Expression]
+ val childProjections =
Array.fill(newChildren.size)(mutable.ArrayBuffer.empty[Alias])
+
+ defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id),
index) =>
+ if (CollapseProject.isCheap(child)) {
+ refToExpr(id) = child
+ } else {
+ val childProjectionIndex = newChildren.indexWhere(
+ c => child.references.subsetOf(c.outputSet)
+ )
+ if (childProjectionIndex == -1) {
+ // When we cannot rewrite the common expressions, force to
inline them so that the
+ // query can still run. This can happen if the join
condition contains `With` and
+ // the common expression references columns from both join
sides.
+ // TODO: things can go wrong if the common expression is
nondeterministic. We
+ // don't fix it for now to match the old buggy
behavior when certain
+ // `RuntimeReplaceable` did not use the `With`
expression.
+ // TODO: we should calculate the ref count and also inline
the common expression
+ // if it's ref count is 1.
+ refToExpr(id) = child
+ } else {
+ val alias = Alias(child, s"_common_expr_$index")()
+ childProjections(childProjectionIndex) += alias
+ refToExpr(id) = alias.toAttribute
+ }
+ }
+ }
+
+ newChildren = newChildren.zip(childProjections).map { case (child,
projections) =>
+ if (projections.nonEmpty) {
+ Project(child.output ++ projections, child)
+ } else {
+ child
+ }
+ }
+
+ child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
+ case ref: CommonExpressionRef => refToExpr(ref.id)
+ }
+ }
+
+ newPlan = newPlan.withNewChildren(newChildren)
+ if (p.output == newPlan.output) {
+ newPlan
+ } else {
+ Project(p.output, newPlan)
+ }
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 8b714d5a5d28..9b3337d1a940 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -36,6 +36,7 @@ object TreePattern extends Enumeration {
val CASE_WHEN: Value = Value
val CAST: Value = Value
val COALESCE: Value = Value
+ val COMMON_EXPR_REF: Value = Value
val CONCAT: Value = Value
val COUNT: Value = Value
val CREATE_NAMED_STRUCT: Value = Value
@@ -132,6 +133,7 @@ object TreePattern extends Enumeration {
val TYPED_FILTER: Value = Value
val WINDOW: Value = Value
val WINDOW_GROUP_LIMIT: Value = Value
+ val WITH_EXPRESSION: Value = Value
val WITH_WINDOW_DEFINITION: Value = Value
// Unresolved expression patterns (Alphabetically ordered)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
new file mode 100644
index 000000000000..c625379eb5ff
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
CommonExpressionDef, CommonExpressionRef, With}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.IntegerType
+
+class RewriteWithExpressionSuite extends PlanTest {
+
+ object Optimizer extends RuleExecutor[LogicalPlan] {
+ val batches = Batch("Rewrite With expression", Once,
RewriteWithExpression) :: Nil
+ }
+
+ private val testRelation = LocalRelation($"a".int, $"b".int)
+ private val testRelation2 = LocalRelation($"x".int, $"y".int)
+
+ test("simple common expression") {
+ val a = testRelation.output.head
+ val commonExprDef = CommonExpressionDef(a)
+ val ref = new CommonExpressionRef(commonExprDef)
+ val plan = testRelation.select(With(ref + ref,
Seq(commonExprDef)).as("col"))
+ comparePlans(Optimizer.execute(plan), testRelation.select((a +
a).as("col")))
+ }
+
+ test("non-cheap common expression") {
+ val a = testRelation.output.head
+ val commonExprDef = CommonExpressionDef(a + a)
+ val ref = new CommonExpressionRef(commonExprDef)
+ val plan = testRelation.select(With(ref * ref,
Seq(commonExprDef)).as("col"))
+ val commonExprName = "_common_expr_0"
+ comparePlans(
+ Optimizer.execute(plan),
+ testRelation
+ .select((testRelation.output :+ (a + a).as(commonExprName)): _*)
+ .select(($"$commonExprName" * $"$commonExprName").as("col"))
+ .analyze
+ )
+ }
+
+ test("nested WITH expression") {
+ val a = testRelation.output.head
+ val commonExprDef = CommonExpressionDef(a + a)
+ val ref = new CommonExpressionRef(commonExprDef)
+ val innerExpr = With(ref + ref, Seq(commonExprDef))
+ val innerCommonExprName = "_common_expr_0"
+
+ val b = testRelation.output.last
+ val outerCommonExprDef = CommonExpressionDef(innerExpr + b)
+ val outerRef = new CommonExpressionRef(outerCommonExprDef)
+ val outerExpr = With(outerRef * outerRef, Seq(outerCommonExprDef))
+ val outerCommonExprName = "_common_expr_0"
+
+ val plan = testRelation.select(outerExpr.as("col"))
+ val rewrittenOuterExpr = ($"$innerCommonExprName" +
$"$innerCommonExprName" + b)
+ .as(outerCommonExprName)
+ val outerExprAttr = AttributeReference(outerCommonExprName, IntegerType)(
+ exprId = rewrittenOuterExpr.exprId)
+ comparePlans(
+ Optimizer.execute(plan),
+ testRelation
+ .select((testRelation.output :+ (a + a).as(innerCommonExprName)): _*)
+ .select((testRelation.output :+ $"$innerCommonExprName" :+
rewrittenOuterExpr): _*)
+ .select((outerExprAttr * outerExprAttr).as("col"))
+ .analyze
+ )
+ }
+
+ test("WITH expression in filter") {
+ val a = testRelation.output.head
+ val commonExprDef = CommonExpressionDef(a + a)
+ val ref = new CommonExpressionRef(commonExprDef)
+ val plan = testRelation.where(With(ref < 10 && ref > 0,
Seq(commonExprDef)))
+ val commonExprName = "_common_expr_0"
+ comparePlans(
+ Optimizer.execute(plan),
+ testRelation
+ .select((testRelation.output :+ (a + a).as(commonExprName)): _*)
+ .where($"$commonExprName" < 10 && $"$commonExprName" > 0)
+ .select(testRelation.output: _*)
+ .analyze
+ )
+ }
+
+ test("WITH expression in join condition: only reference left child") {
+ val a = testRelation.output.head
+ val commonExprDef = CommonExpressionDef(a + a)
+ val ref = new CommonExpressionRef(commonExprDef)
+ val condition = With(ref < 10 && ref > 0, Seq(commonExprDef))
+ val plan = testRelation.join(testRelation2, condition = Some(condition))
+ val commonExprName = "_common_expr_0"
+ comparePlans(
+ Optimizer.execute(plan),
+ testRelation
+ .select((testRelation.output :+ (a + a).as(commonExprName)): _*)
+ .join(testRelation2, condition = Some($"$commonExprName" < 10 &&
$"$commonExprName" > 0))
+ .select((testRelation.output ++ testRelation2.output): _*)
+ .analyze
+ )
+ }
+
+ test("WITH expression in join condition: only reference right child") {
+ val x = testRelation2.output.head
+ val commonExprDef = CommonExpressionDef(x + x)
+ val ref = new CommonExpressionRef(commonExprDef)
+ val condition = With(ref < 10 && ref > 0, Seq(commonExprDef))
+ val plan = testRelation.join(testRelation2, condition = Some(condition))
+ val commonExprName = "_common_expr_0"
+ comparePlans(
+ Optimizer.execute(plan),
+ testRelation
+ .join(
+ testRelation2.select((testRelation2.output :+ (x +
x).as(commonExprName)): _*),
+ condition = Some($"$commonExprName" < 10 && $"$commonExprName" > 0)
+ )
+ .select((testRelation.output ++ testRelation2.output): _*)
+ .analyze
+ )
+ }
+
+ test("WITH expression in join condition: reference both children") {
+ val a = testRelation.output.head
+ val x = testRelation2.output.head
+ val commonExprDef = CommonExpressionDef(a + x)
+ val ref = new CommonExpressionRef(commonExprDef)
+ val condition = With(ref < 10 && ref > 0, Seq(commonExprDef))
+ val plan = testRelation.join(testRelation2, condition = Some(condition))
+ comparePlans(
+ Optimizer.execute(plan),
+ testRelation
+ .join(
+ testRelation2,
+ // Can't pre-evaluate, have to inline
+ condition = Some((a + x) < 10 && (a + x) > 0)
+ )
+ )
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]