Repository: spark
Updated Branches:
  refs/heads/master 1132e472e -> 9d824fed8


[SQL] SPARK-1800 Add broadcast hash join operator & associated hints.

This PR is based off Michael's [PR 
734](https://github.com/apache/spark/pull/734) and includes a bunch of cleanups.

Moreover, this PR also
- makes `SparkLogicalPlan` take a `tableName: String`, which facilitates 
testing.
- moves join-related tests to a single file.

Author: Zongheng Yang <[email protected]>
Author: Michael Armbrust <[email protected]>

Closes #1163 from concretevitamin/auto-broadcast-hash-join and squashes the 
following commits:

d0f4991 [Zongheng Yang] Fix bug in broadcast hash join & add test to cover it.
af080d7 [Zongheng Yang] Fix in joinIterators()'s next().
440d277 [Zongheng Yang] Fixes to imports; add back requiredChildDistribution 
(lost when merging)
208d5f6 [Zongheng Yang] Make LeftSemiJoinHash mix in HashJoin.
ad6c7cc [Zongheng Yang] Minor cleanups.
814b3bf [Zongheng Yang] Merge branch 'master' into auto-broadcast-hash-join
a8a093e [Zongheng Yang] Minor cleanups.
6fd8443 [Zongheng Yang] Cut down size estimation related stuff.
a4267be [Zongheng Yang] Add test for broadcast hash join and related necessary 
refactorings:
0e64b08 [Zongheng Yang] Scalastyle fix.
91461c2 [Zongheng Yang] Merge branch 'master' into auto-broadcast-hash-join
7c7158b [Zongheng Yang] Prototype of auto conversion to broadcast hash join.
0ad122f [Zongheng Yang] Merge branch 'master' into auto-broadcast-hash-join
3e5d77c [Zongheng Yang] WIP: giant and messy WIP.
a92ed0c [Michael Armbrust] Formatting.
76ca434 [Michael Armbrust] A simple strategy that broadcasts tables only when 
they are found in a configuration hint.
cf6b381 [Michael Armbrust] Split out generic logic for hash joins and create 
two concrete physical operators: BroadcastHashJoin and ShuffledHashJoin.
a8420ca [Michael Armbrust] Copy records in executeCollect to avoid issues with 
mutable rows.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9d824fed
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9d824fed
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9d824fed

Branch: refs/heads/master
Commit: 9d824fed8c62dd6c87b4c855c2fea930c01b58f4
Parents: 1132e47
Author: Zongheng Yang <[email protected]>
Authored: Wed Jun 25 18:06:33 2014 -0700
Committer: Michael Armbrust <[email protected]>
Committed: Wed Jun 25 18:06:33 2014 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/Projection.scala   |   8 +-
 .../catalyst/plans/logical/BaseRelation.scala   |   1 -
 .../scala/org/apache/spark/sql/SQLConf.scala    |  17 ++
 .../scala/org/apache/spark/sql/SQLContext.scala |   6 +-
 .../apache/spark/sql/execution/SparkPlan.scala  |  15 +-
 .../spark/sql/execution/SparkStrategies.scala   |  54 ++++-
 .../spark/sql/execution/basicOperators.scala    |   1 -
 .../org/apache/spark/sql/execution/joins.scala  | 219 +++++++++++--------
 .../spark/sql/parquet/ParquetRelation.scala     |   5 +-
 .../org/apache/spark/sql/DslQuerySuite.scala    |  99 ---------
 .../scala/org/apache/spark/sql/JoinSuite.scala  | 173 +++++++++++++++
 .../scala/org/apache/spark/sql/QueryTest.scala  |   4 +-
 .../spark/sql/execution/PlannerSuite.scala      |  17 --
 .../org/apache/spark/sql/hive/HiveContext.scala |   2 +-
 .../spark/sql/hive/HiveMetastoreCatalog.scala   |   7 +-
 15 files changed, 395 insertions(+), 233 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index a9e976c..2c71d2c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -45,8 +45,10 @@ class Projection(expressions: Seq[Expression]) extends (Row 
=> Row) {
  * that schema.
  *
  * In contrast to a normal projection, a MutableProjection reuses the same 
underlying row object
- * each time an input row is added.  This significatly reduces the cost of 
calcuating the
- * projection, but means that it is not safe
+ * each time an input row is added.  This significantly reduces the cost of 
calculating the
+ * projection, but means that it is not safe to hold on to a reference to a 
[[Row]] after `next()`
+ * has been called on the [[Iterator]] that produced it. Instead, the user 
must call `Row.copy()`
+ * and hold on to the returned [[Row]] before calling `next()`.
  */
 case class MutableProjection(expressions: Seq[Expression]) extends (Row => 
Row) {
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
@@ -67,7 +69,7 @@ case class MutableProjection(expressions: Seq[Expression]) 
extends (Row => Row)
 }
 
 /**
- * A mutable wrapper that makes two rows appear appear as a single 
concatenated row.  Designed to
+ * A mutable wrapper that makes two rows appear as a single concatenated row.  
Designed to
  * be instantiated once per thread and reused.
  */
 class JoinedRow extends Row {

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala
index 7c61678..582334a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala
@@ -21,5 +21,4 @@ abstract class BaseRelation extends LeafNode {
   self: Product =>
 
   def tableName: String
-  def isPartitioned: Boolean = false
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index b378252..2fe7f94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -29,9 +29,26 @@ import scala.collection.JavaConverters._
  */
 trait SQLConf {
 
+  /** ************************ Spark SQL Params/Hints ******************* */
+  // TODO: refactor so that these hints accessors don't pollute the name space 
of SQLContext?
+
   /** Number of partitions to use for shuffle operators. */
   private[spark] def numShufflePartitions: Int = 
get("spark.sql.shuffle.partitions", "200").toInt
 
+  /**
+   * Upper bound on the sizes (in bytes) of the tables qualified for the auto 
conversion to
+   * a broadcast value during the physical executions of join operations.  
Setting this to 0
+   * effectively disables auto conversion.
+   * Hive setting: hive.auto.convert.join.noconditionaltask.size.
+   */
+  private[spark] def autoConvertJoinSize: Int =
+    get("spark.sql.auto.convert.join.size", "10000").toInt
+
+  /** A comma-separated list of table names marked to be broadcasted during 
joins. */
+  private[spark] def joinBroadcastTables: String = 
get("spark.sql.join.broadcastTables", "")
+
+  /** ********************** SQLConf functionality methods ************ */
+
   @transient
   private val settings = java.util.Collections.synchronizedMap(
     new java.util.HashMap[String, String]())

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 7195f97..7edb548 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -170,7 +170,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
    * @group userf
    */
   def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
-    catalog.registerTable(None, tableName, rdd.logicalPlan)
+    val name = tableName
+    val newPlan = rdd.logicalPlan transform {
+      case s @ SparkLogicalPlan(ExistingRdd(_, _), _) => s.copy(tableName = 
name)
+    }
+    catalog.registerTable(None, tableName, newPlan)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 07967fe..27dc091 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -23,9 +23,9 @@ import org.apache.spark.sql.{Logging, Row}
 import org.apache.spark.sql.catalyst.trees
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.catalyst.plans.{QueryPlan, logical}
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.BaseRelation
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
 
 /**
  * :: DeveloperApi ::
@@ -66,19 +66,20 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with 
Logging {
  * linking.
  */
 @DeveloperApi
-case class SparkLogicalPlan(alreadyPlanned: SparkPlan)
-  extends logical.LogicalPlan with MultiInstanceRelation {
+case class SparkLogicalPlan(alreadyPlanned: SparkPlan, tableName: String = 
"SparkLogicalPlan")
+  extends BaseRelation with MultiInstanceRelation {
 
   def output = alreadyPlanned.output
-  def references = Set.empty
-  def children = Nil
+  override def references = Set.empty
+  override def children = Nil
 
   override final def newInstance: this.type = {
     SparkLogicalPlan(
       alreadyPlanned match {
         case ExistingRdd(output, rdd) => 
ExistingRdd(output.map(_.newInstance), rdd)
         case _ => sys.error("Multiple instance of the same relation detected.")
-      }).asInstanceOf[this.type]
+      }, tableName)
+      .asInstanceOf[this.type]
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index bd8ae4c..3cd2996 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -21,10 +21,10 @@ import org.apache.spark.sql.{SQLContext, execution}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.parquet._
 import org.apache.spark.sql.columnar.{InMemoryRelation, 
InMemoryColumnarTableScan}
+import org.apache.spark.sql.parquet._
 
 private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
   self: SQLContext#SparkPlanner =>
@@ -45,14 +45,52 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
     }
   }
 
+  /**
+   * Uses the HashFilteredJoin pattern to find joins where at least some of 
the predicates can be
+   * evaluated by matching hash keys.
+   */
   object HashJoin extends Strategy with PredicateHelper {
+    private[this] def broadcastHashJoin(
+        leftKeys: Seq[Expression],
+        rightKeys: Seq[Expression],
+        left: LogicalPlan,
+        right: LogicalPlan,
+        condition: Option[Expression],
+        side: BuildSide) = {
+      val broadcastHashJoin = execution.BroadcastHashJoin(
+        leftKeys, rightKeys, side, planLater(left), 
planLater(right))(sqlContext)
+      condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) 
:: Nil
+    }
+
+    def broadcastTables: Seq[String] = 
sqlContext.joinBroadcastTables.split(",").toBuffer
+
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      // Find inner joins where at least some predicates can be evaluated by 
matching hash keys
-      // using the HashFilteredJoin pattern.
+      case HashFilteredJoin(
+              Inner,
+              leftKeys,
+              rightKeys,
+              condition,
+              left,
+              right @ PhysicalOperation(_, _, b: BaseRelation))
+        if broadcastTables.contains(b.tableName) =>
+          broadcastHashJoin(leftKeys, rightKeys, left, right, condition, 
BuildRight)
+
+      case HashFilteredJoin(
+              Inner,
+              leftKeys,
+              rightKeys,
+              condition,
+              left @ PhysicalOperation(_, _, b: BaseRelation),
+              right)
+        if broadcastTables.contains(b.tableName) =>
+          broadcastHashJoin(leftKeys, rightKeys, left, right, condition, 
BuildLeft)
+
       case HashFilteredJoin(Inner, leftKeys, rightKeys, condition, left, 
right) =>
         val hashJoin =
-          execution.HashJoin(leftKeys, rightKeys, BuildRight, planLater(left), 
planLater(right))
+          execution.ShuffledHashJoin(
+            leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
         condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
+
       case _ => Nil
     }
   }
@@ -62,10 +100,10 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case logical.Aggregate(groupingExpressions, aggregateExpressions, child) 
=>
         // Collect all aggregate expressions.
         val allAggregates =
-          aggregateExpressions.flatMap(_ collect { case a: AggregateExpression 
=> a})
+          aggregateExpressions.flatMap(_ collect { case a: AggregateExpression 
=> a })
         // Collect all aggregate expressions that can be computed partially.
         val partialAggregates =
-          aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => 
p})
+          aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => 
p })
 
         // Only do partial aggregation if supported by all aggregate 
expressions.
         if (allAggregates.size == partialAggregates.size) {
@@ -242,7 +280,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.ExistingRdd(Nil, singleRowRdd) :: Nil
       case logical.Repartition(expressions, child) =>
         execution.Exchange(HashPartitioning(expressions, numPartitions), 
planLater(child)) :: Nil
-      case SparkLogicalPlan(existingPlan) => existingPlan :: Nil
+      case SparkLogicalPlan(existingPlan, _) => existingPlan :: Nil
       case _ => Nil
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index b40d4e3..a278f1c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -205,4 +205,3 @@ object ExistingRdd {
 case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode 
{
   override def execute() = rdd
 }
-

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 84bdde3..32c5f26 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -18,12 +18,15 @@
 package org.apache.spark.sql.execution
 
 import scala.collection.mutable.{ArrayBuffer, BitSet}
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent._
+import scala.concurrent.duration._
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical._
 
 @DeveloperApi
 sealed abstract class BuildSide
@@ -34,28 +37,19 @@ case object BuildLeft extends BuildSide
 @DeveloperApi
 case object BuildRight extends BuildSide
 
-/**
- * :: DeveloperApi ::
- */
-@DeveloperApi
-case class HashJoin(
-    leftKeys: Seq[Expression],
-    rightKeys: Seq[Expression],
-    buildSide: BuildSide,
-    left: SparkPlan,
-    right: SparkPlan) extends BinaryNode {
-
-  override def outputPartitioning: Partitioning = left.outputPartitioning
+trait HashJoin {
+  val leftKeys: Seq[Expression]
+  val rightKeys: Seq[Expression]
+  val buildSide: BuildSide
+  val left: SparkPlan
+  val right: SparkPlan
 
-  override def requiredChildDistribution =
-    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
-  val (buildPlan, streamedPlan) = buildSide match {
+  lazy val (buildPlan, streamedPlan) = buildSide match {
     case BuildLeft => (left, right)
     case BuildRight => (right, left)
   }
 
-  val (buildKeys, streamedKeys) = buildSide match {
+  lazy val (buildKeys, streamedKeys) = buildSide match {
     case BuildLeft => (leftKeys, rightKeys)
     case BuildRight => (rightKeys, leftKeys)
   }
@@ -66,73 +60,74 @@ case class HashJoin(
   @transient lazy val streamSideKeyGenerator =
     () => new MutableProjection(streamedKeys, streamedPlan.output)
 
-  def execute() = {
-
-    buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, 
streamIter) =>
-      // TODO: Use Spark's HashMap implementation.
-      val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
-      var currentRow: Row = null
-
-      // Create a mapping of buildKeys -> rows
-      while (buildIter.hasNext) {
-        currentRow = buildIter.next()
-        val rowKey = buildSideKeyGenerator(currentRow)
-        if(!rowKey.anyNull) {
-          val existingMatchList = hashTable.get(rowKey)
-          val matchList = if (existingMatchList == null) {
-            val newMatchList = new ArrayBuffer[Row]()
-            hashTable.put(rowKey, newMatchList)
-            newMatchList
-          } else {
-            existingMatchList
-          }
-          matchList += currentRow.copy()
+  def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): 
Iterator[Row] = {
+    // TODO: Use Spark's HashMap implementation.
+
+    val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
+    var currentRow: Row = null
+
+    // Create a mapping of buildKeys -> rows
+    while (buildIter.hasNext) {
+      currentRow = buildIter.next()
+      val rowKey = buildSideKeyGenerator(currentRow)
+      if(!rowKey.anyNull) {
+        val existingMatchList = hashTable.get(rowKey)
+        val matchList = if (existingMatchList == null) {
+          val newMatchList = new ArrayBuffer[Row]()
+          hashTable.put(rowKey, newMatchList)
+          newMatchList
+        } else {
+          existingMatchList
         }
+        matchList += currentRow.copy()
       }
+    }
 
-      new Iterator[Row] {
-        private[this] var currentStreamedRow: Row = _
-        private[this] var currentHashMatches: ArrayBuffer[Row] = _
-        private[this] var currentMatchPosition: Int = -1
+    new Iterator[Row] {
+      private[this] var currentStreamedRow: Row = _
+      private[this] var currentHashMatches: ArrayBuffer[Row] = _
+      private[this] var currentMatchPosition: Int = -1
 
-        // Mutable per row objects.
-        private[this] val joinRow = new JoinedRow
+      // Mutable per row objects.
+      private[this] val joinRow = new JoinedRow
 
-        private[this] val joinKeys = streamSideKeyGenerator()
+      private[this] val joinKeys = streamSideKeyGenerator()
 
-        override final def hasNext: Boolean =
-          (currentMatchPosition != -1 && currentMatchPosition < 
currentHashMatches.size) ||
+      override final def hasNext: Boolean =
+        (currentMatchPosition != -1 && currentMatchPosition < 
currentHashMatches.size) ||
           (streamIter.hasNext && fetchNext())
 
-        override final def next() = {
-          val ret = joinRow(currentStreamedRow, 
currentHashMatches(currentMatchPosition))
-          currentMatchPosition += 1
-          ret
+      override final def next() = {
+        val ret = buildSide match {
+          case BuildRight => joinRow(currentStreamedRow, 
currentHashMatches(currentMatchPosition))
+          case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), 
currentStreamedRow)
         }
+        currentMatchPosition += 1
+        ret
+      }
 
-        /**
-         * Searches the streamed iterator for the next row that has at least 
one match in hashtable.
-         *
-         * @return true if the search is successful, and false the streamed 
iterator runs out of
-         *         tuples.
-         */
-        private final def fetchNext(): Boolean = {
-          currentHashMatches = null
-          currentMatchPosition = -1
-
-          while (currentHashMatches == null && streamIter.hasNext) {
-            currentStreamedRow = streamIter.next()
-            if (!joinKeys(currentStreamedRow).anyNull) {
-              currentHashMatches = hashTable.get(joinKeys.currentValue)
-            }
+      /**
+       * Searches the streamed iterator for the next row that has at least one 
match in hashtable.
+       *
+       * @return true if the search is successful, and false if the streamed 
iterator runs out of
+       *         tuples.
+       */
+      private final def fetchNext(): Boolean = {
+        currentHashMatches = null
+        currentMatchPosition = -1
+
+        while (currentHashMatches == null && streamIter.hasNext) {
+          currentStreamedRow = streamIter.next()
+          if (!joinKeys(currentStreamedRow).anyNull) {
+            currentHashMatches = hashTable.get(joinKeys.currentValue)
           }
+        }
 
-          if (currentHashMatches == null) {
-            false
-          } else {
-            currentMatchPosition = 0
-            true
-          }
+        if (currentHashMatches == null) {
+          false
+        } else {
+          currentMatchPosition = 0
+          true
         }
       }
     }
@@ -141,32 +136,49 @@ case class HashJoin(
 
 /**
  * :: DeveloperApi ::
- * Build the right table's join keys into a HashSet, and iteratively go 
through the left
- * table, to find the if join keys are in the Hash set.
+ * Performs an inner hash join of two child relations by first shuffling the 
data using the join
+ * keys.
  */
 @DeveloperApi
-case class LeftSemiJoinHash(
+case class ShuffledHashJoin(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
+    buildSide: BuildSide,
     left: SparkPlan,
-    right: SparkPlan) extends BinaryNode {
+    right: SparkPlan) extends BinaryNode with HashJoin {
 
   override def outputPartitioning: Partitioning = left.outputPartitioning
 
   override def requiredChildDistribution =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
-  val (buildPlan, streamedPlan) = (right, left)
-  val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
+  def execute() = {
+    buildPlan.execute().zipPartitions(streamedPlan.execute()) {
+      (buildIter, streamIter) => joinIterators(buildIter, streamIter)
+    }
+  }
+}
 
-  def output = left.output
+/**
+ * :: DeveloperApi ::
+ * Build the right table's join keys into a HashSet, and iteratively go 
through the left
+ * table, to find the if join keys are in the Hash set.
+ */
+@DeveloperApi
+case class LeftSemiJoinHash(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    left: SparkPlan,
+    right: SparkPlan) extends BinaryNode with HashJoin {
 
-  @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, 
buildPlan.output)
-  @transient lazy val streamSideKeyGenerator =
-    () => new MutableProjection(streamedKeys, streamedPlan.output)
+  val buildSide = BuildRight
 
-  def execute() = {
+  override def requiredChildDistribution =
+    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+  override def output = left.output
 
+  def execute() = {
     buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, 
streamIter) =>
       val hashSet = new java.util.HashSet[Row]()
       var currentRow: Row = null
@@ -191,6 +203,43 @@ case class LeftSemiJoinHash(
   }
 }
 
+
+/**
+ * :: DeveloperApi ::
+ * Performs an inner hash join of two child relations.  When the output RDD of 
this operator is
+ * being constructed, a Spark job is asynchronously started to calculate the 
values for the
+ * broadcasted relation.  This data is then placed in a Spark broadcast 
variable.  The streamed
+ * relation is not shuffled.
+ */
+@DeveloperApi
+case class BroadcastHashJoin(
+     leftKeys: Seq[Expression],
+     rightKeys: Seq[Expression],
+     buildSide: BuildSide,
+     left: SparkPlan,
+     right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode 
with HashJoin {
+
+  override def otherCopyArgs = sqlContext :: Nil
+
+  override def outputPartitioning: Partitioning = left.outputPartitioning
+
+  override def requiredChildDistribution =
+    UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+
+  @transient
+  lazy val broadcastFuture = future {
+    sqlContext.sparkContext.broadcast(buildPlan.executeCollect())
+  }
+
+  def execute() = {
+    val broadcastRelation = Await.result(broadcastFuture, 5.minute)
+
+    streamedPlan.execute().mapPartitions { streamedIter =>
+      joinIterators(broadcastRelation.value.iterator, streamedIter)
+    }
+  }
+}
+
 /**
  * :: DeveloperApi ::
  * Using BroadcastNestedLoopJoin to calculate left semi join result when 
there's no join keys
@@ -220,7 +269,6 @@ case class LeftSemiJoinBNL(
         .map(c => BindReferences.bindReference(c, left.output ++ right.output))
         .getOrElse(Literal(true)))
 
-
   def execute() = {
     val broadcastedRelation =
       
sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
@@ -284,7 +332,6 @@ case class BroadcastNestedLoopJoin(
         .map(c => BindReferences.bindReference(c, left.output ++ right.output))
         .getOrElse(Literal(true)))
 
-
   def execute() = {
     val broadcastedRelation =
       
sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index 96c131a..9c4771d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -44,8 +44,9 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode}
  * @param path The path to the Parquet file.
  */
 private[sql] case class ParquetRelation(
-    val path: String,
-    @transient val conf: Option[Configuration] = None) extends LeafNode with 
MultiInstanceRelation {
+    path: String,
+    @transient conf: Option[Configuration] = None) extends LeafNode with 
MultiInstanceRelation {
+
   self: Product =>
 
   /** Schema derived from ParquetFile */

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index fb599e1..e4a64a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.test._
 
 /* Implicits */
@@ -149,102 +148,4 @@ class DslQuerySuite extends QueryTest {
   test("zero count") {
     assert(emptyTableData.count() === 0)
   }
-
-  test("inner join where, one match per row") {
-    checkAnswer(
-      upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
-      Seq(
-        (1, "A", 1, "a"),
-        (2, "B", 2, "b"),
-        (3, "C", 3, "c"),
-        (4, "D", 4, "d")
-      ))
-  }
-
-  test("inner join ON, one match per row") {
-    checkAnswer(
-      upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)),
-      Seq(
-        (1, "A", 1, "a"),
-        (2, "B", 2, "b"),
-        (3, "C", 3, "c"),
-        (4, "D", 4, "d")
-      ))
-  }
-
-  test("inner join, where, multiple matches") {
-    val x = testData2.where('a === 1).as('x)
-    val y = testData2.where('a === 1).as('y)
-    checkAnswer(
-      x.join(y).where("x.a".attr === "y.a".attr),
-      (1,1,1,1) ::
-      (1,1,1,2) ::
-      (1,2,1,1) ::
-      (1,2,1,2) :: Nil
-    )
-  }
-
-  test("inner join, no matches") {
-    val x = testData2.where('a === 1).as('x)
-    val y = testData2.where('a === 2).as('y)
-    checkAnswer(
-      x.join(y).where("x.a".attr === "y.a".attr),
-      Nil)
-  }
-
-  test("big inner join, 4 matches per row") {
-    val bigData = 
testData.unionAll(testData).unionAll(testData).unionAll(testData)
-    val bigDataX = bigData.as('x)
-    val bigDataY = bigData.as('y)
-
-    checkAnswer(
-      bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
-      testData.flatMap(
-        row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq)
-  }
-
-  test("cartisian product join") {
-    checkAnswer(
-      testData3.join(testData3),
-      (1, null, 1, null) ::
-      (1, null, 2, 2) ::
-      (2, 2, 1, null) ::
-      (2, 2, 2, 2) :: Nil)
-  }
-
-  test("left outer join") {
-    checkAnswer(
-      upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)),
-      (1, "A", 1, "a") ::
-      (2, "B", 2, "b") ::
-      (3, "C", 3, "c") ::
-      (4, "D", 4, "d") ::
-      (5, "E", null, null) ::
-      (6, "F", null, null) :: Nil)
-  }
-
-  test("right outer join") {
-    checkAnswer(
-      lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)),
-      (1, "a", 1, "A") ::
-      (2, "b", 2, "B") ::
-      (3, "c", 3, "C") ::
-      (4, "d", 4, "D") ::
-      (null, null, 5, "E") ::
-      (null, null, 6, "F") :: Nil)
-  }
-
-  test("full outer join") {
-    val left = upperCaseData.where('N <= 4).as('left)
-    val right = upperCaseData.where('N >= 3).as('right)
-
-    checkAnswer(
-      left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
-      (1, "A", null, null) ::
-      (2, "B", null, null) ::
-      (3, "C", 3, "C") ::
-      (4, "D", 4, "D") ::
-      (null, null, 5, "E") ::
-      (null, null, 6, "F") :: Nil)
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
new file mode 100644
index 0000000..3d7d5ee
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, 
Inner}
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.TestSQLContext._
+
+class JoinSuite extends QueryTest {
+
+  // Ensures tables are loaded.
+  TestData
+
+  test("equi-join is hash-join") {
+    val x = testData2.as('x)
+    val y = testData2.as('y)
+    val join = x.join(y, Inner, Some("x.a".attr === 
"y.a".attr)).queryExecution.analyzed
+    val planned = planner.HashJoin(join)
+    assert(planned.size === 1)
+  }
+
+  test("plans broadcast hash join, given hints") {
+
+    def mkTest(buildSide: BuildSide, leftTable: String, rightTable: String) = {
+      TestSQLContext.set("spark.sql.join.broadcastTables",
+        s"${if (buildSide == BuildRight) rightTable else leftTable}")
+      val rdd = sql(s"""SELECT * FROM $leftTable JOIN $rightTable ON key = 
a""")
+      // Using `sparkPlan` because for relevant patterns in HashJoin to be
+      // matched, other strategies need to be applied.
+      val physical = rdd.queryExecution.sparkPlan
+      val bhj = physical.collect { case j: BroadcastHashJoin if j.buildSide == 
buildSide => j }
+
+      assert(bhj.size === 1, "planner does not pick up hint to generate 
broadcast hash join")
+      checkAnswer(
+        rdd,
+        Seq(
+          (1, "1", 1, 1),
+          (1, "1", 1, 2),
+          (2, "2", 2, 1),
+          (2, "2", 2, 2),
+          (3, "3", 3, 1),
+          (3, "3", 3, 2)
+        ))
+    }
+
+    mkTest(BuildRight, "testData", "testData2")
+    mkTest(BuildLeft, "testData", "testData2")
+  }
+
+  test("multiple-key equi-join is hash-join") {
+    val x = testData2.as('x)
+    val y = testData2.as('y)
+    val join = x.join(y, Inner,
+      Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr)).queryExecution.analyzed
+    val planned = planner.HashJoin(join)
+    assert(planned.size === 1)
+  }
+
+  test("inner join where, one match per row") {
+    checkAnswer(
+      upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
+      Seq(
+        (1, "A", 1, "a"),
+        (2, "B", 2, "b"),
+        (3, "C", 3, "c"),
+        (4, "D", 4, "d")
+      ))
+  }
+
+  test("inner join ON, one match per row") {
+    checkAnswer(
+      upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)),
+      Seq(
+        (1, "A", 1, "a"),
+        (2, "B", 2, "b"),
+        (3, "C", 3, "c"),
+        (4, "D", 4, "d")
+      ))
+  }
+
+  test("inner join, where, multiple matches") {
+    val x = testData2.where('a === 1).as('x)
+    val y = testData2.where('a === 1).as('y)
+    checkAnswer(
+      x.join(y).where("x.a".attr === "y.a".attr),
+      (1,1,1,1) ::
+      (1,1,1,2) ::
+      (1,2,1,1) ::
+      (1,2,1,2) :: Nil
+    )
+  }
+
+  test("inner join, no matches") {
+    val x = testData2.where('a === 1).as('x)
+    val y = testData2.where('a === 2).as('y)
+    checkAnswer(
+      x.join(y).where("x.a".attr === "y.a".attr),
+      Nil)
+  }
+
+  test("big inner join, 4 matches per row") {
+    val bigData = 
testData.unionAll(testData).unionAll(testData).unionAll(testData)
+    val bigDataX = bigData.as('x)
+    val bigDataY = bigData.as('y)
+
+    checkAnswer(
+      bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
+      testData.flatMap(
+        row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq)
+  }
+
+  test("cartisian product join") {
+    checkAnswer(
+      testData3.join(testData3),
+      (1, null, 1, null) ::
+      (1, null, 2, 2) ::
+      (2, 2, 1, null) ::
+      (2, 2, 2, 2) :: Nil)
+  }
+
+  test("left outer join") {
+    checkAnswer(
+      upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)),
+      (1, "A", 1, "a") ::
+      (2, "B", 2, "b") ::
+      (3, "C", 3, "c") ::
+      (4, "D", 4, "d") ::
+      (5, "E", null, null) ::
+      (6, "F", null, null) :: Nil)
+  }
+
+  test("right outer join") {
+    checkAnswer(
+      lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)),
+      (1, "a", 1, "A") ::
+      (2, "b", 2, "B") ::
+      (3, "c", 3, "C") ::
+      (4, "d", 4, "D") ::
+      (null, null, 5, "E") ::
+      (null, null, 6, "F") :: Nil)
+  }
+
+  test("full outer join") {
+    val left = upperCaseData.where('N <= 4).as('left)
+    val right = upperCaseData.where('N >= 3).as('right)
+
+    checkAnswer(
+      left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
+      (1, "A", null, null) ::
+      (2, "B", null, null) ::
+      (3, "C", 3, "C") ::
+      (4, "D", 4, "D") ::
+      (null, null, 5, "E") ::
+      (null, null, 6, "F") :: Nil)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index ef84ead..8e1e197 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -35,7 +35,7 @@ class QueryTest extends PlanTest {
       case singleItem => Seq(Seq(singleItem))
     }
 
-    val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => 
s}.nonEmpty
+    val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s 
}.nonEmpty
     def prepareAnswer(answer: Seq[Any]) = if (!isSorted) 
answer.sortBy(_.toString) else answer
     val sparkAnswer = try rdd.collect().toSeq catch {
       case e: Exception =>
@@ -48,7 +48,7 @@ class QueryTest extends PlanTest {
           """.stripMargin)
     }
 
-    if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
+    if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
       fail(s"""
         |Results do not match for query:
         |${rdd.logicalPlan}

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index df6b118..215618e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -57,21 +57,4 @@ class PlannerSuite extends FunSuite {
     val planned = PartialAggregation(query)
     assert(planned.isEmpty)
   }
-
-  test("equi-join is hash-join") {
-    val x = testData2.as('x)
-    val y = testData2.as('y)
-    val join = x.join(y, Inner, Some("x.a".attr === 
"y.a".attr)).queryExecution.analyzed
-    val planned = planner.HashJoin(join)
-    assert(planned.size === 1)
-  }
-
-  test("multiple-key equi-join is hash-join") {
-    val x = testData2.as('x)
-    val y = testData2.as('y)
-    val join = x.join(y, Inner,
-      Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr)).queryExecution.analyzed
-    val planned = planner.HashJoin(join)
-    assert(planned.size === 1)
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 7695242..7aedfcd 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -258,7 +258,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
         struct.zip(fields).map {
           case (v, t) => s""""${t.name}":${toHiveStructString(v, 
t.dataType)}"""
         }.mkString("{", ",", "}")
-      case (seq: Seq[_], ArrayType(typ))=>
+      case (seq: Seq[_], ArrayType(typ)) =>
         seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
       case (map: Map[_,_], MapType(kType, vType)) =>
         map.map {

http://git-wip-us.apache.org/repos/asf/spark/blob/9d824fed/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index faa30c9..90eacf4 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -34,9 +34,8 @@ import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.execution.SparkLogicalPlan
-import org.apache.spark.sql.hive.execution.{HiveTableScan, InsertIntoHiveTable}
-import org.apache.spark.sql.columnar.{InMemoryRelation, 
InMemoryColumnarTableScan}
+import org.apache.spark.sql.columnar.InMemoryRelation
+import org.apache.spark.sql.hive.execution.HiveTableScan
 
 /* Implicit conversions */
 import scala.collection.JavaConversions._
@@ -259,8 +258,6 @@ private[hive] case class MetastoreRelation
     new Partition(hiveQlTable, p)
   }
 
-  override def isPartitioned = hiveQlTable.isPartitioned
-
   val tableDesc = new TableDesc(
     
Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]],
     hiveQlTable.getInputFormatClass,

Reply via email to