This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d8d604bc07b [SPARK-40599][SQL] Add multiTransform methods to TreeNode
to generate alternatives
d8d604bc07b is described below
commit d8d604bc07bc3b8c98f73c4b10f93cb4eb7113be
Author: Peter Toth <[email protected]>
AuthorDate: Tue Jan 17 20:58:37 2023 +0800
[SPARK-40599][SQL] Add multiTransform methods to TreeNode to generate
alternatives
### What changes were proposed in this pull request?
This PR introduce `TreeNode.multiTransform()` methods to be able to
recursively transform a `TreeNode` (and so a tree) into multiple alternatives.
These functions are particularly useful if we want to transform an expression
with a projection in which subexpressions can be aliased with multiple
different attributes.
E.g. if we have a partitioning expression `HashPartitioning(a + b)` and we
have a `Project` node that aliases `a` as `a1` and `a2` and `b` as `b1` and
`b2` we can easily generate a stream of alternative transformations of the
original partitioning:
```
// This is a simplified test, some arguments are missing to make it conciese
val partitioning = HashPartitioning(Add(a, b))
val aliases: Map[Expression, Seq[Attribute]] = ... // collect the alias map
from project
val s = partitioning.multiTransform {
case e: Expression if aliases.contains(e.canonicalized) =>
aliases(e.canonicalized)
}
s // Stream(HashPartitioning(Add(a1, b1)), HashPartitioning(Add(a1, b2)),
HashPartitioning(Add(a2, b2)), HashPartitioning(Add(a2, b2)))
```
The result of `multiTransform` is a lazy stream to be able to limit the
number of alternatives generated at the caller side as needed.
### Why are the changes needed?
`TreeNode.multiTransform()` is a useful helper method.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New UTs are added.
Closes #38034 from peter-toth/SPARK-40599-multitransform.
Authored-by: Peter Toth <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../apache/spark/sql/catalyst/trees/TreeNode.scala | 128 +++++++++++++++++++++
.../spark/sql/catalyst/trees/TreeNodeSuite.scala | 104 +++++++++++++++++
2 files changed, 232 insertions(+)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 9510aa4d9e7..dc64e5e2560 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -618,6 +618,134 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
extends Product with Tre
}
}
+ /**
+ * Returns alternative copies of this node where `rule` has been recursively
applied to it and all
+ * of its children (pre-order).
+ *
+ * @param rule a function used to generate alternatives for a node
+ * @return the stream of alternatives
+ */
+ def multiTransformDown(
+ rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
+ multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
+ }
+
+ /**
+ * Returns alternative copies of this node where `rule` has been recursively
applied to it and all
+ * of its children (pre-order).
+ *
+ * As it is very easy to generate enormous number of alternatives when the
input tree is huge or
+ * when the rule returns many alternatives for many nodes, this function
returns the alternatives
+ * as a lazy `Stream` to be able to limit the number of alternatives
generated at the caller side
+ * as needed.
+ *
+ * The rule should not apply or can return a one element stream of original
node to indicate that
+ * the original node without any transformation is a valid alternative.
+ *
+ * The rule can return `Stream.empty` to indicate that the original node
should be pruned. In this
+ * case `multiTransform()` returns an empty `Stream`.
+ *
+ * Please consider the following examples of
`input.multiTransformDown(rule)`:
+ *
+ * We have an input expression:
+ * `Add(a, b)`
+ *
+ * 1.
+ * We have a simple rule:
+ * `a` => `Stream(1, 2)`
+ * `b` => `Stream(10, 20)`
+ * `Add(a, b)` => `Stream(11, 12, 21, 22)`
+ *
+ * The output is:
+ * `Stream(11, 12, 21, 22)`
+ *
+ * 2.
+ * In the previous example if we want to generate alternatives of `a` and
`b` too then we need to
+ * explicitly add the original `Add(a, b)` expression to the rule:
+ * `a` => `Stream(1, 2)`
+ * `b` => `Stream(10, 20)`
+ * `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))`
+ *
+ * The output is:
+ * `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
+ *
+ * @param rule a function used to generate alternatives for a node
+ * @param cond a Lambda expression to prune tree traversals. If
`cond.apply` returns false
+ * on a TreeNode T, skips processing T and its subtree;
otherwise, processes
+ * T and its subtree recursively.
+ * @param ruleId is a unique Id for `rule` to prune unnecessary tree
traversals. When it is
+ * UnknownRuleId, no pruning happens. Otherwise, if `rule`
(with id `ruleId`)
+ * has been marked as in effective on a TreeNode T, skips
processing T and its
+ * subtree. Do not pass it if the rule is not purely
functional and reads a
+ * varying initial state for different invocations.
+ * @return the stream of alternatives
+ */
+ def multiTransformDownWithPruning(
+ cond: TreePatternBits => Boolean,
+ ruleId: RuleId = UnknownRuleId
+ )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
+ if (!cond.apply(this) || isRuleIneffective(ruleId)) {
+ return Stream(this)
+ }
+
+ // We could return `Stream(this)` if the `rule` doesn't apply and handle
both
+ // - the doesn't apply
+ // - and the rule returns a one element `Stream(originalNode)`
+ // cases together. But, unfortunately it doesn't seem like there is a way
to match on a one
+ // element stream without eagerly computing the tail head. So this
contradicts with the purpose
+ // of only taking the necessary elements from the alternatives. I.e. the
+ // "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail.
+ // Please note that this behaviour has a downside as well that we can only
mark the rule on the
+ // original node ineffective if the rule didn't match.
+ var ruleApplied = true
+ val afterRules = CurrentOrigin.withOrigin(origin) {
+ rule.applyOrElse(this, (_: BaseType) => {
+ ruleApplied = false
+ Stream.empty
+ })
+ }
+
+ val afterRulesStream = if (afterRules.isEmpty) {
+ if (ruleApplied) {
+ // If the rule returned with empty alternatives then prune
+ Stream.empty
+ } else {
+ // If the rule was not applied then keep the original node
+ this.markRuleAsIneffective(ruleId)
+ Stream(this)
+ }
+ } else {
+ // If the rule was applied then use the returned alternatives
+ afterRules.map { afterRule =>
+ if (this fastEquals afterRule) {
+ this
+ } else {
+ afterRule.copyTagsFrom(this)
+ afterRule
+ }
+ }
+ }
+
+ afterRulesStream.flatMap { afterRule =>
+ if (afterRule.containsChild.nonEmpty) {
+ generateChildrenSeq(
+ afterRule.children.map(_.multiTransformDownWithPruning(cond,
ruleId)(rule)))
+ .map(afterRule.withNewChildren)
+ } else {
+ Stream(afterRule)
+ }
+ }
+ }
+
+ private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]):
Stream[Seq[T]] = {
+ childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream,
childrenSeqStream) =>
+ for {
+ childrenSeq <- childrenSeqStream
+ child <- childrenStream
+ } yield child +: childrenSeq
+ )
+ }
+
/**
* Returns a copy of this node where `f` has been applied to all the nodes
in `children`.
*/
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 286d3dddae6..ac28917675e 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -977,4 +977,108 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
assert(origin.context.summary.isEmpty)
}
}
+
+ private def newErrorAfterStream(es: Expression*) = {
+ es.toStream.append(
+ throw new NoSuchElementException("Stream should not return more
elements")
+ )
+ }
+
+ test("multiTransformDown generates all alternatives") {
+ val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"),
Literal("d")))
+ val transformed = e.multiTransformDown {
+ case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
+ case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
+ case Add(StringLiteral("c"), StringLiteral("d"), _) =>
+ Stream(Literal(100), Literal(200), Literal(300))
+ }
+ val expected = for {
+ cd <- Seq(Literal(100), Literal(200), Literal(300))
+ b <- Seq(Literal(10), Literal(20), Literal(30))
+ a <- Seq(Literal(1), Literal(2), Literal(3))
+ } yield Add(Add(a, b), cd)
+ assert(transformed === expected)
+ }
+
+ test("multiTransformDown is lazy") {
+ val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"),
Literal("d")))
+ val transformed = e.multiTransformDown {
+ case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
+ case StringLiteral("b") => newErrorAfterStream(Literal(10))
+ case Add(StringLiteral("c"), StringLiteral("d"), _) =>
newErrorAfterStream(Literal(100))
+ }
+ val expected = for {
+ a <- Seq(Literal(1), Literal(2), Literal(3))
+ } yield Add(Add(a, Literal(10)), Literal(100))
+ // We don't access alternatives for `b` after 10 and for `c` after 100
+ assert(transformed.take(3) == expected)
+ intercept[NoSuchElementException] {
+ transformed.take(3 + 1).toList
+ }
+
+ val transformed2 = e.multiTransformDown {
+ case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
+ case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
+ case Add(StringLiteral("c"), StringLiteral("d"), _) =>
newErrorAfterStream(Literal(100))
+ }
+ val expected2 = for {
+ b <- Seq(Literal(10), Literal(20), Literal(30))
+ a <- Seq(Literal(1), Literal(2), Literal(3))
+ } yield Add(Add(a, b), Literal(100))
+ // We don't access alternatives for `c` after 100
+ assert(transformed2.take(3 * 3) === expected2)
+ intercept[NoSuchElementException] {
+ transformed.take(3 * 3 + 1).toList
+ }
+ }
+
+ test("multiTransformDown rule return this") {
+ val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"),
Literal("d")))
+ val transformed = e.multiTransformDown {
+ case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s)
+ case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s)
+ case a @ Add(StringLiteral("c"), StringLiteral("d"), _) =>
+ Stream(Literal(100), Literal(200), a)
+ }
+ val expected = for {
+ cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d")))
+ b <- Seq(Literal(10), Literal(20), Literal("b"))
+ a <- Seq(Literal(1), Literal(2), Literal("a"))
+ } yield Add(Add(a, b), cd)
+ assert(transformed == expected)
+ }
+
+ test("multiTransformDown doesn't stop generating alternatives of descendants
when non-leaf is " +
+ "transformed and itself is in the alternatives") {
+ val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"),
Literal("d")))
+ val transformed = e.multiTransformDown {
+ case a @ Add(StringLiteral("a"), StringLiteral("b"), _) =>
+ Stream(Literal(11), Literal(12), Literal(21), Literal(22), a)
+ case StringLiteral("a") => Stream(Literal(1), Literal(2))
+ case StringLiteral("b") => Stream(Literal(10), Literal(20))
+ case Add(StringLiteral("c"), StringLiteral("d"), _) =>
Stream(Literal(100), Literal(200))
+ }
+ val expected = for {
+ cd <- Seq(Literal(100), Literal(200))
+ ab <- Seq(Literal(11), Literal(12), Literal(21), Literal(22)) ++
+ (for {
+ b <- Seq(Literal(10), Literal(20))
+ a <- Seq(Literal(1), Literal(2))
+ } yield Add(a, b))
+ } yield Add(ab, cd)
+ assert(transformed == expected)
+ }
+
+ test("multiTransformDown can prune") {
+ val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"),
Literal("d")))
+ val transformed = e.multiTransformDown {
+ case StringLiteral("a") => Stream.empty
+ }
+ assert(transformed.isEmpty)
+
+ val transformed2 = e.multiTransformDown {
+ case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty
+ }
+ assert(transformed2.isEmpty)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]