Repository: spark Updated Branches: refs/heads/master 7dd9fc67a -> 269fc62b2
[SQL] Support transforming TreeNodes with Option children. Thanks goes to @marmbrus for his implementation. Author: Michael Armbrust <[email protected]> Author: Zongheng Yang <[email protected]> Closes #1074 from concretevitamin/option-treenode and squashes the following commits: ef27b85 [Zongheng Yang] Merge pull request #1 from marmbrus/pr/1074 73133c2 [Michael Armbrust] TreeNodes can't be inner classes. ab78420 [Zongheng Yang] Add a test. 2ccb721 [Michael Armbrust] Add support for transformation of optional children. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/269fc62b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/269fc62b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/269fc62b Branch: refs/heads/master Commit: 269fc62b20ee5f9cd60a8f133c29f662d17071b1 Parents: 7dd9fc6 Author: Michael Armbrust <[email protected]> Authored: Sun Jun 15 11:28:34 2014 +0200 Committer: Michael Armbrust <[email protected]> Committed: Sun Jun 15 11:28:34 2014 +0200 ---------------------------------------------------------------------- .../spark/sql/catalyst/trees/TreeNode.scala | 19 +++++++++++++- .../sql/catalyst/trees/TreeNodeSuite.scala | 27 ++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/269fc62b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 0369129..cd04bdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -187,6 +187,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } + case Some(arg: TreeNode[_]) if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + if (!(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } case m: Map[_,_] => m case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -231,6 +239,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } + case Some(arg: TreeNode[_]) if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformUp(rule) + if (!(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } case m: Map[_,_] => m case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -273,7 +289,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName?") + this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " + + s"Exception message: ${e.getMessage}.") } } http://git-wip-us.apache.org/repos/asf/spark/blob/269fc62b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 1ddc41a..6344874 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -22,6 +22,17 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, NullType} + +case class Dummy(optKey: Option[Expression]) extends Expression { + def children = optKey.toSeq + def references = Set.empty[Attribute] + def nullable = true + def dataType = NullType + override lazy val resolved = true + type EvaluatedType = Any + def eval(input: Row) = null.asInstanceOf[Any] +} class TreeNodeSuite extends FunSuite { test("top node changed") { @@ -75,4 +86,20 @@ class TreeNodeSuite extends FunSuite { assert(expected === actual) } + + test("transform works on nodes with Option children") { + val dummy1 = Dummy(Some(Literal("1", StringType))) + val dummy2 = Dummy(None) + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + + var actual = dummy1 transformDown toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy1 transformUp toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy2 transform toZero + assert(actual === Dummy(None)) + } + }
