Repository: spark
Updated Branches:
  refs/heads/branch-2.0 dc1562e97 -> d98dd72e7


[SPARK-1239] Improve fetching of map output statuses

The main issue we are trying to solve is the memory bloat of the Driver when 
tasks request the map output statuses.  This means with a large number of tasks 
you either need a huge amount of memory on Driver or you have to repartition to 
smaller number.  This makes it really difficult to run over say 50000 tasks.

The main issues that cause the memory bloat are:
1) no flow control on sending the map output status responses.  We serialize 
the map status output  and then hand off to netty to send.  netty is sending 
asynchronously and it can't send them fast enough to keep up with incoming 
requests so we end up with lots of copies of the serialized map output statuses 
sitting there and this causes huge bloat when you have 10's of thousands of 
tasks and map output status is in the 10's of MB.
2) When initial reduce tasks are started up, they all request the map output 
statuses from the Driver. These requests are handled by multiple threads in 
parallel so even though we check to see if we have a cached version, initially 
when we don't have a cached version yet, many of initial requests can all end 
up serializing the exact same map output statuses.

This patch does a couple of things:
- When the map output status size is over a threshold (default 512K) then it 
uses broadcast to send the map statuses.  This means we no longer serialize a 
large map output status and thus we don't have issues with memory bloat.  the 
messages sizes are now in the 300-400 byte range and the map status output are 
broadcast. If its under the threadshold it sends it as before, the message 
contains the DIRECT indicator now.
- synchronize the incoming requests to allow one thread to cache the serialized 
output and broadcast the map output status  that can then be used by everyone 
else.  This ensures we don't create multiple broadcast variables when we don't 
need to.  To ensure this happens I added a second thread pool which the 
Dispatcher hands the requests to so that those threads can block without 
blocking the main dispatcher threads (which would cause things like heartbeats 
and such not to come through)

Note that some of design and code was contributed by mridulm

## How was this patch tested?

Unit tests and a lot of manually testing.
Ran with akka and netty rpc. Ran with both dynamic allocation on and off.

one of the large jobs I used to test this was a join of 15TB of data.  it had 
200,000 map tasks, and  20,000 reduce tasks. Executors ranged from 200 to 2000. 
 This job ran successfully with 5GB of memory on the driver with these changes. 
Without these changes I was using 20GB and only had 500 reduce tasks.  The job 
has 50mb of serialized map output statuses and took roughly the same amount of 
time for the executors to get the map output statuses as before.

Ran a variety of other jobs, from large wordcounts to small ones not using 
broadcasts.

Author: Thomas Graves <[email protected]>

Closes #12113 from tgravescs/SPARK-1239.

(cherry picked from commit cc95f1ed5fdf2566bcefe8d10116eee544cf9184)
Signed-off-by: Davies Liu <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d98dd72e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d98dd72e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d98dd72e

Branch: refs/heads/branch-2.0
Commit: d98dd72e7baeb59eacec4fefd66397513a607b2f
Parents: dc1562e
Author: Thomas Graves <[email protected]>
Authored: Fri May 6 19:31:26 2016 -0700
Committer: Davies Liu <[email protected]>
Committed: Fri May 6 19:31:35 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/MapOutputTracker.scala     | 250 +++++++++++++++----
 .../main/scala/org/apache/spark/SparkEnv.scala  |   6 +-
 .../apache/spark/MapOutputTrackerSuite.scala    |  99 +++++---
 .../spark/scheduler/DAGSchedulerSuite.scala     |   7 +-
 .../storage/BlockManagerReplicationSuite.scala  |   4 +-
 .../spark/storage/BlockManagerSuite.scala       |   4 +-
 .../streaming/ReceivedBlockHandlerSuite.scala   |   4 +-
 7 files changed, 290 insertions(+), 84 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d98dd72e/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala 
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 3a5caa3..6bd9502 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -18,13 +18,15 @@
 package org.apache.spark
 
 import java.io._
-import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, 
ThreadPoolExecutor}
 import java.util.zip.{GZIPInputStream, GZIPOutputStream}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
 import scala.reflect.ClassTag
+import scala.util.control.NonFatal
 
+import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
 import org.apache.spark.internal.Logging
 import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, 
RpcEnv}
 import org.apache.spark.scheduler.MapStatus
@@ -37,31 +39,18 @@ private[spark] case class GetMapOutputStatuses(shuffleId: 
Int)
   extends MapOutputTrackerMessage
 private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
 
+private[spark] case class GetMapOutputMessage(shuffleId: Int, context: 
RpcCallContext)
+
 /** RpcEndpoint class for MapOutputTrackerMaster */
 private[spark] class MapOutputTrackerMasterEndpoint(
     override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: 
SparkConf)
   extends RpcEndpoint with Logging {
-  val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
 
   override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
     case GetMapOutputStatuses(shuffleId: Int) =>
       val hostPort = context.senderAddress.hostPort
       logInfo("Asked to send map output locations for shuffle " + shuffleId + 
" to " + hostPort)
-      val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
-      val serializedSize = mapOutputStatuses.length
-      if (serializedSize > maxRpcMessageSize) {
-
-        val msg = s"Map output statuses were $serializedSize bytes which " +
-          s"exceeds spark.rpc.message.maxSize ($maxRpcMessageSize bytes)."
-
-        /* For SPARK-1244 we'll opt for just logging an error and then sending 
it to the sender.
-         * A bigger refactoring (SPARK-1239) will ultimately remove this 
entire code path. */
-        val exception = new SparkException(msg)
-        logError(msg, exception)
-        context.sendFailure(exception)
-      } else {
-        context.reply(mapOutputStatuses)
-      }
+      val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, 
context))
 
     case StopMapOutputTracker =>
       logInfo("MapOutputTrackerMasterEndpoint stopped!")
@@ -270,12 +259,17 @@ private[spark] abstract class MapOutputTracker(conf: 
SparkConf) extends Logging
 /**
  * MapOutputTracker for the driver.
  */
-private[spark] class MapOutputTrackerMaster(conf: SparkConf)
+private[spark] class MapOutputTrackerMaster(conf: SparkConf,
+    broadcastManager: BroadcastManager, isLocal: Boolean)
   extends MapOutputTracker(conf) {
 
   /** Cache a serialized version of the output statuses for each shuffle to 
send them out faster */
   private var cacheEpoch = epoch
 
+  // The size at which we use Broadcast to send the map output statuses to the 
executors
+  private val minSizeForBroadcast =
+    conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", 
"512k").toInt
+
   /** Whether to compute locality preferences for reduce tasks */
   private val shuffleLocalityEnabled = 
conf.getBoolean("spark.shuffle.reduceLocality.enabled", true)
 
@@ -296,10 +290,86 @@ private[spark] class MapOutputTrackerMaster(conf: 
SparkConf)
   protected val mapStatuses = new ConcurrentHashMap[Int, 
Array[MapStatus]]().asScala
   private val cachedSerializedStatuses = new ConcurrentHashMap[Int, 
Array[Byte]]().asScala
 
+  private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
+
+  // Kept in sync with cachedSerializedStatuses explicitly
+  // This is required so that the Broadcast variable remains in scope until we 
remove
+  // the shuffleId explicitly or implicitly.
+  private val cachedSerializedBroadcast = new HashMap[Int, 
Broadcast[Array[Byte]]]()
+
+  // This is to prevent multiple serializations of the same shuffle - which 
happens when
+  // there is a request storm when shuffle start.
+  private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]()
+
+  // requests for map output statuses
+  private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
+
+  // Thread pool used for handling map output status requests. This is a 
separate thread pool
+  // to ensure we don't block the normal dispatcher threads.
+  private val threadpool: ThreadPoolExecutor = {
+    val numThreads = 
conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8)
+    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, 
"map-output-dispatcher")
+    for (i <- 0 until numThreads) {
+      pool.execute(new MessageLoop)
+    }
+    pool
+  }
+
+  // Make sure that that we aren't going to exceed the max RPC message size by 
making sure
+  // we use broadcast to send large map output statuses.
+  if (minSizeForBroadcast > maxRpcMessageSize) {
+    val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast 
($minSizeForBroadcast bytes) must " +
+      s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent 
sending an rpc " +
+      "message that is to large."
+    logError(msg)
+    throw new IllegalArgumentException(msg)
+  }
+
+  def post(message: GetMapOutputMessage): Unit = {
+    mapOutputRequests.offer(message)
+  }
+
+  /** Message loop used for dispatching messages. */
+  private class MessageLoop extends Runnable {
+    override def run(): Unit = {
+      try {
+        while (true) {
+          try {
+            val data = mapOutputRequests.take()
+             if (data == PoisonPill) {
+              // Put PoisonPill back so that other MessageLoops can see it.
+              mapOutputRequests.offer(PoisonPill)
+              return
+            }
+            val context = data.context
+            val shuffleId = data.shuffleId
+            val hostPort = context.senderAddress.hostPort
+            logDebug("Handling request to send map output locations for 
shuffle " + shuffleId +
+              " to " + hostPort)
+            val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)
+            context.reply(mapOutputStatuses)
+          } catch {
+            case NonFatal(e) => logError(e.getMessage, e)
+          }
+        }
+      } catch {
+        case ie: InterruptedException => // exit
+      }
+    }
+  }
+
+  /** A poison endpoint that indicates MessageLoop should exit its message 
loop. */
+  private val PoisonPill = new GetMapOutputMessage(-99, null)
+
+  // Exposed for testing
+  private[spark] def getNumCachedSerializedBroadcast = 
cachedSerializedBroadcast.size
+
   def registerShuffle(shuffleId: Int, numMaps: Int) {
     if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
       throw new IllegalArgumentException("Shuffle ID " + shuffleId + " 
registered twice")
     }
+    // add in advance
+    shuffleIdLocks.putIfAbsent(shuffleId, new Object())
   }
 
   def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
@@ -337,6 +407,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
   override def unregisterShuffle(shuffleId: Int) {
     mapStatuses.remove(shuffleId)
     cachedSerializedStatuses.remove(shuffleId)
+    cachedSerializedBroadcast.remove(shuffleId).foreach(v => 
removeBroadcast(v))
+    shuffleIdLocks.remove(shuffleId)
   }
 
   /** Check if the given shuffle is being tracked */
@@ -428,40 +500,89 @@ private[spark] class MapOutputTrackerMaster(conf: 
SparkConf)
     }
   }
 
+  private def removeBroadcast(bcast: Broadcast[_]): Unit = {
+    if (null != bcast) {
+      broadcastManager.unbroadcast(bcast.id,
+        removeFromDriver = true, blocking = false)
+    }
+  }
+
+  private def clearCachedBroadcast(): Unit = {
+    for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2)
+    cachedSerializedBroadcast.clear()
+  }
+
   def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
     var statuses: Array[MapStatus] = null
+    var retBytes: Array[Byte] = null
     var epochGotten: Long = -1
-    epochLock.synchronized {
-      if (epoch > cacheEpoch) {
-        cachedSerializedStatuses.clear()
-        cacheEpoch = epoch
-      }
-      cachedSerializedStatuses.get(shuffleId) match {
-        case Some(bytes) =>
-          return bytes
-        case None =>
-          statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
-          epochGotten = epoch
+
+    // Check to see if we have a cached version, returns true if it does
+    // and has side effect of setting retBytes.  If not returns false
+    // with side effect of setting statuses
+    def checkCachedStatuses(): Boolean = {
+      epochLock.synchronized {
+        if (epoch > cacheEpoch) {
+          cachedSerializedStatuses.clear()
+          clearCachedBroadcast()
+          cacheEpoch = epoch
+        }
+        cachedSerializedStatuses.get(shuffleId) match {
+          case Some(bytes) =>
+            retBytes = bytes
+            true
+          case None =>
+            logDebug("cached status not found for : " + shuffleId)
+            statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
+            epochGotten = epoch
+            false
+        }
       }
     }
-    // If we got here, we failed to find the serialized locations in the 
cache, so we pulled
-    // out a snapshot of the locations as "statuses"; let's serialize and 
return that
-    val bytes = MapOutputTracker.serializeMapStatuses(statuses)
-    logInfo("Size of output statuses for shuffle %d is %d 
bytes".format(shuffleId, bytes.length))
-    // Add them into the table only if the epoch hasn't changed while we were 
working
-    epochLock.synchronized {
-      if (epoch == epochGotten) {
-        cachedSerializedStatuses(shuffleId) = bytes
+
+    if (checkCachedStatuses()) return retBytes
+    var shuffleIdLock = shuffleIdLocks.get(shuffleId)
+    if (null == shuffleIdLock) {
+      val newLock = new Object()
+      // in general, this condition should be false - but good to be paranoid
+      val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
+      shuffleIdLock = if (null != prevLock) prevLock else newLock
+    }
+    // synchronize so we only serialize/broadcast it once since multiple 
threads call
+    // in parallel
+    shuffleIdLock.synchronized {
+      // double check to make sure someone else didn't serialize and cache the 
same
+      // mapstatus while we were waiting on the synchronize
+      if (checkCachedStatuses()) return retBytes
+
+      // If we got here, we failed to find the serialized locations in the 
cache, so we pulled
+      // out a snapshot of the locations as "statuses"; let's serialize and 
return that
+      val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, 
broadcastManager,
+        isLocal, minSizeForBroadcast)
+      logInfo("Size of output statuses for shuffle %d is %d 
bytes".format(shuffleId, bytes.length))
+      // Add them into the table only if the epoch hasn't changed while we 
were working
+      epochLock.synchronized {
+        if (epoch == epochGotten) {
+          cachedSerializedStatuses(shuffleId) = bytes
+          if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
+        } else {
+          logInfo("Epoch changed, not caching!")
+          removeBroadcast(bcast)
+        }
       }
+      bytes
     }
-    bytes
   }
 
   override def stop() {
+    mapOutputRequests.offer(PoisonPill)
+    threadpool.shutdown()
     sendTracker(StopMapOutputTracker)
     mapStatuses.clear()
     trackerEndpoint = null
     cachedSerializedStatuses.clear()
+    clearCachedBroadcast()
+    shuffleIdLocks.clear()
   }
 }
 
@@ -477,12 +598,16 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
 private[spark] object MapOutputTracker extends Logging {
 
   val ENDPOINT_NAME = "MapOutputTracker"
+  private val DIRECT = 0
+  private val BROADCAST = 1
 
   // Serialize an array of map output locations into an efficient byte format 
so that we can send
   // it to reduce tasks. We do this by compressing the serialized bytes using 
GZIP. They will
   // generally be pretty compressible because many map outputs will be on the 
same hostname.
-  def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+  def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: 
BroadcastManager,
+      isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], 
Broadcast[Array[Byte]]) = {
     val out = new ByteArrayOutputStream
+    out.write(DIRECT)
     val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
     Utils.tryWithSafeFinally {
       // Since statuses can be modified in parallel, sync on it
@@ -492,16 +617,51 @@ private[spark] object MapOutputTracker extends Logging {
     } {
       objOut.close()
     }
-    out.toByteArray
+    val arr = out.toByteArray
+    if (arr.length >= minBroadcastSize) {
+      // Use broadcast instead.
+      // Important arr(0) is the tag == DIRECT, ignore that while 
deserializing !
+      val bcast = broadcastManager.newBroadcast(arr, isLocal)
+      // toByteArray creates copy, so we can reuse out
+      out.reset()
+      out.write(BROADCAST)
+      val oos = new ObjectOutputStream(new GZIPOutputStream(out))
+      oos.writeObject(bcast)
+      oos.close()
+      val outArr = out.toByteArray
+      logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size 
= " + arr.length)
+      (outArr, bcast)
+    } else {
+      (arr, null)
+    }
   }
 
   // Opposite of serializeMapStatuses.
   def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
-    val objIn = new ObjectInputStream(new GZIPInputStream(new 
ByteArrayInputStream(bytes)))
-    Utils.tryWithSafeFinally {
-      objIn.readObject().asInstanceOf[Array[MapStatus]]
-    } {
-      objIn.close()
+    assert (bytes.length > 0)
+
+    def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = {
+      val objIn = new ObjectInputStream(new GZIPInputStream(
+        new ByteArrayInputStream(arr, off, len)))
+      Utils.tryWithSafeFinally {
+        objIn.readObject()
+      } {
+        objIn.close()
+      }
+    }
+
+    bytes(0) match {
+      case DIRECT =>
+        deserializeObject(bytes, 1, bytes.length - 
1).asInstanceOf[Array[MapStatus]]
+      case BROADCAST =>
+        // deserialize the Broadcast, pull .value array out of it, and then 
deserialize that
+        val bcast = deserializeObject(bytes, 1, bytes.length - 1).
+          asInstanceOf[Broadcast[Array[Byte]]]
+        logInfo("Broadcast mapstatuses size = " + bytes.length +
+          ", actual size = " + bcast.value.length)
+        // Important - ignore the DIRECT tag ! Start from offset 1
+        deserializeObject(bcast.value, 1, bcast.value.length - 
1).asInstanceOf[Array[MapStatus]]
+      case _ => throw new IllegalArgumentException("Unexpected byte tag = " + 
bytes(0))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d98dd72e/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala 
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 27497e2..4bf8890 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -284,8 +284,10 @@ object SparkEnv extends Logging {
       }
     }
 
+    val broadcastManager = new BroadcastManager(isDriver, conf, 
securityManager)
+
     val mapOutputTracker = if (isDriver) {
-      new MapOutputTrackerMaster(conf)
+      new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
     } else {
       new MapOutputTrackerWorker(conf)
     }
@@ -325,8 +327,6 @@ object SparkEnv extends Logging {
       serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager,
       blockTransferService, securityManager, numUsableCores)
 
-    val broadcastManager = new BroadcastManager(isDriver, conf, 
securityManager)
-
     val metricsSystem = if (isDriver) {
       // Don't start metrics system right now for Driver.
       // We need to wait for the task scheduler to give us an app ID.

http://git-wip-us.apache.org/repos/asf/spark/blob/d98dd72e/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala 
b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index ddf4876..c6aebc1 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
 import org.mockito.Matchers.{any, isA}
 import org.mockito.Mockito._
 
+import org.apache.spark.broadcast.BroadcastManager
 import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv}
 import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
 import org.apache.spark.shuffle.FetchFailedException
@@ -30,6 +31,12 @@ import org.apache.spark.storage.{BlockManagerId, 
ShuffleBlockId}
 class MapOutputTrackerSuite extends SparkFunSuite {
   private val conf = new SparkConf
 
+  private def newTrackerMaster(sparkConf: SparkConf = conf) = {
+    val broadcastManager = new BroadcastManager(true, sparkConf,
+      new SecurityManager(sparkConf))
+    new MapOutputTrackerMaster(sparkConf, broadcastManager, true)
+  }
+
   def createRpcEnv(name: String, host: String = "localhost", port: Int = 0,
       securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = {
     RpcEnv.create(name, host, port, conf, securityManager)
@@ -37,7 +44,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
 
   test("master start and stop") {
     val rpcEnv = createRpcEnv("test")
-    val tracker = new MapOutputTrackerMaster(conf)
+    val tracker = newTrackerMaster()
     tracker.trackerEndpoint = 
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
       new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
     tracker.stop()
@@ -46,7 +53,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
 
   test("master register shuffle and fetch") {
     val rpcEnv = createRpcEnv("test")
-    val tracker = new MapOutputTrackerMaster(conf)
+    val tracker = newTrackerMaster()
     tracker.trackerEndpoint = 
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
       new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
     tracker.registerShuffle(10, 2)
@@ -62,13 +69,14 @@ class MapOutputTrackerSuite extends SparkFunSuite {
       Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 
0, 0), size1000))),
           (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 
1, 0), size10000))))
         .toSet)
+    assert(0 == tracker.getNumCachedSerializedBroadcast)
     tracker.stop()
     rpcEnv.shutdown()
   }
 
   test("master register and unregister shuffle") {
     val rpcEnv = createRpcEnv("test")
-    val tracker = new MapOutputTrackerMaster(conf)
+    val tracker = newTrackerMaster()
     tracker.trackerEndpoint = 
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
       new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
     tracker.registerShuffle(10, 2)
@@ -80,6 +88,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
       Array(compressedSize10000, compressedSize1000)))
     assert(tracker.containsShuffle(10))
     assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty)
+    assert(0 == tracker.getNumCachedSerializedBroadcast)
     tracker.unregisterShuffle(10)
     assert(!tracker.containsShuffle(10))
     assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty)
@@ -90,7 +99,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
 
   test("master register shuffle and unregister map output and fetch") {
     val rpcEnv = createRpcEnv("test")
-    val tracker = new MapOutputTrackerMaster(conf)
+    val tracker = newTrackerMaster()
     tracker.trackerEndpoint = 
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
       new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
     tracker.registerShuffle(10, 2)
@@ -101,6 +110,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 
1000),
         Array(compressedSize10000, compressedSize1000, compressedSize1000)))
 
+    assert(0 == tracker.getNumCachedSerializedBroadcast)
     // As if we had two simultaneous fetch failures
     tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
     tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
@@ -118,7 +128,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     val hostname = "localhost"
     val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))
 
-    val masterTracker = new MapOutputTrackerMaster(conf)
+    val masterTracker = newTrackerMaster()
     masterTracker.trackerEndpoint = 
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
       new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
 
@@ -139,6 +149,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
       Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 
0, 0), size1000)))))
+    assert(0 == masterTracker.getNumCachedSerializedBroadcast)
 
     masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 
1000))
     masterTracker.incrementEpoch()
@@ -147,6 +158,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
 
     // failure should be cached
     intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 
0) }
+    assert(0 == masterTracker.getNumCachedSerializedBroadcast)
 
     masterTracker.stop()
     slaveTracker.stop()
@@ -158,8 +170,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     val newConf = new SparkConf
     newConf.set("spark.rpc.message.maxSize", "1")
     newConf.set("spark.rpc.askTimeout", "1") // Fail fast
+    newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "1048576")
 
-    val masterTracker = new MapOutputTrackerMaster(conf)
+    val masterTracker = newTrackerMaster(newConf)
     val rpcEnv = createRpcEnv("spark")
     val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, 
masterTracker, newConf)
     rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
@@ -172,45 +185,27 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     val rpcCallContext = mock(classOf[RpcCallContext])
     when(rpcCallContext.senderAddress).thenReturn(senderAddress)
     masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10))
-    verify(rpcCallContext).reply(any())
-    verify(rpcCallContext, never()).sendFailure(any())
+    // Default size for broadcast in this testsuite is set to -1 so should not 
cause broadcast
+    // to be used.
+    verify(rpcCallContext, timeout(30000)).reply(any())
+    assert(0 == masterTracker.getNumCachedSerializedBroadcast)
 
 //    masterTracker.stop() // this throws an exception
     rpcEnv.shutdown()
   }
 
-  test("remote fetch exceeds max RPC message size") {
+  test("min broadcast size exceeds max RPC message size") {
     val newConf = new SparkConf
     newConf.set("spark.rpc.message.maxSize", "1")
     newConf.set("spark.rpc.askTimeout", "1") // Fail fast
+    newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", 
Int.MaxValue.toString)
 
-    val masterTracker = new MapOutputTrackerMaster(conf)
-    val rpcEnv = createRpcEnv("test")
-    val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, 
masterTracker, newConf)
-    rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
-
-    // Message size should be ~1.1MB, and MapOutputTrackerMasterEndpoint 
should throw exception.
-    // Note that the size is hand-selected here because map output statuses 
are compressed before
-    // being sent.
-    masterTracker.registerShuffle(20, 100)
-    (0 until 100).foreach { i =>
-      masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
-        BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
-    }
-    val senderAddress = RpcAddress("localhost", 12345)
-    val rpcCallContext = mock(classOf[RpcCallContext])
-    when(rpcCallContext.senderAddress).thenReturn(senderAddress)
-    masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20))
-    verify(rpcCallContext, never()).reply(any())
-    verify(rpcCallContext).sendFailure(isA(classOf[SparkException]))
-
-//    masterTracker.stop() // this throws an exception
-    rpcEnv.shutdown()
+    intercept[IllegalArgumentException] { newTrackerMaster(newConf) }
   }
 
   test("getLocationsWithLargestOutputs with multiple outputs in same machine") 
{
     val rpcEnv = createRpcEnv("test")
-    val tracker = new MapOutputTrackerMaster(conf)
+    val tracker = newTrackerMaster()
     tracker.trackerEndpoint = 
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
       new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
     // Setup 3 map tasks
@@ -242,4 +237,44 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     tracker.stop()
     rpcEnv.shutdown()
   }
+
+  test("remote fetch using broadcast") {
+    val newConf = new SparkConf
+    newConf.set("spark.rpc.message.maxSize", "1")
+    newConf.set("spark.rpc.askTimeout", "1") // Fail fast
+    newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 
KB << 1MB framesize
+
+    // needs TorrentBroadcast so need a SparkContext
+    val sc = new SparkContext("local", "MapOutputTrackerSuite", newConf)
+    try {
+      val masterTracker = 
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+      val rpcEnv = sc.env.rpcEnv
+      val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, 
masterTracker, newConf)
+      rpcEnv.stop(masterTracker.trackerEndpoint)
+      rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
+
+      // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should 
throw exception.
+      // Note that the size is hand-selected here because map output statuses 
are compressed before
+      // being sent.
+      masterTracker.registerShuffle(20, 100)
+      (0 until 100).foreach { i =>
+        masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
+          BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
+      }
+      val senderAddress = RpcAddress("localhost", 12345)
+      val rpcCallContext = mock(classOf[RpcCallContext])
+      when(rpcCallContext.senderAddress).thenReturn(senderAddress)
+      masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20))
+      // should succeed since majority of data is broadcast and actual 
serialized
+      // message size is small
+      verify(rpcCallContext, timeout(30000)).reply(any())
+      assert(1 == masterTracker.getNumCachedSerializedBroadcast)
+      masterTracker.unregisterShuffle(20)
+      assert(0 == masterTracker.getNumCachedSerializedBroadcast)
+
+    } finally {
+      LocalSparkContext.stop(sc)
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d98dd72e/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 844c780..e3ed079 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -28,6 +28,7 @@ import org.scalatest.concurrent.Timeouts
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
 import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
@@ -156,6 +157,8 @@ class DAGSchedulerSuite extends SparkFunSuite with 
LocalSparkContext with Timeou
   }
 
   var mapOutputTracker: MapOutputTrackerMaster = null
+  var broadcastManager: BroadcastManager = null
+  var securityMgr: SecurityManager = null
   var scheduler: DAGScheduler = null
   var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null
 
@@ -207,7 +210,9 @@ class DAGSchedulerSuite extends SparkFunSuite with 
LocalSparkContext with Timeou
     cancelledStages.clear()
     cacheLocations.clear()
     results.clear()
-    mapOutputTracker = new MapOutputTrackerMaster(conf)
+    securityMgr = new SecurityManager(conf)
+    broadcastManager = new BroadcastManager(true, conf, securityMgr)
+    mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true)
     scheduler = new DAGScheduler(
       sc,
       taskScheduler,

http://git-wip-us.apache.org/repos/asf/spark/blob/d98dd72e/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
 
b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index d14728c..31687e6 100644
--- 
a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -27,6 +27,7 @@ import org.scalatest.{BeforeAndAfter, Matchers}
 import org.scalatest.concurrent.Eventually._
 
 import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
 import org.apache.spark.memory.UnifiedMemoryManager
 import org.apache.spark.network.BlockTransferService
 import org.apache.spark.network.netty.NettyBlockTransferService
@@ -43,7 +44,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite with 
Matchers with Befo
   private var rpcEnv: RpcEnv = null
   private var master: BlockManagerMaster = null
   private val securityMgr = new SecurityManager(conf)
-  private val mapOutputTracker = new MapOutputTrackerMaster(conf)
+  private val bcastManager = new BroadcastManager(true, conf, securityMgr)
+  private val mapOutputTracker = new MapOutputTrackerMaster(conf, 
bcastManager, true)
   private val shuffleManager = new SortShuffleManager(conf)
 
   // List of block manager created during an unit test, so that all of the 
them can be stopped

http://git-wip-us.apache.org/repos/asf/spark/blob/d98dd72e/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index db1efaf..a258030 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._
 import org.scalatest.concurrent.Timeouts._
 
 import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
 import org.apache.spark.executor.DataReadMethod
 import org.apache.spark.memory.UnifiedMemoryManager
 import org.apache.spark.network.{BlockDataManager, BlockTransferService}
@@ -59,7 +60,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers 
with BeforeAndAfterE
   var rpcEnv: RpcEnv = null
   var master: BlockManagerMaster = null
   val securityMgr = new SecurityManager(new SparkConf(false))
-  val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false))
+  val bcastManager = new BroadcastManager(true, new SparkConf(false), 
securityMgr)
+  val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), 
bcastManager, true)
   val shuffleManager = new SortShuffleManager(new SparkConf(false))
 
   // Reuse a serializer across tests to avoid creating a new thread-local 
buffer on each test

http://git-wip-us.apache.org/repos/asf/spark/blob/d98dd72e/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 4be4882..e974279 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -29,6 +29,7 @@ import org.scalatest.{BeforeAndAfter, Matchers}
 import org.scalatest.concurrent.Eventually._
 
 import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
 import org.apache.spark.internal.Logging
 import org.apache.spark.memory.StaticMemoryManager
 import org.apache.spark.network.netty.NettyBlockTransferService
@@ -57,7 +58,8 @@ class ReceivedBlockHandlerSuite
   val hadoopConf = new Configuration()
   val streamId = 1
   val securityMgr = new SecurityManager(conf)
-  val mapOutputTracker = new MapOutputTrackerMaster(conf)
+  val broadcastManager = new BroadcastManager(true, conf, securityMgr)
+  val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, 
true)
   val shuffleManager = new SortShuffleManager(conf)
   val serializer = new KryoSerializer(conf)
   var serializerManager = new SerializerManager(serializer, conf)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to