This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c953610deaf [SPARK-40599][SQL] Relax multiTransform rule type to allow
alternatives to be any kinds of Seq
c953610deaf is described below
commit c953610deafda769feb85fbb936591ffc4448f8e
Author: Peter Toth <[email protected]>
AuthorDate: Thu Jan 19 23:54:06 2023 +0800
[SPARK-40599][SQL] Relax multiTransform rule type to allow alternatives to
be any kinds of Seq
### What changes were proposed in this pull request?
This is a follow-up PR to https://github.com/apache/spark/pull/38034. It
relaxes `multiTransformDown()`'s `rule` parameter type to accept any kinds of
`Seq` and make `MultiTransform.generateCartesianProduct()` helper public.
### Why are the changes needed?
API mprovement.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing UTs.
Closes #39652 from peter-toth/SPARK-40599-multitransform-follow-up.
Authored-by: Peter Toth <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../apache/spark/sql/catalyst/trees/TreeNode.scala | 70 +++++++++++++---------
.../spark/sql/catalyst/trees/TreeNodeSuite.scala | 31 +++++-----
2 files changed, 57 insertions(+), 44 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index dc64e5e2560..c8df2086a72 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -626,7 +626,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
extends Product with Tre
* @return the stream of alternatives
*/
def multiTransformDown(
- rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
+ rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
}
@@ -639,10 +639,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
extends Product with Tre
* as a lazy `Stream` to be able to limit the number of alternatives
generated at the caller side
* as needed.
*
- * The rule should not apply or can return a one element stream of original
node to indicate that
+ * The purpose of this function to access the returned alternatives by the
rule only if they are
+ * needed so the rule can return a `Stream` whose elements are also lazily
calculated.
+ * E.g. `multiTransform*` calls can be nested with the help of
+ * `MultiTransform.generateCartesianProduct()`.
+ *
+ * The rule should not apply or can return a one element `Seq` of original
node to indicate that
* the original node without any transformation is a valid alternative.
*
- * The rule can return `Stream.empty` to indicate that the original node
should be pruned. In this
+ * The rule can return `Seq.empty` to indicate that the original node should
be pruned. In this
* case `multiTransform()` returns an empty `Stream`.
*
* Please consider the following examples of
`input.multiTransformDown(rule)`:
@@ -652,9 +657,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
extends Product with Tre
*
* 1.
* We have a simple rule:
- * `a` => `Stream(1, 2)`
- * `b` => `Stream(10, 20)`
- * `Add(a, b)` => `Stream(11, 12, 21, 22)`
+ * `a` => `Seq(1, 2)`
+ * `b` => `Seq(10, 20)`
+ * `Add(a, b)` => `Seq(11, 12, 21, 22)`
*
* The output is:
* `Stream(11, 12, 21, 22)`
@@ -662,9 +667,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
extends Product with Tre
* 2.
* In the previous example if we want to generate alternatives of `a` and
`b` too then we need to
* explicitly add the original `Add(a, b)` expression to the rule:
- * `a` => `Stream(1, 2)`
- * `b` => `Stream(10, 20)`
- * `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))`
+ * `a` => `Seq(1, 2)`
+ * `b` => `Seq(10, 20)`
+ * `Add(a, b)` => `Seq(11, 12, 21, 22, Add(a, b))`
*
* The output is:
* `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
@@ -683,25 +688,25 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
extends Product with Tre
def multiTransformDownWithPruning(
cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId
- )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
+ )(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
return Stream(this)
}
- // We could return `Stream(this)` if the `rule` doesn't apply and handle
both
+ // We could return `Seq(this)` if the `rule` doesn't apply and handle both
// - the doesn't apply
- // - and the rule returns a one element `Stream(originalNode)`
- // cases together. But, unfortunately it doesn't seem like there is a way
to match on a one
- // element stream without eagerly computing the tail head. So this
contradicts with the purpose
- // of only taking the necessary elements from the alternatives. I.e. the
- // "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail.
+ // - and the rule returns a one element `Seq(originalNode)`
+ // cases together. The returned `Seq` can be a `Stream` and unfortunately
it doesn't seem like
+ // there is a way to match on a one element stream without eagerly
computing the tail's head.
+ // This contradicts with the purpose of only taking the necessary elements
from the
+ // alternatives. I.e. the "multiTransformDown is lazy" test case in
`TreeNodeSuite` would fail.
// Please note that this behaviour has a downside as well that we can only
mark the rule on the
// original node ineffective if the rule didn't match.
var ruleApplied = true
val afterRules = CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, (_: BaseType) => {
ruleApplied = false
- Stream.empty
+ Seq.empty
})
}
@@ -716,7 +721,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
extends Product with Tre
}
} else {
// If the rule was applied then use the returned alternatives
- afterRules.map { afterRule =>
+ afterRules.toStream.map { afterRule =>
if (this fastEquals afterRule) {
this
} else {
@@ -728,7 +733,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
extends Product with Tre
afterRulesStream.flatMap { afterRule =>
if (afterRule.containsChild.nonEmpty) {
- generateChildrenSeq(
+ MultiTransform.generateCartesianProduct(
afterRule.children.map(_.multiTransformDownWithPruning(cond,
ruleId)(rule)))
.map(afterRule.withNewChildren)
} else {
@@ -737,15 +742,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
extends Product with Tre
}
}
- private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]):
Stream[Seq[T]] = {
- childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream,
childrenSeqStream) =>
- for {
- childrenSeq <- childrenSeqStream
- child <- childrenStream
- } yield child +: childrenSeq
- )
- }
-
/**
* Returns a copy of this node where `f` has been applied to all the nodes
in `children`.
*/
@@ -1368,3 +1364,21 @@ trait QuaternaryLike[T <: TreeNode[T]] { self:
TreeNode[T] =>
protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird:
T, newFourth: T): T
}
+
+object MultiTransform {
+
+ /**
+ * Returns the stream of `Seq` elements by generating the cartesian product
of sequences.
+ *
+ * @param elementSeqs a list of sequences to build the cartesian product from
+ * @return the stream of generated `Seq` elements
+ */
+ def generateCartesianProduct[T](elementSeqs: Seq[Seq[T]]): Stream[Seq[T]] = {
+ elementSeqs.foldRight(Stream(Seq.empty[T]))((elements, elementTails) =>
+ for {
+ elementTail <- elementTails
+ element <- elements
+ } yield element +: elementTail
+ )
+ }
+}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index ac28917675e..e4adf59b392 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -987,10 +987,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
test("multiTransformDown generates all alternatives") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"),
Literal("d")))
val transformed = e.multiTransformDown {
- case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
- case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
+ case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
+ case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30))
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
- Stream(Literal(100), Literal(200), Literal(300))
+ Seq(Literal(100), Literal(200), Literal(300))
}
val expected = for {
cd <- Seq(Literal(100), Literal(200), Literal(300))
@@ -1003,7 +1003,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
test("multiTransformDown is lazy") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"),
Literal("d")))
val transformed = e.multiTransformDown {
- case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
+ case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => newErrorAfterStream(Literal(10))
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
newErrorAfterStream(Literal(100))
}
@@ -1017,8 +1017,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
}
val transformed2 = e.multiTransformDown {
- case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
- case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
+ case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
+ case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30))
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
newErrorAfterStream(Literal(100))
}
val expected2 = for {
@@ -1035,10 +1035,9 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper
{
test("multiTransformDown rule return this") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"),
Literal("d")))
val transformed = e.multiTransformDown {
- case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s)
- case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s)
- case a @ Add(StringLiteral("c"), StringLiteral("d"), _) =>
- Stream(Literal(100), Literal(200), a)
+ case s @ StringLiteral("a") => Seq(Literal(1), Literal(2), s)
+ case s @ StringLiteral("b") => Seq(Literal(10), Literal(20), s)
+ case a @ Add(StringLiteral("c"), StringLiteral("d"), _) =>
Seq(Literal(100), Literal(200), a)
}
val expected = for {
cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d")))
@@ -1053,10 +1052,10 @@ class TreeNodeSuite extends SparkFunSuite with
SQLHelper {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"),
Literal("d")))
val transformed = e.multiTransformDown {
case a @ Add(StringLiteral("a"), StringLiteral("b"), _) =>
- Stream(Literal(11), Literal(12), Literal(21), Literal(22), a)
- case StringLiteral("a") => Stream(Literal(1), Literal(2))
- case StringLiteral("b") => Stream(Literal(10), Literal(20))
- case Add(StringLiteral("c"), StringLiteral("d"), _) =>
Stream(Literal(100), Literal(200))
+ Seq(Literal(11), Literal(12), Literal(21), Literal(22), a)
+ case StringLiteral("a") => Seq(Literal(1), Literal(2))
+ case StringLiteral("b") => Seq(Literal(10), Literal(20))
+ case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100),
Literal(200))
}
val expected = for {
cd <- Seq(Literal(100), Literal(200))
@@ -1072,12 +1071,12 @@ class TreeNodeSuite extends SparkFunSuite with
SQLHelper {
test("multiTransformDown can prune") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"),
Literal("d")))
val transformed = e.multiTransformDown {
- case StringLiteral("a") => Stream.empty
+ case StringLiteral("a") => Seq.empty
}
assert(transformed.isEmpty)
val transformed2 = e.multiTransformDown {
- case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty
+ case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq.empty
}
assert(transformed2.isEmpty)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]