Repository: spark
Updated Branches:
  refs/heads/master 1208f72ac -> c4787a369


[SPARK-3194][SQL] Add AttributeSet to fix bugs with invalid comparisons of 
AttributeReferences

It is common to want to describe sets of attributes that are in various parts 
of a query plan.  However, the semantics of putting `AttributeReference` 
objects into a standard Scala `Set` result in subtle bugs when references 
differ cosmetically.  For example, with case insensitive resolution it is 
possible to have two references to the same attribute whose names are not equal.

In this PR I introduce a new abstraction, an `AttributeSet`, which performs all 
comparisons using the globally unique `ExpressionId` instead of case class 
equality.  (There is already a related class, 
[`AttributeMap`](https://github.com/marmbrus/spark/blob/inMemStats/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala#L32))
  This new type of set is used to fix a bug in the optimizer where needed 
attributes were getting projected away underneath join operators.

I also took this opportunity to refactor the expression and query plan base 
classes.  In all but one instance the logic for computing the `references` of 
an `Expression` were the same.  Thus, I moved this logic into the base class.

For query plans the semantics of  the `references` method were ill defined (is 
it the references output? or is it those used by expression evaluation? or 
what?).  As a result, this method wasn't really used very much.  So, I removed 
it.

TODO:
 - [x] Finish scala doc for `AttributeSet`
 - [x] Scan the code for other instances of `Set[Attribute]` and refactor them.
 - [x] Finish removing `references` from `QueryPlan`

Author: Michael Armbrust <[email protected]>

Closes #2109 from marmbrus/attributeSets and squashes the following commits:

1c0dae5 [Michael Armbrust] work on serialization bug.
9ba868d [Michael Armbrust] Merge remote-tracking branch 'origin/master' into 
attributeSets
3ae5288 [Michael Armbrust] review comments
40ce7f6 [Michael Armbrust] style
d577cc7 [Michael Armbrust] Scaladoc
cae5d22 [Michael Armbrust] remove more references implementations
d6e16be [Michael Armbrust] Remove more instances of "def references" and normal 
sets of attributes.
fc26b49 [Michael Armbrust] Add AttributeSet class, remove references from 
Expression.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c4787a36
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c4787a36
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c4787a36

Branch: refs/heads/master
Commit: c4787a3690a9ed3b8b2c6c294fc4a6915436b6f7
Parents: 1208f72
Author: Michael Armbrust <[email protected]>
Authored: Tue Aug 26 16:29:14 2014 -0700
Committer: Reynold Xin <[email protected]>
Committed: Tue Aug 26 16:29:14 2014 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   6 +-
 .../sql/catalyst/analysis/unresolved.scala      |   1 -
 .../sql/catalyst/expressions/AttributeSet.scala | 106 +++++++++++++++++++
 .../catalyst/expressions/BoundAttribute.scala   |   2 -
 .../sql/catalyst/expressions/Expression.scala   |   6 +-
 .../spark/sql/catalyst/expressions/Rand.scala   |   1 -
 .../sql/catalyst/expressions/ScalaUdf.scala     |   1 -
 .../sql/catalyst/expressions/SortOrder.scala    |   1 -
 .../sql/catalyst/expressions/WrapDynamic.scala  |   2 +-
 .../sql/catalyst/expressions/aggregates.scala   |  25 ++---
 .../sql/catalyst/expressions/arithmetic.scala   |   2 -
 .../sql/catalyst/expressions/complexTypes.scala |   2 +-
 .../sql/catalyst/expressions/generators.scala   |   2 -
 .../sql/catalyst/expressions/literals.scala     |   4 +-
 .../catalyst/expressions/namedExpressions.scala |   6 +-
 .../catalyst/expressions/nullFunctions.scala    |   3 -
 .../sql/catalyst/expressions/predicates.scala   |   6 +-
 .../spark/sql/catalyst/expressions/sets.scala   |   5 -
 .../catalyst/expressions/stringOperations.scala |   2 -
 .../sql/catalyst/optimizer/Optimizer.scala      |  12 ++-
 .../spark/sql/catalyst/plans/QueryPlan.scala    |   4 +-
 .../catalyst/plans/logical/LogicalPlan.scala    |  11 +-
 .../plans/logical/ScriptTransformation.scala    |   4 +-
 .../catalyst/plans/logical/basicOperators.scala |  29 +----
 .../catalyst/plans/logical/partitioning.scala   |   4 -
 .../catalyst/plans/physical/partitioning.scala  |   3 +-
 .../sql/catalyst/trees/TreeNodeSuite.scala      |   1 -
 .../scala/org/apache/spark/sql/SQLContext.scala |   7 +-
 .../columnar/InMemoryColumnarTableScan.scala    |   2 -
 .../apache/spark/sql/execution/SparkPlan.scala  |   3 +-
 .../spark/sql/execution/debug/package.scala     |   2 -
 .../apache/spark/sql/execution/pythonUdfs.scala |   2 -
 .../apache/spark/sql/hive/HiveStrategies.scala  |   8 +-
 .../org/apache/spark/sql/hive/hiveUdfs.scala    |   5 -
 .../hive/execution/HiveResolutionSuite.scala    |   9 +-
 35 files changed, 166 insertions(+), 123 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c18d785..4a95240 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -132,7 +132,7 @@ class Analyzer(catalog: Catalog, registry: 
FunctionRegistry, caseSensitive: Bool
       case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved 
&& p.resolved =>
         val unresolved = ordering.flatMap(_.collect { case 
UnresolvedAttribute(name) => name })
         val resolved = unresolved.flatMap(child.resolveChildren)
-        val requiredAttributes = resolved.collect { case a: Attribute => a 
}.toSet
+        val requiredAttributes = AttributeSet(resolved.collect { case a: 
Attribute => a })
 
         val missingInProject = requiredAttributes -- p.output
         if (missingInProject.nonEmpty) {
@@ -152,8 +152,8 @@ class Analyzer(catalog: Catalog, registry: 
FunctionRegistry, caseSensitive: Bool
         )
 
         logDebug(s"Grouping expressions: $groupingRelation")
-        val resolved = unresolved.flatMap(groupingRelation.resolve).toSet
-        val missingInAggs = resolved -- a.outputSet
+        val resolved = unresolved.flatMap(groupingRelation.resolve)
+        val missingInAggs = resolved.filterNot(a.outputSet.contains)
         logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
         if (missingInAggs.nonEmpty) {
           // Add missing grouping exprs and then project them away after the 
sort.

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index a0e2577..a2c61c6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -66,7 +66,6 @@ case class UnresolvedFunction(name: String, children: 
Seq[Expression]) extends E
   override def dataType = throw new UnresolvedException(this, "dataType")
   override def foldable = throw new UnresolvedException(this, "foldable")
   override def nullable = throw new UnresolvedException(this, "nullable")
-  override def references = children.flatMap(_.references).toSet
   override lazy val resolved = false
 
   // Unresolved functions are transient at compile time and don't get 
evaluated during execution.

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
new file mode 100644
index 0000000..c3a08bb
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+protected class AttributeEquals(val a: Attribute) {
+  override def hashCode() = a.exprId.hashCode()
+  override def equals(other: Any) = other match {
+    case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId
+    case otherAttribute => false
+  }
+}
+
+object AttributeSet {
+  /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute 
Attributes]]. */
+  def apply(baseSet: Seq[Attribute]) = {
+    new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet)
+  }
+}
+
+/**
+ * A Set designed to hold [[AttributeReference]] objects, that performs 
equality checking using
+ * expression id instead of standard java equality.  Using expression id means 
that these
+ * sets will correctly test for membership, even when the AttributeReferences 
in question differ
+ * cosmetically (e.g., the names have different capitalizations).
+ *
+ * Note that we do not override equality for Attribute references as it is 
really weird when
+ * `AttributeReference("a"...) == AttrributeReference("b", ...)`. This tactic 
leads to broken tests,
+ * and also makes doing transformations hard (we always try keep older trees 
instead of new ones
+ * when the transformation was a no-op).
+ */
+class AttributeSet private (val baseSet: Set[AttributeEquals])
+  extends Traversable[Attribute] with Serializable {
+
+  /** Returns true if the members of this AttributeSet and other are the same. 
*/
+  override def equals(other: Any) = other match {
+    case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains)
+    case _ => false
+  }
+
+  /** Returns true if this set contains an Attribute with the same expression 
id as `elem` */
+  def contains(elem: NamedExpression): Boolean =
+    baseSet.contains(new AttributeEquals(elem.toAttribute))
+
+  /** Returns a new [[AttributeSet]] that contains `elem` in addition to the 
current elements. */
+  def +(elem: Attribute): AttributeSet =  // scalastyle:ignore
+    new AttributeSet(baseSet + new AttributeEquals(elem))
+
+  /** Returns a new [[AttributeSet]] that does not contain `elem`. */
+  def -(elem: Attribute): AttributeSet =
+    new AttributeSet(baseSet - new AttributeEquals(elem))
+
+  /** Returns an iterator containing all of the attributes in the set. */
+  def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator
+
+  /**
+   * Returns true if the [[Attribute Attributes]] in this set are a subset of 
the Attributes in
+   * `other`.
+   */
+  def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet)
+
+  /**
+   * Returns a new [[AttributeSet]] that does not contain any of the 
[[Attribute Attributes]] found
+   * in `other`.
+   */
+  def --(other: Traversable[NamedExpression]) =
+    new AttributeSet(baseSet -- other.map(a => new 
AttributeEquals(a.toAttribute)))
+
+  /**
+   * Returns a new [[AttributeSet]] that contains all of the [[Attribute 
Attributes]] found
+   * in `other`.
+   */
+  def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet)
+
+  /**
+   * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] 
where `f` evaluates to
+   * true.
+   */
+  override def filter(f: Attribute => Boolean) = new 
AttributeSet(baseSet.filter(ae => f(ae.a)))
+
+  /**
+   * Returns a new [[AttributeSet]] that only contains [[Attribute 
Attributes]] that are found in
+   * `this` and `other`.
+   */
+  def intersect(other: AttributeSet) = new 
AttributeSet(baseSet.intersect(other.baseSet))
+
+  override def foreach[U](f: (Attribute) => U): Unit = 
baseSet.map(_.a).foreach(f)
+
+  // We must force toSeq to not be strict otherwise we end up with a 
[[Stream]] that captures all
+  // sorts of things in its closure.
+  override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 0913f15..54c6baf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -32,8 +32,6 @@ case class BoundReference(ordinal: Int, dataType: DataType, 
nullable: Boolean)
 
   type EvaluatedType = Any
 
-  override def references = Set.empty
-
   override def toString = s"input[$ordinal]"
 
   override def eval(input: Row): Any = input(ordinal)

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index ba62dab..70507e7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -41,7 +41,7 @@ abstract class Expression extends TreeNode[Expression] {
    */
   def foldable: Boolean = false
   def nullable: Boolean
-  def references: Set[Attribute]
+  def references: AttributeSet = 
AttributeSet(children.flatMap(_.references.iterator))
 
   /** Returns the result of evaluating this expression on a given input Row */
   def eval(input: Row = null): EvaluatedType
@@ -230,8 +230,6 @@ abstract class BinaryExpression extends Expression with 
trees.BinaryNode[Express
 
   override def foldable = left.foldable && right.foldable
 
-  override def references = left.references ++ right.references
-
   override def toString = s"($left $symbol $right)"
 }
 
@@ -242,5 +240,5 @@ abstract class LeafExpression extends Expression with 
trees.LeafNode[Expression]
 abstract class UnaryExpression extends Expression with 
trees.UnaryNode[Expression] {
   self: Product =>
 
-  override def references = child.references
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
index 38f836f..851db95 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.types.DoubleType
 case object Rand extends LeafExpression {
   override def dataType = DoubleType
   override def nullable = false
-  override def references = Set.empty
 
   private[this] lazy val rand = new Random
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 95633dd..63ac2a6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -24,7 +24,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, 
children: Seq[Expressi
 
   type EvaluatedType = Any
 
-  def references = children.flatMap(_.references).toSet
   def nullable = true
 
   /** This method has been generated by this script

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index d2b7685..d00b2ac 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -31,7 +31,6 @@ case object Descending extends SortDirection
 case class SortOrder(child: Expression, direction: SortDirection) extends 
Expression 
     with trees.UnaryNode[Expression] {
 
-  override def references = child.references
   override def dataType = child.dataType
   override def nullable = child.nullable
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
index eb88989..1eb5571 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -35,7 +35,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends 
Expression {
   type EvaluatedType = DynamicRow
 
   def nullable = false
-  def references = children.toSet
+
   def dataType = DynamicType
 
   override def eval(input: Row): DynamicRow = input match {

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 613b87c..dbc0c29 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -78,7 +78,7 @@ abstract class AggregateFunction
 
   /** Base should return the generic aggregate expression that this function 
is computing */
   val base: AggregateExpression
-  override def references = base.references
+
   override def nullable = base.nullable
   override def dataType = base.dataType
 
@@ -89,7 +89,7 @@ abstract class AggregateFunction
 }
 
 case class Min(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = true
   override def dataType = child.dataType
   override def toString = s"MIN($child)"
@@ -119,7 +119,7 @@ case class MinFunction(expr: Expression, base: 
AggregateExpression) extends Aggr
 }
 
 case class Max(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = true
   override def dataType = child.dataType
   override def toString = s"MAX($child)"
@@ -149,7 +149,7 @@ case class MaxFunction(expr: Expression, base: 
AggregateExpression) extends Aggr
 }
 
 case class Count(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = LongType
   override def toString = s"COUNT($child)"
@@ -166,7 +166,7 @@ case class CountDistinct(expressions: Seq[Expression]) 
extends PartialAggregate
   def this() = this(null)
 
   override def children = expressions
-  override def references = expressions.flatMap(_.references).toSet
+
   override def nullable = false
   override def dataType = LongType
   override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
@@ -184,7 +184,6 @@ case class CollectHashSet(expressions: Seq[Expression]) 
extends AggregateExpress
   def this() = this(null)
 
   override def children = expressions
-  override def references = expressions.flatMap(_.references).toSet
   override def nullable = false
   override def dataType = ArrayType(expressions.head.dataType)
   override def toString = s"AddToHashSet(${expressions.mkString(",")})"
@@ -219,7 +218,6 @@ case class CombineSetsAndCount(inputSet: Expression) 
extends AggregateExpression
   def this() = this(null)
 
   override def children = inputSet :: Nil
-  override def references = inputSet.references
   override def nullable = false
   override def dataType = LongType
   override def toString = s"CombineAndCount($inputSet)"
@@ -248,7 +246,7 @@ case class CombineSetsAndCountFunction(
 
 case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
   extends AggregateExpression with trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = child.dataType
   override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -257,7 +255,7 @@ case class ApproxCountDistinctPartition(child: Expression, 
relativeSD: Double)
 
 case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
   extends AggregateExpression with trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = LongType
   override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -266,7 +264,7 @@ case class ApproxCountDistinctMerge(child: Expression, 
relativeSD: Double)
 
 case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
   extends PartialAggregate with trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = LongType
   override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -284,7 +282,7 @@ case class ApproxCountDistinct(child: Expression, 
relativeSD: Double = 0.05)
 }
 
 case class Average(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = DoubleType
   override def toString = s"AVG($child)"
@@ -304,7 +302,7 @@ case class Average(child: Expression) extends 
PartialAggregate with trees.UnaryN
 }
 
 case class Sum(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = child.dataType
   override def toString = s"SUM($child)"
@@ -322,7 +320,7 @@ case class Sum(child: Expression) extends PartialAggregate 
with trees.UnaryNode[
 case class SumDistinct(child: Expression)
   extends AggregateExpression with trees.UnaryNode[Expression] {
 
-  override def references = child.references
+
   override def nullable = false
   override def dataType = child.dataType
   override def toString = s"SUM(DISTINCT $child)"
@@ -331,7 +329,6 @@ case class SumDistinct(child: Expression)
 }
 
 case class First(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
-  override def references = child.references
   override def nullable = true
   override def dataType = child.dataType
   override def toString = s"FIRST($child)"

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 5f8b6ae..aae86a3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -95,8 +95,6 @@ case class MaxOf(left: Expression, right: Expression) extends 
Expression {
 
   override def children = left :: right :: Nil
 
-  override def references = left.references ++ right.references
-
   override def dataType = left.dataType
 
   override def eval(input: Row): Any = {

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index c1154eb..dafd745 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -31,7 +31,7 @@ case class GetItem(child: Expression, ordinal: Expression) 
extends Expression {
   /** `Null` is returned for invalid ordinals. */
   override def nullable = true
   override def foldable = child.foldable && ordinal.foldable
-  override def references = children.flatMap(_.references).toSet
+
   def dataType = child.dataType match {
     case ArrayType(dt, _) => dt
     case MapType(_, vt, _) => vt

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index e99c5b4..9c86525 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -47,8 +47,6 @@ abstract class Generator extends Expression {
 
   override def nullable = false
 
-  override def references = children.flatMap(_.references).toSet
-
   /**
    * Should be overridden by specific generators.  Called only once for each 
instance to ensure
    * that rule application does not change the output schema of a generator.

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e15e16d..a8c2396 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -52,7 +52,7 @@ case class Literal(value: Any, dataType: DataType) extends 
LeafExpression {
 
   override def foldable = true
   def nullable = value == null
-  def references = Set.empty
+
 
   override def toString = if (value != null) value.toString else "null"
 
@@ -66,8 +66,6 @@ case class MutableLiteral(var value: Any, nullable: Boolean = 
true) extends Leaf
 
   val dataType = Literal(value).dataType
 
-  def references = Set.empty
-
   def update(expression: Expression, input: Row) = {
     value = expression.eval(input)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 02d0476..7c4b9d4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression {
 
   def toAttribute = this
   def newInstance: Attribute
-  override def references = Set(this)
+
 }
 
 /**
@@ -85,7 +85,7 @@ case class Alias(child: Expression, name: String)
 
   override def dataType = child.dataType
   override def nullable = child.nullable
-  override def references = child.references
+
 
   override def toAttribute = {
     if (resolved) {
@@ -116,6 +116,8 @@ case class AttributeReference(name: String, dataType: 
DataType, nullable: Boolea
     (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: 
Seq[String] = Nil)
   extends Attribute with trees.LeafNode[Expression] {
 
+  override def references = AttributeSet(this :: Nil)
+
   override def equals(other: Any) = other match {
     case ar: AttributeReference => exprId == ar.exprId && dataType == 
ar.dataType
     case _ => false

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index e88c5d4..086d0a3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -26,7 +26,6 @@ case class Coalesce(children: Seq[Expression]) extends 
Expression {
   /** Coalesce is nullable if all of its children are nullable, or if it has 
no children. */
   def nullable = !children.exists(!_.nullable)
 
-  def references = children.flatMap(_.references).toSet
   // Coalesce is foldable if all children are foldable.
   override def foldable = !children.exists(!_.foldable)
 
@@ -53,7 +52,6 @@ case class Coalesce(children: Seq[Expression]) extends 
Expression {
 }
 
 case class IsNull(child: Expression) extends Predicate with 
trees.UnaryNode[Expression] {
-  def references = child.references
   override def foldable = child.foldable
   def nullable = false
 
@@ -65,7 +63,6 @@ case class IsNull(child: Expression) extends Predicate with 
trees.UnaryNode[Expr
 }
 
 case class IsNotNull(child: Expression) extends Predicate with 
trees.UnaryNode[Expression] {
-  def references = child.references
   override def foldable = child.foldable
   def nullable = false
   override def toString = s"IS NOT NULL $child"

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 5976b0d..1313ccd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression 
with Predicate {
  */
 case class In(value: Expression, list: Seq[Expression]) extends Predicate {
   def children = value +: list
-  def references = children.flatMap(_.references).toSet
+
   def nullable = true // TODO: Figure out correct nullability semantics of IN.
   override def toString = s"$value IN ${list.mkString("(", ",", ")")}"
 
@@ -197,7 +197,7 @@ case class If(predicate: Expression, trueValue: Expression, 
falseValue: Expressi
 
   def children = predicate :: trueValue :: falseValue :: Nil
   override def nullable = trueValue.nullable || falseValue.nullable
-  def references = children.flatMap(_.references).toSet
+
   override lazy val resolved = childrenResolved && trueValue.dataType == 
falseValue.dataType
   def dataType = {
     if (!resolved) {
@@ -239,7 +239,7 @@ case class If(predicate: Expression, trueValue: Expression, 
falseValue: Expressi
 case class CaseWhen(branches: Seq[Expression]) extends Expression {
   type EvaluatedType = Any
   def children = branches
-  def references = children.flatMap(_.references).toSet
+
   def dataType = {
     if (!resolved) {
       throw new UnresolvedException(this, "cannot resolve due to differing 
types in some branches")

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index e6c570b..3d4c4a8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -26,8 +26,6 @@ import org.apache.spark.util.collection.OpenHashSet
 case class NewSet(elementType: DataType) extends LeafExpression {
   type EvaluatedType = Any
 
-  def references = Set.empty
-
   def nullable = false
 
   // We are currently only using these Expressions internally for aggregation. 
 However, if we ever
@@ -53,9 +51,6 @@ case class AddItemToSet(item: Expression, set: Expression) 
extends Expression {
   def nullable = set.nullable
 
   def dataType = set.dataType
-
-  def references = (item.flatMap(_.references) ++ 
set.flatMap(_.references)).toSet
-
   def eval(input: Row): Any = {
     val itemEval = item.eval(input)
     val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 97fc3a3..c2a3a5c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -226,8 +226,6 @@ case class Substring(str: Expression, pos: Expression, len: 
Expression) extends
     if (str.dataType == BinaryType) str.dataType else StringType
   }
 
-  def references = children.flatMap(_.references).toSet
-
   override def children = str :: pos :: len :: Nil
 
   @inline

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 5f86d60..ddd4b37 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -65,8 +65,10 @@ object ColumnPruning extends Rule[LogicalPlan] {
     // Eliminate unneeded attributes from either side of a Join.
     case Project(projectList, Join(left, right, joinType, condition)) =>
       // Collect the list of all references required either above or to 
evaluate the condition.
-      val allReferences: Set[Attribute] =
-        projectList.flatMap(_.references).toSet ++ 
condition.map(_.references).getOrElse(Set.empty)
+      val allReferences: AttributeSet =
+        AttributeSet(
+          projectList.flatMap(_.references.iterator)) ++
+          condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
 
       /** Applies a projection only when the child is producing unnecessary 
attributes */
       def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences)
@@ -76,8 +78,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
     // Eliminate unneeded attributes from right side of a LeftSemiJoin.
     case Join(left, right, LeftSemi, condition) =>
       // Collect the list of all references required to evaluate the condition.
-      val allReferences: Set[Attribute] =
-        condition.map(_.references).getOrElse(Set.empty)
+      val allReferences: AttributeSet =
+        condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
 
       Join(left, prunedChild(right, allReferences), LeftSemi, condition)
 
@@ -104,7 +106,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
   }
 
   /** Applies a projection only when the child is producing unnecessary 
attributes */
-  private def prunedChild(c: LogicalPlan, allReferences: Set[Attribute]) =
+  private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
     if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
       Project(allReferences.filter(c.outputSet.contains).toSeq, c)
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0988b0c..1e177e2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans
 
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, 
Expression}
 import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, 
StructType}
 
@@ -29,7 +29,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] 
extends TreeNode[PlanTy
   /**
    * Returns the set of attributes that are output by this node.
    */
-  def outputSet: Set[Attribute] = output.toSet
+  def outputSet: AttributeSet = AttributeSet(output)
 
   /**
    * Runs [[transform]] with `rule` on all expressions present in this query 
operator.

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 278569f..8616ac4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -46,16 +46,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
   )
 
   /**
-   * Returns the set of attributes that are referenced by this node
-   * during evaluation.
-   */
-  def references: Set[Attribute]
-
-  /**
    * Returns the set of attributes that this node takes as
    * input from its children.
    */
-  lazy val inputSet: Set[Attribute] = children.flatMap(_.output).toSet
+  lazy val inputSet: AttributeSet = AttributeSet(children.flatMap(_.output))
 
   /**
    * Returns true if this expression and all its children have been resolved 
to a specific schema
@@ -126,9 +120,6 @@ abstract class LeafNode extends LogicalPlan with 
trees.LeafNode[LogicalPlan] {
 
   override lazy val statistics: Statistics =
     throw new UnsupportedOperationException(s"LeafNode $nodeName must 
implement statistics.")
-
-  // Leaf nodes by definition cannot reference any input attributes.
-  override def references = Set.empty
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
index d3f9d0f..4460c86 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
@@ -30,6 +30,4 @@ case class ScriptTransformation(
     input: Seq[Expression],
     script: String,
     output: Seq[Attribute],
-    child: LogicalPlan) extends UnaryNode {
-  def references = input.flatMap(_.references).toSet
-}
+    child: LogicalPlan) extends UnaryNode

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 3cb4072..4adfb18 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.types._
 
 case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) 
extends UnaryNode {
   def output = projectList.map(_.toAttribute)
-  def references = projectList.flatMap(_.references).toSet
 }
 
 /**
@@ -59,14 +58,10 @@ case class Generate(
 
   override def output =
     if (join) child.output ++ generatorOutput else generatorOutput
-
-  override def references =
-    if (join) child.outputSet else generator.references
 }
 
 case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode 
{
   override def output = child.output
-  override def references = condition.references
 }
 
 case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
@@ -76,8 +71,6 @@ case class Union(left: LogicalPlan, right: LogicalPlan) 
extends BinaryNode {
   override lazy val resolved =
     childrenResolved &&
     !left.output.zip(right.output).exists { case (l,r) => l.dataType != 
r.dataType }
-
-  override def references = Set.empty
 }
 
 case class Join(
@@ -86,8 +79,6 @@ case class Join(
   joinType: JoinType,
   condition: Option[Expression]) extends BinaryNode {
 
-  override def references = condition.map(_.references).getOrElse(Set.empty)
-
   override def output = {
     joinType match {
       case LeftSemi =>
@@ -106,8 +97,6 @@ case class Join(
 
 case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
   def output = left.output
-
-  def references = Set.empty
 }
 
 case class InsertIntoTable(
@@ -118,7 +107,6 @@ case class InsertIntoTable(
   extends LogicalPlan {
   // The table being inserted into is a child for the purposes of 
transformations.
   override def children = table :: child :: Nil
-  override def references = Set.empty
   override def output = child.output
 
   override lazy val resolved = childrenResolved && 
child.output.zip(table.output).forall {
@@ -130,20 +118,17 @@ case class InsertIntoCreatedTable(
     databaseName: Option[String],
     tableName: String,
     child: LogicalPlan) extends UnaryNode {
-  override def references = Set.empty
   override def output = child.output
 }
 
 case class WriteToFile(
     path: String,
     child: LogicalPlan) extends UnaryNode {
-  override def references = Set.empty
   override def output = child.output
 }
 
 case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode {
   override def output = child.output
-  override def references = order.flatMap(_.references).toSet
 }
 
 case class Aggregate(
@@ -152,19 +137,20 @@ case class Aggregate(
     child: LogicalPlan)
   extends UnaryNode {
 
+  /** The set of all AttributeReferences required for this aggregation. */
+  def references =
+    AttributeSet(
+      groupingExpressions.flatMap(_.references) ++ 
aggregateExpressions.flatMap(_.references))
+
   override def output = aggregateExpressions.map(_.toAttribute)
-  override def references =
-    (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet
 }
 
 case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
   override def output = child.output
-  override def references = limitExpr.references
 }
 
 case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
   override def output = child.output.map(_.withQualifiers(alias :: Nil))
-  override def references = Set.empty
 }
 
 /**
@@ -191,20 +177,16 @@ case class LowerCaseSchema(child: LogicalPlan) extends 
UnaryNode {
         a.qualifiers)
     case other => other
   }
-
-  override def references = Set.empty
 }
 
 case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, 
child: LogicalPlan)
     extends UnaryNode {
 
   override def output = child.output
-  override def references = Set.empty
 }
 
 case class Distinct(child: LogicalPlan) extends UnaryNode {
   override def output = child.output
-  override def references = child.outputSet
 }
 
 case object NoRelation extends LeafNode {
@@ -213,5 +195,4 @@ case object NoRelation extends LeafNode {
 
 case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode 
{
   override def output = left.output
-  override def references = Set.empty
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
index 7146fbd..72b0c5c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
@@ -31,13 +31,9 @@ abstract class RedistributeData extends UnaryNode {
 
 case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan)
     extends RedistributeData {
-
-  def references = sortExpressions.flatMap(_.references).toSet
 }
 
 case class Repartition(partitionExpressions: Seq[Expression], child: 
LogicalPlan)
     extends RedistributeData {
-
-  def references = partitionExpressions.flatMap(_.references).toSet
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 4bb022c..ccb0df1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -71,6 +71,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) 
extends Distribution {
       "An AllTuples should be used to represent a distribution that only has " 
+
       "a single partition.")
 
+  // TODO: This is not really valid...
   def clustering = ordering.map(_.child).toSet
 }
 
@@ -139,7 +140,6 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
   with Partitioning {
 
   override def children = expressions
-  override def references = expressions.flatMap(_.references).toSet
   override def nullable = false
   override def dataType = IntegerType
 
@@ -179,7 +179,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], 
numPartitions: Int)
   with Partitioning {
 
   override def children = ordering
-  override def references = ordering.flatMap(_.references).toSet
   override def nullable = false
   override def dataType = IntegerType
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 6344874..2962025 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.types.{StringType, 
NullType}
 
 case class Dummy(optKey: Option[Expression]) extends Expression {
   def children = optKey.toSeq
-  def references = Set.empty[Attribute]
   def nullable = true
   def dataType = NullType
   override lazy val resolved = true

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 8a9f4de..6f0eed3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -344,8 +344,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
         prunePushedDownFilters: Seq[Expression] => Seq[Expression],
         scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {
 
-      val projectSet = projectList.flatMap(_.references).toSet
-      val filterSet = filterPredicates.flatMap(_.references).toSet
+      val projectSet = AttributeSet(projectList.flatMap(_.references))
+      val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
       val filterCondition = 
prunePushedDownFilters(filterPredicates).reduceLeftOption(And)
 
       // Right now we still use a projection even if the only evaluation is 
applying an alias
@@ -354,7 +354,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
       // TODO: Decouple final output schema from expression evaluation so this 
copy can be
       // avoided safely.
 
-      if (projectList.toSet == projectSet && filterSet.subsetOf(projectSet)) {
+      if (AttributeSet(projectList.map(_.toAttribute)) == projectSet &&
+          filterSet.subsetOf(projectSet)) {
         // When it is possible to just use column pruning to get the right 
projection and
         // when the columns of this projection are enough to evaluate all 
filter conditions,
         // just do a scan followed by a filter, with no extra project.

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index e63b490..24e88ee 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -79,8 +79,6 @@ private[sql] case class InMemoryRelation(
 
   override def children = Seq.empty
 
-  override def references = Set.empty
-
   override def newInstance() = {
     new InMemoryRelation(
       output.map(_.newInstance),

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 21cbbc9..7d33ea5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -141,10 +141,9 @@ case class SparkLogicalPlan(alreadyPlanned: 
SparkPlan)(@transient sqlContext: SQ
   extends LogicalPlan with MultiInstanceRelation {
 
   def output = alreadyPlanned.output
-  override def references = Set.empty
   override def children = Nil
 
-  override final def newInstance: this.type = {
+  override final def newInstance(): this.type = {
     SparkLogicalPlan(
       alreadyPlanned match {
         case ExistingRdd(output, rdd) => 
ExistingRdd(output.map(_.newInstance), rdd)

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index f31df05..5b896c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -58,8 +58,6 @@ package object debug {
   }
 
   private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
-    def references = Set.empty
-
     def output = child.output
 
     implicit object SetAccumulatorParam extends 
AccumulatorParam[HashSet[String]] {

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index b92091b..aef6ebf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -49,7 +49,6 @@ private[spark] case class PythonUDF(
   override def toString = s"PythonUDF#$name(${children.mkString(",")})"
 
   def nullable: Boolean = true
-  def references: Set[Attribute] = children.flatMap(_.references).toSet
 
   override def eval(input: Row) = sys.error("PythonUDFs can not be directly 
evaluated.")
 }
@@ -113,7 +112,6 @@ private[spark] object ExtractPythonUdfs extends 
Rule[LogicalPlan] {
 case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends 
logical.UnaryNode {
   val resultAttribute = AttributeReference("pythonUDF", udf.dataType, 
nullable=true)()
 
-  def references = Set.empty
   def output = child.output :+ resultAttribute
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 389ace7..10fa831 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -79,9 +79,9 @@ private[hive] trait HiveStrategies {
              hiveContext.convertMetastoreParquet =>
 
         // Filter out all predicates that only deal with partition keys
-        val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet
+        val partitionsKeys = AttributeSet(relation.partitionKeys)
         val (pruningPredicates, otherPredicates) = predicates.partition {
-          _.references.map(_.exprId).subsetOf(partitionKeyIds)
+          _.references.subsetOf(partitionsKeys)
         }
 
         // We are going to throw the predicates and projection back at the 
whole optimization
@@ -176,9 +176,9 @@ private[hive] trait HiveStrategies {
       case PhysicalOperation(projectList, predicates, relation: 
MetastoreRelation) =>
         // Filter out all predicates that only deal with partition keys, these 
are given to the
         // hive table scan operator to be used for partition pruning.
-        val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet
+        val partitionKeyIds = AttributeSet(relation.partitionKeys)
         val (pruningPredicates, otherPredicates) = predicates.partition {
-          _.references.map(_.exprId).subsetOf(partitionKeyIds)
+          _.references.subsetOf(partitionKeyIds)
         }
 
         pruneFilterProject(

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index c6497a1..7d1ad53 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -88,7 +88,6 @@ private[hive] abstract class HiveUdf extends Expression with 
Logging with HiveFu
   type EvaluatedType = Any
 
   def nullable = true
-  def references = children.flatMap(_.references).toSet
 
   lazy val function = createFunction[UDFType]()
 
@@ -229,8 +228,6 @@ private[hive] case class HiveGenericUdaf(
 
   def nullable: Boolean = true
 
-  def references: Set[Attribute] = children.map(_.references).flatten.toSet
-
   override def toString = 
s"$nodeName#$functionClassName(${children.mkString(",")})"
 
   def newInstance() = new HiveUdafFunction(functionClassName, children, this)
@@ -253,8 +250,6 @@ private[hive] case class HiveGenericUdtf(
     children: Seq[Expression])
   extends Generator with HiveInspectors with HiveFunctionFactory {
 
-  override def references = children.flatMap(_.references).toSet
-
   @transient
   protected lazy val function: GenericUDTF = createFunction()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4787a36/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
index 6b3ffd1..b6be6bc 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.hive.execution
 import org.apache.spark.sql.hive.test.TestHive
 import org.apache.spark.sql.hive.test.TestHive._
 
-case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested])
 case class Nested(a: Int, B: Int)
+case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested])
 
 /**
  * A set of test cases expressed in Hive QL that are not covered by the tests 
included in the hive distribution.
@@ -57,6 +57,13 @@ class HiveResolutionSuite extends HiveComparisonTest {
       .registerTempTable("caseSensitivityTest")
 
     sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
+
+    println(sql("SELECT * FROM casesensitivitytest one JOIN 
casesensitivitytest two ON one.a = two.a").queryExecution)
+
+    sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON 
one.a = two.a").collect()
+
+    // TODO: sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest 
b ON a.a = b.a")
+
   }
 
   test("nested repeated resolution") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to