Repository: spark
Updated Branches:
  refs/heads/master 615cc858c -> cf2e0ae72


[SPARK-11096] Post-hoc review Netty based RPC implementation - round 2

A few more changes:

1. Renamed IDVerifier -> RpcEndpointVerifier
2. Renamed NettyRpcAddress -> RpcEndpointAddress
3. Simplified NettyRpcHandler a bit by removing the connection count tracking. 
This is OK because I now force spark.shuffle.io.numConnectionsPerPeer to 1
4. Reduced spark.rpc.connect.threads to 64. It would be great to eventually 
remove this extra thread pool.
5. Minor cleanup & documentation.

Author: Reynold Xin <[email protected]>

Closes #9112 from rxin/SPARK-11096.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cf2e0ae7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cf2e0ae7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cf2e0ae7

Branch: refs/heads/master
Commit: cf2e0ae7205443f052463e8cb9334ae2b6df2d0e
Parents: 615cc85
Author: Reynold Xin <[email protected]>
Authored: Wed Oct 14 12:41:02 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Wed Oct 14 12:41:02 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/rpc/RpcEnv.scala     |   9 --
 .../org/apache/spark/rpc/netty/Dispatcher.scala |   7 +-
 .../org/apache/spark/rpc/netty/IDVerifier.scala |  39 -------
 .../spark/rpc/netty/NettyRpcAddress.scala       |  56 ---------
 .../apache/spark/rpc/netty/NettyRpcEnv.scala    | 114 ++++++++-----------
 .../spark/rpc/netty/RpcEndpointAddress.scala    |  60 ++++++++++
 .../spark/rpc/netty/RpcEndpointVerifier.scala   |  40 +++++++
 .../spark/rpc/netty/NettyRpcAddressSuite.scala  |   2 +-
 .../spark/rpc/netty/NettyRpcHandlerSuite.scala  |   3 -
 9 files changed, 152 insertions(+), 178 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala 
b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index ef491a0..2c4a8b9 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -94,15 +94,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
   }
 
   /**
-   * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` 
and `endpointName`
-   * asynchronously.
-   */
-  def asyncSetupEndpointRef(
-      systemName: String, address: RpcAddress, endpointName: String): 
Future[RpcEndpointRef] = {
-    asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName))
-  }
-
-  /**
    * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` 
and `endpointName`.
    * This is a blocking action.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index 398e9ea..f1a8273 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -29,6 +29,9 @@ import org.apache.spark.network.client.RpcResponseCallback
 import org.apache.spark.rpc._
 import org.apache.spark.util.ThreadUtils
 
+/**
+ * A message dispatcher, responsible for routing RPC messages to the 
appropriate endpoint(s).
+ */
 private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
 
   private class EndpointData(
@@ -42,7 +45,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) 
extends Logging {
   private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
 
   // Track the receivers whose inboxes may contain messages.
-  private val receivers = new LinkedBlockingQueue[EndpointData]()
+  private val receivers = new LinkedBlockingQueue[EndpointData]
 
   /**
    * True if the dispatcher has been stopped. Once stopped, all messages 
posted will be bounced
@@ -52,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) 
extends Logging {
   private var stopped = false
 
   def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): 
NettyRpcEndpointRef = {
-    val addr = new NettyRpcAddress(nettyEnv.address.host, 
nettyEnv.address.port, name)
+    val addr = new RpcEndpointAddress(nettyEnv.address.host, 
nettyEnv.address.port, name)
     val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
     synchronized {
       if (stopped) {

http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
deleted file mode 100644
index fa9a3eb..0000000
--- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.rpc.netty
-
-import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
-
-/**
- * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists
- */
-private[netty] case class ID(name: String)
-
-/**
- * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] 
exists in this [[RpcEnv]]
- */
-private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: 
Dispatcher)
-  extends RpcEndpoint {
-
-  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
-    case ID(name) => context.reply(dispatcher.verify(name))
-  }
-}
-
-private[netty] object IDVerifier {
-  val NAME = "id-verifier"
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala
deleted file mode 100644
index 1876b25..0000000
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rpc.netty
-
-import java.net.URI
-
-import org.apache.spark.SparkException
-import org.apache.spark.rpc.RpcAddress
-
-private[netty] case class NettyRpcAddress(host: String, port: Int, name: 
String) {
-
-  def toRpcAddress: RpcAddress = RpcAddress(host, port)
-
-  override val toString = s"spark://$name@$host:$port"
-}
-
-private[netty] object NettyRpcAddress {
-
-  def apply(sparkUrl: String): NettyRpcAddress = {
-    try {
-      val uri = new URI(sparkUrl)
-      val host = uri.getHost
-      val port = uri.getPort
-      val name = uri.getUserInfo
-      if (uri.getScheme != "spark" ||
-        host == null ||
-        port < 0 ||
-        name == null ||
-        (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath 
returns "" instead of null
-        uri.getFragment != null ||
-        uri.getQuery != null) {
-        throw new SparkException("Invalid Spark URL: " + sparkUrl)
-      }
-      NettyRpcAddress(host, port, name)
-    } catch {
-      case e: java.net.URISyntaxException =>
-        throw new SparkException("Invalid Spark URL: " + sparkUrl, e)
-    }
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 89b6df7..a2b28c5 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -22,7 +22,6 @@ import java.nio.ByteBuffer
 import java.util.concurrent._
 import javax.annotation.concurrent.GuardedBy
 
-import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.concurrent.{Future, Promise}
 import scala.reflect.ClassTag
@@ -45,8 +44,10 @@ private[netty] class NettyRpcEnv(
     host: String,
     securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
 
-  private val transportConf =
-    SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 
0))
+  // Override numConnectionsPerPeer to 1 for RPC.
+  private val transportConf = SparkTransportConf.fromSparkConf(
+    conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"),
+    conf.getInt("spark.rpc.io.threads", 0))
 
   private val dispatcher: Dispatcher = new Dispatcher(this)
 
@@ -54,14 +55,14 @@ private[netty] class NettyRpcEnv(
     new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
 
   private val clientFactory = {
-    val bootstraps: Seq[TransportClientBootstrap] =
+    val bootstraps: java.util.List[TransportClientBootstrap] =
       if (securityManager.isAuthenticationEnabled()) {
-        Seq(new SaslClientBootstrap(transportConf, "", securityManager,
+        java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", 
securityManager,
           securityManager.isSaslEncryptionEnabled()))
       } else {
-        Nil
+        java.util.Collections.emptyList[TransportClientBootstrap]
       }
-    transportContext.createClientFactory(bootstraps.asJava)
+    transportContext.createClientFactory(bootstraps)
   }
 
   val timeoutScheduler = 
ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
@@ -71,7 +72,7 @@ private[netty] class NettyRpcEnv(
   // TODO: a non-blocking TransportClientFactory.createClient in future
   private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
     "netty-rpc-connection",
-    conf.getInt("spark.rpc.connect.threads", 256))
+    conf.getInt("spark.rpc.connect.threads", 64))
 
   @volatile private var server: TransportServer = _
 
@@ -83,7 +84,8 @@ private[netty] class NettyRpcEnv(
         java.util.Collections.emptyList()
       }
     server = transportContext.createServer(port, bootstraps)
-    dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, 
dispatcher))
+    dispatcher.registerRpcEndpoint(
+      RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
   }
 
   override lazy val address: RpcAddress = {
@@ -96,11 +98,11 @@ private[netty] class NettyRpcEnv(
   }
 
   def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
-    val addr = NettyRpcAddress(uri)
+    val addr = RpcEndpointAddress(uri)
     val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
-    val idVerifierRef =
-      new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, 
IDVerifier.NAME), this)
-    idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find =>
+    val verifier = new NettyRpcEndpointRef(
+      conf, RpcEndpointAddress(addr.host, addr.port, 
RpcEndpointVerifier.NAME), this)
+    
verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap
 { find =>
       if (find) {
         Future.successful(endpointRef)
       } else {
@@ -117,16 +119,18 @@ private[netty] class NettyRpcEnv(
   private[netty] def send(message: RequestMessage): Unit = {
     val remoteAddr = message.receiver.address
     if (remoteAddr == address) {
+      // Message to a local RPC endpoint.
       val promise = Promise[Any]()
       dispatcher.postLocalMessage(message, promise)
       promise.future.onComplete {
         case Success(response) =>
           val ack = response.asInstanceOf[Ack]
-          logDebug(s"Receive ack from ${ack.sender}")
+          logTrace(s"Received ack from ${ack.sender}")
         case Failure(e) =>
           logError(s"Exception when sending $message", e)
       }(ThreadUtils.sameThread)
     } else {
+      // Message to a remote RPC endpoint.
       try {
         // `createClient` will block if it cannot find a known connection, so 
we should run it in
         // clientConnectionExecutor
@@ -204,11 +208,10 @@ private[netty] class NettyRpcEnv(
           }
         })
       } catch {
-        case e: RejectedExecutionException => {
+        case e: RejectedExecutionException =>
           if (!promise.tryFailure(e)) {
             logWarning(s"Ignore failure", e)
           }
-        }
       }
     }
     promise.future
@@ -231,7 +234,7 @@ private[netty] class NettyRpcEnv(
   }
 
   override def uriOf(systemName: String, address: RpcAddress, endpointName: 
String): String =
-    new NettyRpcAddress(address.host, address.port, endpointName).toString
+    new RpcEndpointAddress(address.host, address.port, endpointName).toString
 
   override def shutdown(): Unit = {
     cleanup()
@@ -310,9 +313,9 @@ private[netty] class NettyRpcEndpointRef(@transient conf: 
SparkConf)
 
   @transient @volatile private var nettyEnv: NettyRpcEnv = _
 
-  @transient @volatile private var _address: NettyRpcAddress = _
+  @transient @volatile private var _address: RpcEndpointAddress = _
 
-  def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) {
+  def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: 
NettyRpcEnv) {
     this(conf)
     this._address = _address
     this.nettyEnv = nettyEnv
@@ -322,7 +325,7 @@ private[netty] class NettyRpcEndpointRef(@transient conf: 
SparkConf)
 
   private def readObject(in: ObjectInputStream): Unit = {
     in.defaultReadObject()
-    _address = in.readObject().asInstanceOf[NettyRpcAddress]
+    _address = in.readObject().asInstanceOf[RpcEndpointAddress]
     nettyEnv = NettyRpcEnv.currentEnv.value
   }
 
@@ -406,49 +409,37 @@ private[netty] class NettyRpcHandler(
   private type RemoteEnvAddress = RpcAddress
 
   // Store all client addresses and their NettyRpcEnv addresses.
+  // TODO: Is this even necessary?
   @GuardedBy("this")
   private val remoteAddresses = new mutable.HashMap[ClientAddress, 
RemoteEnvAddress]()
 
-  // Store the connections from other NettyRpcEnv addresses. We need to keep 
track of the connection
-  // count because `TransportClientFactory.createClient` will create multiple 
connections
-  // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and 
randomly select a connection
-  // to send the message. See `TransportClientFactory.createClient` for more 
details.
-  @GuardedBy("this")
-  private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, 
Int]()
-
   override def receive(
       client: TransportClient, message: Array[Byte], callback: 
RpcResponseCallback): Unit = {
     val requestMessage = nettyEnv.deserialize[RequestMessage](message)
-    val addr = 
client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+    val addr = 
client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
     assert(addr != null)
     val remoteEnvAddress = requestMessage.senderAddress
     val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
-    val broadcastMessage: Option[RemoteProcessConnected] =
-      synchronized {
-        // If the first connection to a remote RpcEnv is found, we should 
broadcast "Associated"
-        if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
-          // clientAddr connects at the first time
-          val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
-          // Increase the connection number of remoteEnvAddress
-          remoteConnectionCount.put(remoteEnvAddress, count + 1)
-          if (count == 0) {
-            // This is the first connection, so fire "Associated"
-            Some(RemoteProcessConnected(remoteEnvAddress))
-          } else {
-            None
-          }
-        } else {
-          None
-        }
+
+    // TODO: Can we add connection callback (channel registered) to the 
underlying framework?
+    // A variable to track whether we should dispatch the 
RemoteProcessConnected message.
+    var dispatchRemoteProcessConnected = false
+    synchronized {
+      if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
+        // clientAddr connects at the first time, fire "RemoteProcessConnected"
+        dispatchRemoteProcessConnected = true
       }
-    broadcastMessage.foreach(dispatcher.postToAll)
+    }
+    if (dispatchRemoteProcessConnected) {
+      dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))
+    }
     dispatcher.postRemoteMessage(requestMessage, callback)
   }
 
   override def getStreamManager: StreamManager = new OneForOneStreamManager
 
   override def exceptionCaught(cause: Throwable, client: TransportClient): 
Unit = {
-    val addr = 
client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+    val addr = 
client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
     if (addr != null) {
       val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
       val broadcastMessage =
@@ -469,34 +460,21 @@ private[netty] class NettyRpcHandler(
   }
 
   override def connectionTerminated(client: TransportClient): Unit = {
-    val addr = 
client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+    val addr = 
client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
     if (addr != null) {
       val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
-      val broadcastMessage =
-        synchronized {
-          // If the last connection to a remote RpcEnv is terminated, we 
should broadcast
-          // "Disassociated"
-          remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
-            remoteAddresses -= clientAddr
-            val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
-            assert(count != 0, "remoteAddresses and remoteConnectionCount are 
not consistent")
-            if (count - 1 == 0) {
-              // We lost all clients, so clean up and fire "Disassociated"
-              remoteConnectionCount.remove(remoteEnvAddress)
-              Some(RemoteProcessDisconnected(remoteEnvAddress))
-            } else {
-              // Decrease the connection number of remoteEnvAddress
-              remoteConnectionCount.put(remoteEnvAddress, count - 1)
-              None
-            }
-          }
+      val messageOpt: Option[RemoteProcessDisconnected] =
+      synchronized {
+        remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
+          remoteAddresses -= clientAddr
+          Some(RemoteProcessDisconnected(remoteEnvAddress))
         }
-      broadcastMessage.foreach(dispatcher.postToAll)
+      }
+      messageOpt.foreach(dispatcher.postToAll)
     } else {
       // If the channel is closed before connecting, its remoteAddress will be 
null. In this case,
       // we can ignore it since we don't fire "Associated".
       // See java.net.Socket.getRemoteSocketAddress
     }
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
new file mode 100644
index 0000000..87b6236
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import org.apache.spark.SparkException
+import org.apache.spark.rpc.RpcAddress
+
+/**
+ * An address identifier for an RPC endpoint.
+ *
+ * @param host host name of the remote process.
+ * @param port the port the remote RPC environment binds to.
+ * @param name name of the remote endpoint.
+ */
+private[netty] case class RpcEndpointAddress(host: String, port: Int, name: 
String) {
+
+  def toRpcAddress: RpcAddress = RpcAddress(host, port)
+
+  override val toString = s"spark://$name@$host:$port"
+}
+
+private[netty] object RpcEndpointAddress {
+
+  def apply(sparkUrl: String): RpcEndpointAddress = {
+    try {
+      val uri = new java.net.URI(sparkUrl)
+      val host = uri.getHost
+      val port = uri.getPort
+      val name = uri.getUserInfo
+      if (uri.getScheme != "spark" ||
+          host == null ||
+          port < 0 ||
+          name == null ||
+          (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath 
returns "" instead of null
+          uri.getFragment != null ||
+          uri.getQuery != null) {
+        throw new SparkException("Invalid Spark URL: " + sparkUrl)
+      }
+      RpcEndpointAddress(host, port, name)
+    } catch {
+      case e: java.net.URISyntaxException =>
+        throw new SparkException("Invalid Spark URL: " + sparkUrl, e)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala
new file mode 100644
index 0000000..99f20da2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
+
+/**
+ * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] 
exists.
+ *
+ * This is used when setting up a remote endpoint reference.
+ */
+private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, 
dispatcher: Dispatcher)
+  extends RpcEndpoint {
+
+  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
+    case RpcEndpointVerifier.CheckExistence(name) => 
context.reply(dispatcher.verify(name))
+  }
+}
+
+private[netty] object RpcEndpointVerifier {
+  val NAME = "endpoint-verifier"
+
+  /** A message used to ask the remote [[RpcEndpointVerifier]] if an 
[[RpcEndpoint]] exists. */
+  case class CheckExistence(name: String)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala 
b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
index a5d43d3..973a07a 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.SparkFunSuite
 class NettyRpcAddressSuite extends SparkFunSuite {
 
   test("toString") {
-    val addr = NettyRpcAddress("localhost", 12345, "test")
+    val addr = RpcEndpointAddress("localhost", 12345, "test")
     assert(addr.toString === "spark://test@localhost:12345")
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala 
b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index f24f78b..5430e4c 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -42,9 +42,6 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
     when(channel.remoteAddress()).thenReturn(new 
InetSocketAddress("localhost", 40000))
     nettyRpcHandler.receive(client, null, null)
 
-    when(channel.remoteAddress()).thenReturn(new 
InetSocketAddress("localhost", 40001))
-    nettyRpcHandler.receive(client, null, null)
-
     verify(dispatcher, 
times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345)))
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to