Repository: spark
Updated Branches:
  refs/heads/master 198d181df -> 21825529e


[SPARK-9247] [SQL] Use BytesToBytesMap for broadcast join

This PR introduce BytesToBytesMap to UnsafeHashedRelation, use it in executor 
for better performance.

It serialize all the key and values from java HashMap, put them into a 
BytesToBytesMap while deserializing. All the values for a same key are stored 
continuous to have better memory locality.

This PR also address the comments for #7480 , do some clean up.

Author: Davies Liu <[email protected]>

Closes #7592 from davies/unsafe_map2 and squashes the following commits:

42c578a [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
unsafe_map2
fd09528 [Davies Liu] remove thread local cache and update docs
1c5ad8d [Davies Liu] fix test
5eb1b5a [Davies Liu] address comments in #7480
46f1f22 [Davies Liu] fix style
fc221e0 [Davies Liu] use BytesToBytesMap for broadcast join


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/21825529
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/21825529
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/21825529

Branch: refs/heads/master
Commit: 21825529eae66293ec5d8638911303fa54944dd5
Parents: 198d181
Author: Davies Liu <[email protected]>
Authored: Tue Jul 28 15:56:19 2015 -0700
Committer: Davies Liu <[email protected]>
Committed: Tue Jul 28 15:56:19 2015 -0700

----------------------------------------------------------------------
 .../sql/execution/joins/BroadcastHashJoin.scala |   2 +-
 .../joins/BroadcastHashOuterJoin.scala          |   2 +-
 .../joins/BroadcastLeftSemiJoinHash.scala       |   6 +-
 .../joins/BroadcastNestedLoopJoin.scala         |  36 ++--
 .../spark/sql/execution/joins/HashJoin.scala    |  35 ++--
 .../sql/execution/joins/HashOuterJoin.scala     |  34 ++--
 .../sql/execution/joins/HashSemiJoin.scala      |  14 +-
 .../sql/execution/joins/HashedRelation.scala    | 166 ++++++++++++++-----
 .../sql/execution/joins/LeftSemiJoinHash.scala  |   2 +-
 .../sql/execution/joins/ShuffledHashJoin.scala  |   2 +-
 .../execution/joins/ShuffledHashOuterJoin.scala |   8 +-
 .../execution/joins/HashedRelationSuite.scala   |  28 ++--
 12 files changed, 214 insertions(+), 121 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index abaa4a6..624efc1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -62,7 +62,7 @@ case class BroadcastHashJoin(
   private val broadcastFuture = future {
     // Note that we use .execute().collect() because we don't want to convert 
data to Scala types
     val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
-    val hashed = buildHashRelation(input.iterator)
+    val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, 
input.size)
     sparkContext.broadcast(hashed)
   }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index c9d1a88..77e7fe7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -61,7 +61,7 @@ case class BroadcastHashOuterJoin(
   private val broadcastFuture = future {
     // Note that we use .execute().collect() because we don't want to convert 
data to Scala types
     val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
-    val hashed = buildHashRelation(input.iterator)
+    val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
     sparkContext.broadcast(hashed)
   }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index f71c0ce..a605939 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -37,17 +37,17 @@ case class BroadcastLeftSemiJoinHash(
     condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
 
   protected override def doExecute(): RDD[InternalRow] = {
-    val buildIter = right.execute().map(_.copy()).collect().toIterator
+    val input = right.execute().map(_.copy()).collect()
 
     if (condition.isEmpty) {
-      val hashSet = buildKeyHashSet(buildIter)
+      val hashSet = buildKeyHashSet(input.toIterator)
       val broadcastedRelation = sparkContext.broadcast(hashSet)
 
       left.execute().mapPartitions { streamIter =>
         hashSemiJoin(streamIter, broadcastedRelation.value)
       }
     } else {
-      val hashRelation = buildHashRelation(buildIter)
+      val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, 
input.size)
       val broadcastedRelation = sparkContext.broadcast(hashRelation)
 
       left.execute().mapPartitions { streamIter =>

http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 7006369..83b726a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -47,13 +47,11 @@ case class BroadcastNestedLoopJoin(
   override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || 
right.outputsUnsafeRows
   override def canProcessUnsafeRows: Boolean = true
 
-  @transient private[this] lazy val resultProjection: Projection = {
+  @transient private[this] lazy val resultProjection: InternalRow => 
InternalRow = {
     if (outputsUnsafeRows) {
       UnsafeProjection.create(schema)
     } else {
-      new Projection {
-        override def apply(r: InternalRow): InternalRow = r
-      }
+      identity[InternalRow]
     }
   }
 
@@ -96,7 +94,6 @@ case class BroadcastNestedLoopJoin(
         var streamRowMatched = false
 
         while (i < broadcastedRelation.value.size) {
-          // TODO: One bitset per partition instead of per row.
           val broadcastedRow = broadcastedRelation.value(i)
           buildSide match {
             case BuildRight if boundCondition(joinedRow(streamedRow, 
broadcastedRow)) =>
@@ -135,17 +132,26 @@ case class BroadcastNestedLoopJoin(
       val buf: CompactBuffer[InternalRow] = new CompactBuffer()
       var i = 0
       val rel = broadcastedRelation.value
-      while (i < rel.length) {
-        if (!allIncludedBroadcastTuples.contains(i)) {
-          (joinType, buildSide) match {
-            case (RightOuter | FullOuter, BuildRight) =>
-              buf += resultProjection(new JoinedRow(leftNulls, rel(i)))
-            case (LeftOuter | FullOuter, BuildLeft) =>
-              buf += resultProjection(new JoinedRow(rel(i), rightNulls))
-            case _ =>
+      (joinType, buildSide) match {
+        case (RightOuter | FullOuter, BuildRight) =>
+          val joinedRow = new JoinedRow
+          joinedRow.withLeft(leftNulls)
+          while (i < rel.length) {
+            if (!allIncludedBroadcastTuples.contains(i)) {
+              buf += resultProjection(joinedRow.withRight(rel(i))).copy()
+            }
+            i += 1
           }
-        }
-        i += 1
+        case (LeftOuter | FullOuter, BuildLeft) =>
+          val joinedRow = new JoinedRow
+          joinedRow.withRight(rightNulls)
+          while (i < rel.length) {
+            if (!allIncludedBroadcastTuples.contains(i)) {
+              buf += resultProjection(joinedRow.withLeft(rel(i))).copy()
+            }
+            i += 1
+          }
+        case _ =>
       }
       buf.toSeq
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 46ab5b0..6b3d165 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.util.collection.CompactBuffer
 
 
 trait HashJoin {
@@ -44,16 +43,24 @@ trait HashJoin {
 
   override def output: Seq[Attribute] = left.output ++ right.output
 
-  protected[this] def supportUnsafe: Boolean = {
+  protected[this] def isUnsafeMode: Boolean = {
     (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys)
       && UnsafeProjection.canSupport(self.schema))
   }
 
-  override def outputsUnsafeRows: Boolean = supportUnsafe
-  override def canProcessUnsafeRows: Boolean = supportUnsafe
+  override def outputsUnsafeRows: Boolean = isUnsafeMode
+  override def canProcessUnsafeRows: Boolean = isUnsafeMode
+  override def canProcessSafeRows: Boolean = !isUnsafeMode
+
+  @transient protected lazy val buildSideKeyGenerator: Projection =
+    if (isUnsafeMode) {
+      UnsafeProjection.create(buildKeys, buildPlan.output)
+    } else {
+      newMutableProjection(buildKeys, buildPlan.output)()
+    }
 
   @transient protected lazy val streamSideKeyGenerator: Projection =
-    if (supportUnsafe) {
+    if (isUnsafeMode) {
       UnsafeProjection.create(streamedKeys, streamedPlan.output)
     } else {
       newMutableProjection(streamedKeys, streamedPlan.output)()
@@ -65,18 +72,16 @@ trait HashJoin {
   {
     new Iterator[InternalRow] {
       private[this] var currentStreamedRow: InternalRow = _
-      private[this] var currentHashMatches: CompactBuffer[InternalRow] = _
+      private[this] var currentHashMatches: Seq[InternalRow] = _
       private[this] var currentMatchPosition: Int = -1
 
       // Mutable per row objects.
       private[this] val joinRow = new JoinedRow
-      private[this] val resultProjection: Projection = {
-        if (supportUnsafe) {
+      private[this] val resultProjection: (InternalRow) => InternalRow = {
+        if (isUnsafeMode) {
           UnsafeProjection.create(self.schema)
         } else {
-          new Projection {
-            override def apply(r: InternalRow): InternalRow = r
-          }
+          identity[InternalRow]
         }
       }
 
@@ -122,12 +127,4 @@ trait HashJoin {
       }
     }
   }
-
-  protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): 
HashedRelation = {
-    if (supportUnsafe) {
-      UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
-    } else {
-      HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
-    }
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 6bf2f82..7e671e7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -75,30 +75,36 @@ trait HashOuterJoin {
         s"HashOuterJoin should not take $x as the JoinType")
   }
 
-  protected[this] def supportUnsafe: Boolean = {
+  protected[this] def isUnsafeMode: Boolean = {
     (self.codegenEnabled && joinType != FullOuter
       && UnsafeProjection.canSupport(buildKeys)
       && UnsafeProjection.canSupport(self.schema))
   }
 
-  override def outputsUnsafeRows: Boolean = supportUnsafe
-  override def canProcessUnsafeRows: Boolean = supportUnsafe
+  override def outputsUnsafeRows: Boolean = isUnsafeMode
+  override def canProcessUnsafeRows: Boolean = isUnsafeMode
+  override def canProcessSafeRows: Boolean = !isUnsafeMode
 
-  protected[this] def streamedKeyGenerator(): Projection = {
-    if (supportUnsafe) {
+  @transient protected lazy val buildKeyGenerator: Projection =
+    if (isUnsafeMode) {
+      UnsafeProjection.create(buildKeys, buildPlan.output)
+    } else {
+      newMutableProjection(buildKeys, buildPlan.output)()
+    }
+
+  @transient protected[this] lazy val streamedKeyGenerator: Projection = {
+    if (isUnsafeMode) {
       UnsafeProjection.create(streamedKeys, streamedPlan.output)
     } else {
       newProjection(streamedKeys, streamedPlan.output)
     }
   }
 
-  @transient private[this] lazy val resultProjection: Projection = {
-    if (supportUnsafe) {
+  @transient private[this] lazy val resultProjection: InternalRow => 
InternalRow = {
+    if (isUnsafeMode) {
       UnsafeProjection.create(self.schema)
     } else {
-      new Projection {
-        override def apply(r: InternalRow): InternalRow = r
-      }
+      identity[InternalRow]
     }
   }
 
@@ -230,12 +236,4 @@ trait HashOuterJoin {
 
     hashTable
   }
-
-  protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): 
HashedRelation = {
-    if (supportUnsafe) {
-      UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
-    } else {
-      HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
-    }
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 7f49264..97fde8f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -35,11 +35,13 @@ trait HashSemiJoin {
   protected[this] def supportUnsafe: Boolean = {
     (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys)
       && UnsafeProjection.canSupport(rightKeys)
-      && UnsafeProjection.canSupport(left.schema))
+      && UnsafeProjection.canSupport(left.schema)
+      && UnsafeProjection.canSupport(right.schema))
   }
 
-  override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows
+  override def outputsUnsafeRows: Boolean = supportUnsafe
   override def canProcessUnsafeRows: Boolean = supportUnsafe
+  override def canProcessSafeRows: Boolean = !supportUnsafe
 
   @transient protected lazy val leftKeyGenerator: Projection =
     if (supportUnsafe) {
@@ -87,14 +89,6 @@ trait HashSemiJoin {
     })
   }
 
-  protected def buildHashRelation(buildIter: Iterator[InternalRow]): 
HashedRelation = {
-    if (supportUnsafe) {
-      UnsafeHashedRelation(buildIter, rightKeys, right)
-    } else {
-      HashedRelation(buildIter, newProjection(rightKeys, right.output))
-    }
-  }
-
   protected def hashSemiJoin(
       streamIter: Iterator[InternalRow],
       hashedRelation: HashedRelation): Iterator[InternalRow] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 8d5731a..9c058f1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -18,12 +18,15 @@
 package org.apache.spark.sql.execution.joins
 
 import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.nio.ByteOrder
 import java.util.{HashMap => JavaHashMap}
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.unsafe.PlatformDependent
+import org.apache.spark.unsafe.map.BytesToBytesMap
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, 
TaskMemoryManager}
 import org.apache.spark.util.collection.CompactBuffer
 
 
@@ -32,7 +35,7 @@ import org.apache.spark.util.collection.CompactBuffer
  * object.
  */
 private[joins] sealed trait HashedRelation {
-  def get(key: InternalRow): CompactBuffer[InternalRow]
+  def get(key: InternalRow): Seq[InternalRow]
 
   // This is a helper method to implement Externalizable, and is used by
   // GeneralHashedRelation and UniqueKeyHashedRelation
@@ -59,9 +62,9 @@ private[joins] final class GeneralHashedRelation(
     private var hashTable: JavaHashMap[InternalRow, 
CompactBuffer[InternalRow]])
   extends HashedRelation with Externalizable {
 
-  def this() = this(null) // Needed for serialization
+  private def this() = this(null) // Needed for serialization
 
-  override def get(key: InternalRow): CompactBuffer[InternalRow] = 
hashTable.get(key)
+  override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key)
 
   override def writeExternal(out: ObjectOutput): Unit = {
     writeBytes(out, SparkSqlSerializer.serialize(hashTable))
@@ -81,9 +84,9 @@ private[joins]
 final class UniqueKeyHashedRelation(private var hashTable: 
JavaHashMap[InternalRow, InternalRow])
   extends HashedRelation with Externalizable {
 
-  def this() = this(null) // Needed for serialization
+  private def this() = this(null) // Needed for serialization
 
-  override def get(key: InternalRow): CompactBuffer[InternalRow] = {
+  override def get(key: InternalRow): Seq[InternalRow] = {
     val v = hashTable.get(key)
     if (v eq null) null else CompactBuffer(v)
   }
@@ -109,6 +112,10 @@ private[joins] object HashedRelation {
       keyGenerator: Projection,
       sizeEstimate: Int = 64): HashedRelation = {
 
+    if (keyGenerator.isInstanceOf[UnsafeProjection]) {
+      return UnsafeHashedRelation(input, 
keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
+    }
+
     // TODO: Use Spark's HashMap implementation.
     val hashTable = new JavaHashMap[InternalRow, 
CompactBuffer[InternalRow]](sizeEstimate)
     var currentRow: InternalRow = null
@@ -149,31 +156,133 @@ private[joins] object HashedRelation {
   }
 }
 
-
 /**
- * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that 
maps the key into a
- * sequence of values.
+ * A HashedRelation for UnsafeRow, which is backed by HashMap or 
BytesToBytesMap that maps the key
+ * into a sequence of values.
+ *
+ * When it's created, it uses HashMap. After it's serialized and deserialized, 
it switch to use
+ * BytesToBytesMap for better memory performance (multiple values for the same 
are stored as a
+ * continuous byte array.
  *
- * TODO(davies): use BytesToBytesMap
+ * It's serialized in the following format:
+ *  [number of keys]
+ *  [size of key] [size of all values in bytes] [key bytes] [bytes for all 
values]
+ *  ...
+ *
+ * All the values are serialized as following:
+ *   [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
+ *   ...
  */
 private[joins] final class UnsafeHashedRelation(
     private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
   extends HashedRelation with Externalizable {
 
-  def this() = this(null)  // Needed for serialization
+  private[joins] def this() = this(null)  // Needed for serialization
+
+  // Use BytesToBytesMap in executor for better performance (it's created when 
deserialization)
+  @transient private[this] var binaryMap: BytesToBytesMap = _
 
-  override def get(key: InternalRow): CompactBuffer[InternalRow] = {
+  override def get(key: InternalRow): Seq[InternalRow] = {
     val unsafeKey = key.asInstanceOf[UnsafeRow]
-    // Thanks to type eraser
-    hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]
+
+    if (binaryMap != null) {
+      // Used in Broadcast join
+      val loc = binaryMap.lookup(unsafeKey.getBaseObject, 
unsafeKey.getBaseOffset,
+        unsafeKey.getSizeInBytes)
+      if (loc.isDefined) {
+        val buffer = CompactBuffer[UnsafeRow]()
+
+        val base = loc.getValueAddress.getBaseObject
+        var offset = loc.getValueAddress.getBaseOffset
+        val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
+        while (offset < last) {
+          val numFields = PlatformDependent.UNSAFE.getInt(base, offset)
+          val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4)
+          offset += 8
+
+          val row = new UnsafeRow
+          row.pointTo(base, offset, numFields, sizeInBytes)
+          buffer += row
+          offset += sizeInBytes
+        }
+        buffer
+      } else {
+        null
+      }
+
+    } else {
+      // Use the JavaHashMap in Local mode or ShuffleHashJoin
+      hashTable.get(unsafeKey)
+    }
   }
 
   override def writeExternal(out: ObjectOutput): Unit = {
-    writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+    out.writeInt(hashTable.size())
+
+    val iter = hashTable.entrySet().iterator()
+    while (iter.hasNext) {
+      val entry = iter.next()
+      val key = entry.getKey
+      val values = entry.getValue
+
+      // write all the values as single byte array
+      var totalSize = 0L
+      var i = 0
+      while (i < values.size) {
+        totalSize += values(i).getSizeInBytes + 4 + 4
+        i += 1
+      }
+      assert(totalSize < Integer.MAX_VALUE, "values are too big")
+
+      // [key size] [values size] [key bytes] [values bytes]
+      out.writeInt(key.getSizeInBytes)
+      out.writeInt(totalSize.toInt)
+      out.write(key.getBytes)
+      i = 0
+      while (i < values.size) {
+        // [num of fields] [num of bytes] [row bytes]
+        // write the integer in native order, so they can be read by 
UNSAFE.getInt()
+        if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
+          out.writeInt(values(i).numFields())
+          out.writeInt(values(i).getSizeInBytes)
+        } else {
+          out.writeInt(Integer.reverseBytes(values(i).numFields()))
+          out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
+        }
+        out.write(values(i).getBytes)
+        i += 1
+      }
+    }
   }
 
   override def readExternal(in: ObjectInput): Unit = {
-    hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+    val nKeys = in.readInt()
+    // This is used in Broadcast, shared by multiple tasks, so we use on-heap 
memory
+    val memoryManager = new TaskMemoryManager(new 
ExecutorMemoryManager(MemoryAllocator.HEAP))
+    binaryMap = new BytesToBytesMap(memoryManager, nKeys * 2) // reduce hash 
collision
+
+    var i = 0
+    var keyBuffer = new Array[Byte](1024)
+    var valuesBuffer = new Array[Byte](1024)
+    while (i < nKeys) {
+      val keySize = in.readInt()
+      val valuesSize = in.readInt()
+      if (keySize > keyBuffer.size) {
+        keyBuffer = new Array[Byte](keySize)
+      }
+      in.readFully(keyBuffer, 0, keySize)
+      if (valuesSize > valuesBuffer.size) {
+        valuesBuffer = new Array[Byte](valuesSize)
+      }
+      in.readFully(valuesBuffer, 0, valuesSize)
+
+      // put it into binary map
+      val loc = binaryMap.lookup(keyBuffer, 
PlatformDependent.BYTE_ARRAY_OFFSET, keySize)
+      assert(!loc.isDefined, "Duplicated key found!")
+      loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
+        valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize)
+      i += 1
+    }
   }
 }
 
@@ -181,33 +290,14 @@ private[joins] object UnsafeHashedRelation {
 
   def apply(
       input: Iterator[InternalRow],
-      buildKeys: Seq[Expression],
-      buildPlan: SparkPlan,
-      sizeEstimate: Int = 64): HashedRelation = {
-    val boundedKeys = buildKeys.map(BindReferences.bindReference(_, 
buildPlan.output))
-    apply(input, boundedKeys, buildPlan.schema, sizeEstimate)
-  }
-
-  // Used for tests
-  def apply(
-      input: Iterator[InternalRow],
-      buildKeys: Seq[Expression],
-      rowSchema: StructType,
+      keyGenerator: UnsafeProjection,
       sizeEstimate: Int): HashedRelation = {
 
-    // TODO: Use BytesToBytesMap.
     val hashTable = new JavaHashMap[UnsafeRow, 
CompactBuffer[UnsafeRow]](sizeEstimate)
-    val toUnsafe = UnsafeProjection.create(rowSchema)
-    val keyGenerator = UnsafeProjection.create(buildKeys)
 
     // Create a mapping of buildKeys -> rows
     while (input.hasNext) {
-      val currentRow = input.next()
-      val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) {
-        currentRow.asInstanceOf[UnsafeRow]
-      } else {
-        toUnsafe(currentRow)
-      }
+      val unsafeRow = input.next().asInstanceOf[UnsafeRow]
       val rowKey = keyGenerator(unsafeRow)
       if (!rowKey.anyNull) {
         val existingMatchList = hashTable.get(rowKey)

http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 874712a..26a6641 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -46,7 +46,7 @@ case class LeftSemiJoinHash(
         val hashSet = buildKeyHashSet(buildIter)
         hashSemiJoin(streamIter, hashSet)
       } else {
-        val hashRelation = buildHashRelation(buildIter)
+        val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
         hashSemiJoin(streamIter, hashRelation)
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 948d0cc..5439e10 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -45,7 +45,7 @@ case class ShuffledHashJoin(
 
   protected override def doExecute(): RDD[InternalRow] = {
     buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, 
streamIter) =>
-      val hashed = buildHashRelation(buildIter)
+      val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
       hashJoin(streamIter, hashed)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
index f54f1ed..d29b593 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
@@ -50,8 +50,8 @@ case class ShuffledHashOuterJoin(
       // TODO this probably can be replaced by external sort (sort merged 
join?)
       joinType match {
         case LeftOuter =>
-          val hashed = buildHashRelation(rightIter)
-          val keyGenerator = streamedKeyGenerator()
+          val hashed = HashedRelation(rightIter, buildKeyGenerator)
+          val keyGenerator = streamedKeyGenerator
           leftIter.flatMap( currentRow => {
             val rowKey = keyGenerator(currentRow)
             joinedRow.withLeft(currentRow)
@@ -59,8 +59,8 @@ case class ShuffledHashOuterJoin(
           })
 
         case RightOuter =>
-          val hashed = buildHashRelation(leftIter)
-          val keyGenerator = streamedKeyGenerator()
+          val hashed = HashedRelation(leftIter, buildKeyGenerator)
+          val keyGenerator = streamedKeyGenerator
           rightIter.flatMap ( currentRow => {
             val rowKey = keyGenerator(currentRow)
             joinedRow.withRight(currentRow)

http://git-wip-us.apache.org/repos/asf/spark/blob/21825529/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 9dd2220..8b1a9b2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -17,11 +17,12 @@
 
 package org.apache.spark.sql.execution.joins
 
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, 
ObjectInputStream, ObjectOutputStream}
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkSqlSerializer
-import org.apache.spark.sql.types.{StructField, StructType, IntegerType}
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 import org.apache.spark.util.collection.CompactBuffer
 
 
@@ -64,27 +65,34 @@ class HashedRelationSuite extends SparkFunSuite {
   }
 
   test("UnsafeHashedRelation") {
+    val schema = StructType(StructField("a", IntegerType, true) :: Nil)
     val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), 
InternalRow(2))
+    val toUnsafe = UnsafeProjection.create(schema)
+    val unsafeData = data.map(toUnsafe(_).copy()).toArray
+
     val buildKey = Seq(BoundReference(0, IntegerType, false))
-    val schema = StructType(StructField("a", IntegerType, true) :: Nil)
-    val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1)
+    val keyGenerator = UnsafeProjection.create(buildKey)
+    val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
     assert(hashed.isInstanceOf[UnsafeHashedRelation])
 
-    val toUnsafeKey = UnsafeProjection.create(schema)
-    val unsafeData = data.map(toUnsafeKey(_).copy()).toArray
     assert(hashed.get(unsafeData(0)) === 
CompactBuffer[InternalRow](unsafeData(0)))
     assert(hashed.get(unsafeData(1)) === 
CompactBuffer[InternalRow](unsafeData(1)))
-    assert(hashed.get(toUnsafeKey(InternalRow(10))) === null)
+    assert(hashed.get(toUnsafe(InternalRow(10))) === null)
 
     val data2 = CompactBuffer[InternalRow](unsafeData(2).copy())
     data2 += unsafeData(2).copy()
     assert(hashed.get(unsafeData(2)) === data2)
 
-    val hashed2 = 
SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed))
-      .asInstanceOf[UnsafeHashedRelation]
+    val os = new ByteArrayOutputStream()
+    val out = new ObjectOutputStream(os)
+    hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
+    out.flush()
+    val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
+    val hashed2 = new UnsafeHashedRelation()
+    hashed2.readExternal(in)
     assert(hashed2.get(unsafeData(0)) === 
CompactBuffer[InternalRow](unsafeData(0)))
     assert(hashed2.get(unsafeData(1)) === 
CompactBuffer[InternalRow](unsafeData(1)))
-    assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null)
+    assert(hashed2.get(toUnsafe(InternalRow(10))) === null)
     assert(hashed2.get(unsafeData(2)) === data2)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to