Repository: spark Updated Branches: refs/heads/master 615cc858c -> cf2e0ae72
[SPARK-11096] Post-hoc review Netty based RPC implementation - round 2 A few more changes: 1. Renamed IDVerifier -> RpcEndpointVerifier 2. Renamed NettyRpcAddress -> RpcEndpointAddress 3. Simplified NettyRpcHandler a bit by removing the connection count tracking. This is OK because I now force spark.shuffle.io.numConnectionsPerPeer to 1 4. Reduced spark.rpc.connect.threads to 64. It would be great to eventually remove this extra thread pool. 5. Minor cleanup & documentation. Author: Reynold Xin <[email protected]> Closes #9112 from rxin/SPARK-11096. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cf2e0ae7 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cf2e0ae7 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cf2e0ae7 Branch: refs/heads/master Commit: cf2e0ae7205443f052463e8cb9334ae2b6df2d0e Parents: 615cc85 Author: Reynold Xin <[email protected]> Authored: Wed Oct 14 12:41:02 2015 -0700 Committer: Reynold Xin <[email protected]> Committed: Wed Oct 14 12:41:02 2015 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/rpc/RpcEnv.scala | 9 -- .../org/apache/spark/rpc/netty/Dispatcher.scala | 7 +- .../org/apache/spark/rpc/netty/IDVerifier.scala | 39 ------- .../spark/rpc/netty/NettyRpcAddress.scala | 56 --------- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 114 ++++++++----------- .../spark/rpc/netty/RpcEndpointAddress.scala | 60 ++++++++++ .../spark/rpc/netty/RpcEndpointVerifier.scala | 40 +++++++ .../spark/rpc/netty/NettyRpcAddressSuite.scala | 2 +- .../spark/rpc/netty/NettyRpcHandlerSuite.scala | 3 - 9 files changed, 152 insertions(+), 178 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index ef491a0..2c4a8b9 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -94,15 +94,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { } /** - * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` - * asynchronously. - */ - def asyncSetupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { - asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName)) - } - - /** * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. * This is a blocking action. */ http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 398e9ea..f1a8273 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -29,6 +29,9 @@ import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc._ import org.apache.spark.util.ThreadUtils +/** + * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). + */ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private class EndpointData( @@ -42,7 +45,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] // Track the receivers whose inboxes may contain messages. - private val receivers = new LinkedBlockingQueue[EndpointData]() + private val receivers = new LinkedBlockingQueue[EndpointData] /** * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced @@ -52,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private var stopped = false def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { - val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name) + val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name) val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) synchronized { if (stopped) { http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala deleted file mode 100644 index fa9a3eb..0000000 --- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.rpc.netty - -import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} - -/** - * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists - */ -private[netty] case class ID(name: String) - -/** - * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]] - */ -private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher) - extends RpcEndpoint { - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case ID(name) => context.reply(dispatcher.verify(name)) - } -} - -private[netty] object IDVerifier { - val NAME = "id-verifier" -} http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala deleted file mode 100644 index 1876b25..0000000 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.netty - -import java.net.URI - -import org.apache.spark.SparkException -import org.apache.spark.rpc.RpcAddress - -private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) { - - def toRpcAddress: RpcAddress = RpcAddress(host, port) - - override val toString = s"spark://$name@$host:$port" -} - -private[netty] object NettyRpcAddress { - - def apply(sparkUrl: String): NettyRpcAddress = { - try { - val uri = new URI(sparkUrl) - val host = uri.getHost - val port = uri.getPort - val name = uri.getUserInfo - if (uri.getScheme != "spark" || - host == null || - port < 0 || - name == null || - (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null - uri.getFragment != null || - uri.getQuery != null) { - throw new SparkException("Invalid Spark URL: " + sparkUrl) - } - NettyRpcAddress(host, port, name) - } catch { - case e: java.net.URISyntaxException => - throw new SparkException("Invalid Spark URL: " + sparkUrl, e) - } - } - -} http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 89b6df7..a2b28c5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -22,7 +22,6 @@ import java.nio.ByteBuffer import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag @@ -45,8 +44,10 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - private val transportConf = - SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0)) + // Override numConnectionsPerPeer to 1 for RPC. + private val transportConf = SparkTransportConf.fromSparkConf( + conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), + conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) @@ -54,14 +55,14 @@ private[netty] class NettyRpcEnv( new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) private val clientFactory = { - val bootstraps: Seq[TransportClientBootstrap] = + val bootstraps: java.util.List[TransportClientBootstrap] = if (securityManager.isAuthenticationEnabled()) { - Seq(new SaslClientBootstrap(transportConf, "", securityManager, + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, securityManager.isSaslEncryptionEnabled())) } else { - Nil + java.util.Collections.emptyList[TransportClientBootstrap] } - transportContext.createClientFactory(bootstraps.asJava) + transportContext.createClientFactory(bootstraps) } val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") @@ -71,7 +72,7 @@ private[netty] class NettyRpcEnv( // TODO: a non-blocking TransportClientFactory.createClient in future private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( "netty-rpc-connection", - conf.getInt("spark.rpc.connect.threads", 256)) + conf.getInt("spark.rpc.connect.threads", 64)) @volatile private var server: TransportServer = _ @@ -83,7 +84,8 @@ private[netty] class NettyRpcEnv( java.util.Collections.emptyList() } server = transportContext.createServer(port, bootstraps) - dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) + dispatcher.registerRpcEndpoint( + RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } override lazy val address: RpcAddress = { @@ -96,11 +98,11 @@ private[netty] class NettyRpcEnv( } def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { - val addr = NettyRpcAddress(uri) + val addr = RpcEndpointAddress(uri) val endpointRef = new NettyRpcEndpointRef(conf, addr, this) - val idVerifierRef = - new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this) - idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find => + val verifier = new NettyRpcEndpointRef( + conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this) + verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find => if (find) { Future.successful(endpointRef) } else { @@ -117,16 +119,18 @@ private[netty] class NettyRpcEnv( private[netty] def send(message: RequestMessage): Unit = { val remoteAddr = message.receiver.address if (remoteAddr == address) { + // Message to a local RPC endpoint. val promise = Promise[Any]() dispatcher.postLocalMessage(message, promise) promise.future.onComplete { case Success(response) => val ack = response.asInstanceOf[Ack] - logDebug(s"Receive ack from ${ack.sender}") + logTrace(s"Received ack from ${ack.sender}") case Failure(e) => logError(s"Exception when sending $message", e) }(ThreadUtils.sameThread) } else { + // Message to a remote RPC endpoint. try { // `createClient` will block if it cannot find a known connection, so we should run it in // clientConnectionExecutor @@ -204,11 +208,10 @@ private[netty] class NettyRpcEnv( } }) } catch { - case e: RejectedExecutionException => { + case e: RejectedExecutionException => if (!promise.tryFailure(e)) { logWarning(s"Ignore failure", e) } - } } } promise.future @@ -231,7 +234,7 @@ private[netty] class NettyRpcEnv( } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = - new NettyRpcAddress(address.host, address.port, endpointName).toString + new RpcEndpointAddress(address.host, address.port, endpointName).toString override def shutdown(): Unit = { cleanup() @@ -310,9 +313,9 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) @transient @volatile private var nettyEnv: NettyRpcEnv = _ - @transient @volatile private var _address: NettyRpcAddress = _ + @transient @volatile private var _address: RpcEndpointAddress = _ - def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) { + def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) { this(conf) this._address = _address this.nettyEnv = nettyEnv @@ -322,7 +325,7 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject() - _address = in.readObject().asInstanceOf[NettyRpcAddress] + _address = in.readObject().asInstanceOf[RpcEndpointAddress] nettyEnv = NettyRpcEnv.currentEnv.value } @@ -406,49 +409,37 @@ private[netty] class NettyRpcHandler( private type RemoteEnvAddress = RpcAddress // Store all client addresses and their NettyRpcEnv addresses. + // TODO: Is this even necessary? @GuardedBy("this") private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]() - // Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection - // count because `TransportClientFactory.createClient` will create multiple connections - // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection - // to send the message. See `TransportClientFactory.createClient` for more details. - @GuardedBy("this") - private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]() - override def receive( client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { val requestMessage = nettyEnv.deserialize[RequestMessage](message) - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val remoteEnvAddress = requestMessage.senderAddress val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage: Option[RemoteProcessConnected] = - synchronized { - // If the first connection to a remote RpcEnv is found, we should broadcast "Associated" - if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { - // clientAddr connects at the first time - val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) - // Increase the connection number of remoteEnvAddress - remoteConnectionCount.put(remoteEnvAddress, count + 1) - if (count == 0) { - // This is the first connection, so fire "Associated" - Some(RemoteProcessConnected(remoteEnvAddress)) - } else { - None - } - } else { - None - } + + // TODO: Can we add connection callback (channel registered) to the underlying framework? + // A variable to track whether we should dispatch the RemoteProcessConnected message. + var dispatchRemoteProcessConnected = false + synchronized { + if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { + // clientAddr connects at the first time, fire "RemoteProcessConnected" + dispatchRemoteProcessConnected = true } - broadcastMessage.foreach(dispatcher.postToAll) + } + if (dispatchRemoteProcessConnected) { + dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress)) + } dispatcher.postRemoteMessage(requestMessage, callback) } override def getStreamManager: StreamManager = new OneForOneStreamManager override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) val broadcastMessage = @@ -469,34 +460,21 @@ private[netty] class NettyRpcHandler( } override def connectionTerminated(client: TransportClient): Unit = { - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = - synchronized { - // If the last connection to a remote RpcEnv is terminated, we should broadcast - // "Disassociated" - remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => - remoteAddresses -= clientAddr - val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) - assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent") - if (count - 1 == 0) { - // We lost all clients, so clean up and fire "Disassociated" - remoteConnectionCount.remove(remoteEnvAddress) - Some(RemoteProcessDisconnected(remoteEnvAddress)) - } else { - // Decrease the connection number of remoteEnvAddress - remoteConnectionCount.put(remoteEnvAddress, count - 1) - None - } - } + val messageOpt: Option[RemoteProcessDisconnected] = + synchronized { + remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => + remoteAddresses -= clientAddr + Some(RemoteProcessDisconnected(remoteEnvAddress)) } - broadcastMessage.foreach(dispatcher.postToAll) + } + messageOpt.foreach(dispatcher.postToAll) } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". // See java.net.Socket.getRemoteSocketAddress } } - } http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala new file mode 100644 index 0000000..87b6236 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import org.apache.spark.SparkException +import org.apache.spark.rpc.RpcAddress + +/** + * An address identifier for an RPC endpoint. + * + * @param host host name of the remote process. + * @param port the port the remote RPC environment binds to. + * @param name name of the remote endpoint. + */ +private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) { + + def toRpcAddress: RpcAddress = RpcAddress(host, port) + + override val toString = s"spark://$name@$host:$port" +} + +private[netty] object RpcEndpointAddress { + + def apply(sparkUrl: String): RpcEndpointAddress = { + try { + val uri = new java.net.URI(sparkUrl) + val host = uri.getHost + val port = uri.getPort + val name = uri.getUserInfo + if (uri.getScheme != "spark" || + host == null || + port < 0 || + name == null || + (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null + uri.getFragment != null || + uri.getQuery != null) { + throw new SparkException("Invalid Spark URL: " + sparkUrl) + } + RpcEndpointAddress(host, port, name) + } catch { + case e: java.net.URISyntaxException => + throw new SparkException("Invalid Spark URL: " + sparkUrl, e) + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala new file mode 100644 index 0000000..99f20da2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} + +/** + * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists. + * + * This is used when setting up a remote endpoint reference. + */ +private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher) + extends RpcEndpoint { + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name)) + } +} + +private[netty] object RpcEndpointVerifier { + val NAME = "endpoint-verifier" + + /** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */ + case class CheckExistence(name: String) +} http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala index a5d43d3..973a07a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkFunSuite class NettyRpcAddressSuite extends SparkFunSuite { test("toString") { - val addr = NettyRpcAddress("localhost", 12345, "test") + val addr = RpcEndpointAddress("localhost", 12345, "test") assert(addr.toString === "spark://test@localhost:12345") } http://git-wip-us.apache.org/repos/asf/spark/blob/cf2e0ae7/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index f24f78b..5430e4c 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -42,9 +42,6 @@ class NettyRpcHandlerSuite extends SparkFunSuite { when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.receive(client, null, null) - when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001)) - nettyRpcHandler.receive(client, null, null) - verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345))) } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
