This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 884f6f71172 [SPARK-45544][CORE] Integrate SSL support into
TransportContext
884f6f71172 is described below
commit 884f6f71172156ccc7d95ed022c8fb8baadc3c0a
Author: Hasnain Lakhani <[email protected]>
AuthorDate: Sun Oct 29 20:58:18 2023 -0500
[SPARK-45544][CORE] Integrate SSL support into TransportContext
### What changes were proposed in this pull request?
This integrates SSL support into TransportContext and related modules so
that the RPC SSL functionality can work when properly configured.
### Why are the changes needed?
This is needed in order to support SSL for RPC connections.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI
Ran the following tests:
```
build/sbt -P yarn
> project network-common
> testOnly
> project network-shuffle
> testOnly
> project core
> testOnly *Ssl*
> project yarn
> testOnly
org.apache.spark.network.yarn.SslYarnShuffleServiceWithRocksDBBackendSuite
```
I verified traffic was encrypted using TLS using two mechanisms:
* Enabled trace level logging for Netty and JDK SSL and saw logs confirming
TLS handshakes were happening
* I ran wireshark on my machine and snooped on traffic while sending
queries shuffling a fixed string. Without any encryption, I could find that
string in the network traffic. With this encryption enabled, that string did
not show up, and wireshark logs confirmed a TLS handshake was happening.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43541 from hasnain-db/spark-tls-final.
Authored-by: Hasnain Lakhani <[email protected]>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
.../org/apache/spark/network/TransportContext.java | 70 ++++++++++++++++++++--
.../network/client/TransportClientFactory.java | 26 +++++++-
.../spark/network/server/TransportServer.java | 2 +-
.../apache/spark/network/util/TransportConf.java | 8 ---
.../spark/network/ChunkFetchIntegrationSuite.java | 6 +-
.../network/SslChunkFetchIntegrationSuite.java | 22 ++++---
.../client/SslTransportClientFactorySuite.java | 29 +++++----
.../client/TransportClientFactorySuite.java | 8 +--
.../network/shuffle/ShuffleTransportContext.java | 10 ++--
.../shuffle/ExternalShuffleIntegrationSuite.java | 29 +++++----
.../shuffle/ExternalShuffleSecuritySuite.java | 14 ++++-
.../shuffle/ShuffleTransportContextSuite.java | 33 +++++-----
.../SslExternalShuffleIntegrationSuite.java | 44 ++++++++++++++
.../shuffle/SslExternalShuffleSecuritySuite.java | 35 +++++++----
.../shuffle/SslShuffleTransportContextSuite.java | 28 +++++----
.../network/yarn/SslYarnShuffleServiceSuite.scala | 2 +-
16 files changed, 265 insertions(+), 101 deletions(-)
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
index 51d074a4ddb..90ca4f4c46a 100644
---
a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
+++
b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -23,13 +23,17 @@ import io.netty.handler.codec.MessageToMessageDecoder;
import java.io.Closeable;
import java.util.ArrayList;
import java.util.List;
+import javax.annotation.Nullable;
import com.codahale.metrics.Counter;
import io.netty.channel.Channel;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleStateHandler;
+import io.netty.handler.codec.MessageToMessageEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -37,6 +41,8 @@ import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.SslMessageEncoder;
import org.apache.spark.network.protocol.MessageDecoder;
import org.apache.spark.network.protocol.MessageEncoder;
import org.apache.spark.network.server.ChunkFetchRequestHandler;
@@ -45,6 +51,7 @@ import
org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.server.TransportRequestHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.ssl.SSLFactory;
import org.apache.spark.network.util.IOMode;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.NettyLogger;
@@ -72,6 +79,8 @@ public class TransportContext implements Closeable {
private final TransportConf conf;
private final RpcHandler rpcHandler;
private final boolean closeIdleConnections;
+ // Non-null if SSL is enabled, null otherwise.
+ @Nullable private final SSLFactory sslFactory;
// Number of registered connections to the shuffle service
private Counter registeredConnections = new Counter();
@@ -87,7 +96,8 @@ public class TransportContext implements Closeable {
* RPC to load it and cause to load the non-exist matcher class again. JVM
will report
* `ClassCircularityError` to prevent such infinite recursion. (See
SPARK-17714)
*/
- private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
+ private static final MessageToMessageEncoder<Message> ENCODER =
MessageEncoder.INSTANCE;
+ private static final MessageToMessageEncoder<Message> SSL_ENCODER =
SslMessageEncoder.INSTANCE;
private static final MessageDecoder DECODER = MessageDecoder.INSTANCE;
// Separate thread pool for handling ChunkFetchRequest. This helps to enable
throttling
@@ -125,6 +135,7 @@ public class TransportContext implements Closeable {
this.conf = conf;
this.rpcHandler = rpcHandler;
this.closeIdleConnections = closeIdleConnections;
+ this.sslFactory = createSslFactory();
if (conf.getModuleName() != null &&
conf.getModuleName().equalsIgnoreCase("shuffle") &&
@@ -171,8 +182,12 @@ public class TransportContext implements Closeable {
return createServer(0, new ArrayList<>());
}
- public TransportChannelHandler initializePipeline(SocketChannel channel) {
- return initializePipeline(channel, rpcHandler);
+ public TransportChannelHandler initializePipeline(SocketChannel channel,
boolean isClient) {
+ return initializePipeline(channel, rpcHandler, isClient);
+ }
+
+ public boolean sslEncryptionEnabled() {
+ return this.sslFactory != null;
}
/**
@@ -189,15 +204,30 @@ public class TransportContext implements Closeable {
*/
public TransportChannelHandler initializePipeline(
SocketChannel channel,
- RpcHandler channelRpcHandler) {
+ RpcHandler channelRpcHandler,
+ boolean isClient) {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel,
channelRpcHandler);
ChannelPipeline pipeline = channel.pipeline();
if (nettyLogger.getLoggingHandler() != null) {
pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler());
}
+
+ if (sslEncryptionEnabled()) {
+ SslHandler sslHandler;
+ try {
+ sslHandler = new SslHandler(sslFactory.createSSLEngine(isClient,
channel.alloc()));
+ } catch (Exception e) {
+ throw new IllegalStateException("Error creating Netty SslHandler",
e);
+ }
+ pipeline.addFirst("NettySslEncryptionHandler", sslHandler);
+ // Cannot use zero-copy with HTTPS, so we add in our
ChunkedWriteHandler just before the
+ // MessageEncoder
+ pipeline.addLast("chunkedWriter", new ChunkedWriteHandler());
+ }
+
pipeline
- .addLast("encoder", ENCODER)
+ .addLast("encoder", sslEncryptionEnabled()? SSL_ENCODER : ENCODER)
.addLast(TransportFrameDecoder.HANDLER_NAME,
NettyUtils.createFrameDecoder())
.addLast("decoder", getDecoder())
.addLast("idleStateHandler",
@@ -223,6 +253,33 @@ public class TransportContext implements Closeable {
return DECODER;
}
+ private SSLFactory createSslFactory() {
+ if (conf.sslRpcEnabled()) {
+ if (conf.sslRpcEnabledAndKeysAreValid()) {
+ return new SSLFactory.Builder()
+ .openSslEnabled(conf.sslRpcOpenSslEnabled())
+ .requestedProtocol(conf.sslRpcProtocol())
+ .requestedCiphers(conf.sslRpcRequestedCiphers())
+ .keyStore(conf.sslRpcKeyStore(), conf.sslRpcKeyStorePassword())
+ .privateKey(conf.sslRpcPrivateKey())
+ .keyPassword(conf.sslRpcKeyPassword())
+ .certChain(conf.sslRpcCertChain())
+ .trustStore(
+ conf.sslRpcTrustStore(),
+ conf.sslRpcTrustStorePassword(),
+ conf.sslRpcTrustStoreReloadingEnabled(),
+ conf.sslRpctrustStoreReloadIntervalMs())
+ .build();
+ } else {
+ logger.error("RPC SSL encryption enabled but keys not found!" +
+ "Please ensure the configured keys are present.");
+ throw new IllegalArgumentException("RPC SSL encryption enabled but
keys not found!");
+ }
+ } else {
+ return null;
+ }
+ }
+
/**
* Creates the server- and client-side handler which is used to handle both
RequestMessages and
* ResponseMessages. The channel is expected to have been successfully
created, though certain
@@ -255,5 +312,8 @@ public class TransportContext implements Closeable {
if (chunkFetchWorkers != null) {
chunkFetchWorkers.shutdownGracefully();
}
+ if (sslFactory != null) {
+ sslFactory.destroy();
+ }
}
}
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 4c1efd69206..fd48020caac 100644
---
a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++
b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -39,6 +39,9 @@ import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.GenericFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -268,7 +271,7 @@ public class TransportClientFactory implements Closeable {
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
- TransportChannelHandler clientHandler = context.initializePipeline(ch);
+ TransportChannelHandler clientHandler = context.initializePipeline(ch,
true);
clientRef.set(clientHandler.getClient());
channelRef.set(ch);
}
@@ -293,6 +296,27 @@ public class TransportClientFactory implements Closeable {
} else if (cf.cause() != null) {
throw new IOException(String.format("Failed to connect to %s", address),
cf.cause());
}
+ if (context.sslEncryptionEnabled()) {
+ final SslHandler sslHandler =
cf.channel().pipeline().get(SslHandler.class);
+ Future<Channel> future = sslHandler.handshakeFuture().addListener(
+ new GenericFutureListener<Future<Channel>>() {
+ @Override
+ public void operationComplete(final Future<Channel> handshakeFuture)
{
+ if (handshakeFuture.isSuccess()) {
+ logger.debug("{} successfully completed TLS handshake to ",
address);
+ } else {
+ logger.info(
+ "failed to complete TLS handshake to " + address,
handshakeFuture.cause());
+ cf.channel().close();
+ }
+ }
+ });
+ if (!future.await(conf.connectionTimeoutMs())) {
+ cf.channel().close();
+ throw new IOException(
+ String.format("Failed to connect to %s within connection timeout",
address));
+ }
+ }
TransportClient client = clientRef.get();
Channel channel = channelRef.get();
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
index 5b5b3f9d901..6f2e4b8a502 100644
---
a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
+++
b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -140,7 +140,7 @@ public class TransportServer implements Closeable {
for (TransportServerBootstrap bootstrap : bootstraps) {
rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
}
- context.initializePipeline(ch, rpcHandler);
+ context.initializePipeline(ch, rpcHandler, false);
}
});
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 3ebb38e310f..eb85d2bb561 100644
---
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -401,14 +401,6 @@ public class TransportConf {
}
}
- /**
- * If we can dangerously fallback to unencrypted connections if RPC over SSL
is enabled
- * but the key files are not present
- */
- public boolean sslRpcDangerouslyFallbackIfKeysNotPresent() {
- return
conf.getBoolean("spark.ssl.rpc.dangerouslyFallbackIfKeysNotPresent", false);
- }
-
/**
* Flag indicating whether to share the pooled ByteBuf allocators between
the different Netty
* channels. If enabled then only two pooled ByteBuf allocators are created:
one where caching
diff --git
a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
index 2026d3b9524..576a106934f 100644
---
a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
+++
b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -65,8 +65,13 @@ public class ChunkFetchIntegrationSuite {
static ManagedBuffer bufferChunk;
static ManagedBuffer fileChunk;
+ // This is split out so it can be invoked in a subclass with a different
config
@BeforeAll
public static void setUp() throws Exception {
+ doSetUpWithConfig(new TransportConf("shuffle", MapConfigProvider.EMPTY));
+ }
+
+ public static void doSetUpWithConfig(final TransportConf conf) throws
Exception {
int bufSize = 100000;
final ByteBuffer buf = ByteBuffer.allocate(bufSize);
for (int i = 0; i < bufSize; i ++) {
@@ -88,7 +93,6 @@ public class ChunkFetchIntegrationSuite {
Closeables.close(fp, shouldSuppressIOException);
}
- final TransportConf conf = new TransportConf("shuffle",
MapConfigProvider.EMPTY);
fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10,
testFile.length() - 25);
streamManager = new StreamManager() {
diff --git
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
b/common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java
similarity index 59%
copy from
resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
copy to
common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java
index 322d6bfdb7c..783ffd4b8c1 100644
---
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
+++
b/common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java
@@ -14,21 +14,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.network;
-package org.apache.spark.network.yarn
+import org.junit.jupiter.api.BeforeAll;
-import org.apache.spark.network.ssl.SslSampleConfigs
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.ssl.SslSampleConfigs;
-class SslYarnShuffleServiceWithRocksDBBackendSuite
- extends YarnShuffleServiceWithRocksDBBackendSuite {
- /**
- * Override to add "spark.ssl.rpc.*" configuration parameters...
- */
- override def beforeEach(): Unit = {
- super.beforeEach()
- // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to
import here.
- SslSampleConfigs.createDefaultConfigMap().entrySet().
- forEach(entry => yarnConfig.set(entry.getKey, entry.getValue))
+public class SslChunkFetchIntegrationSuite extends ChunkFetchIntegrationSuite {
+
+ @BeforeAll
+ public static void setUp() throws Exception {
+ doSetUpWithConfig(new TransportConf(
+ "shuffle",
SslSampleConfigs.createDefaultConfigProviderForRpcNamespace()));
}
}
diff --git
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
b/common/network-common/src/test/java/org/apache/spark/network/client/SslTransportClientFactorySuite.java
similarity index 51%
copy from
resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
copy to
common/network-common/src/test/java/org/apache/spark/network/client/SslTransportClientFactorySuite.java
index 322d6bfdb7c..79b76b633f9 100644
---
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
+++
b/common/network-common/src/test/java/org/apache/spark/network/client/SslTransportClientFactorySuite.java
@@ -15,20 +15,25 @@
* limitations under the License.
*/
-package org.apache.spark.network.yarn
+package org.apache.spark.network.client;
-import org.apache.spark.network.ssl.SslSampleConfigs
+import org.junit.jupiter.api.BeforeEach;
-class SslYarnShuffleServiceWithRocksDBBackendSuite
- extends YarnShuffleServiceWithRocksDBBackendSuite {
+import org.apache.spark.network.ssl.SslSampleConfigs;
+import org.apache.spark.network.server.NoOpRpcHandler;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.TransportContext;
- /**
- * Override to add "spark.ssl.rpc.*" configuration parameters...
- */
- override def beforeEach(): Unit = {
- super.beforeEach()
- // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to
import here.
- SslSampleConfigs.createDefaultConfigMap().entrySet().
- forEach(entry => yarnConfig.set(entry.getKey, entry.getValue))
+public class SslTransportClientFactorySuite extends
TransportClientFactorySuite {
+
+ @BeforeEach
+ public void setUp() {
+ conf = new TransportConf(
+ "shuffle",
SslSampleConfigs.createDefaultConfigProviderForRpcNamespace());
+ RpcHandler rpcHandler = new NoOpRpcHandler();
+ context = new TransportContext(conf, rpcHandler);
+ server1 = context.createServer();
+ server2 = context.createServer();
}
}
diff --git
a/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java
b/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java
index 49a2d570d96..b57f0be920c 100644
---
a/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java
+++
b/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java
@@ -44,10 +44,10 @@ import org.apache.spark.network.util.TransportConf;
import static org.junit.jupiter.api.Assertions.*;
public class TransportClientFactorySuite {
- private TransportConf conf;
- private TransportContext context;
- private TransportServer server1;
- private TransportServer server2;
+ protected TransportConf conf;
+ protected TransportContext context;
+ protected TransportServer server1;
+ protected TransportServer server2;
@BeforeEach
public void setUp() {
diff --git
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java
index e0971d49510..feaaa570b73 100644
---
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java
+++
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java
@@ -22,6 +22,7 @@ import java.util.List;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
@@ -81,16 +82,16 @@ public class ShuffleTransportContext extends
TransportContext {
}
@Override
- public TransportChannelHandler initializePipeline(SocketChannel channel) {
- TransportChannelHandler ch = super.initializePipeline(channel);
+ public TransportChannelHandler initializePipeline(SocketChannel channel,
boolean isClient) {
+ TransportChannelHandler ch = super.initializePipeline(channel, isClient);
addHandlerToPipeline(channel, ch);
return ch;
}
@Override
public TransportChannelHandler initializePipeline(SocketChannel channel,
- RpcHandler channelRpcHandler) {
- TransportChannelHandler ch = super.initializePipeline(channel,
channelRpcHandler);
+ RpcHandler channelRpcHandler, boolean isClient) {
+ TransportChannelHandler ch = super.initializePipeline(channel,
channelRpcHandler, isClient);
addHandlerToPipeline(channel, ch);
return ch;
}
@@ -112,6 +113,7 @@ public class ShuffleTransportContext extends
TransportContext {
return finalizeWorkers == null ? super.getDecoder() : SHUFFLE_DECODER;
}
+ @ChannelHandler.Sharable
static class ShuffleMessageDecoder extends MessageToMessageDecoder<ByteBuf> {
private final MessageDecoder delegate;
diff --git
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index b5ffa30f62d..73cb133f17e 100644
---
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -32,7 +32,6 @@ import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
-import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.server.OneForOneStreamManager;
@@ -57,11 +56,11 @@ public class ExternalShuffleIntegrationSuite {
private static final String APP_ID = "app-id";
private static final String SORT_MANAGER =
"org.apache.spark.shuffle.sort.SortShuffleManager";
- private static final int RDD_ID = 1;
- private static final int SPLIT_INDEX_VALID_BLOCK = 0;
+ protected static final int RDD_ID = 1;
+ protected static final int SPLIT_INDEX_VALID_BLOCK = 0;
private static final int SPLIT_INDEX_MISSING_FILE = 1;
- private static final int SPLIT_INDEX_CORRUPT_LENGTH = 2;
- private static final int SPLIT_INDEX_VALID_BLOCK_TO_RM = 3;
+ protected static final int SPLIT_INDEX_CORRUPT_LENGTH = 2;
+ protected static final int SPLIT_INDEX_VALID_BLOCK_TO_RM = 3;
private static final int SPLIT_INDEX_MISSING_BLOCK_TO_RM = 4;
// Executor 0 is sort-based
@@ -86,8 +85,20 @@ public class ExternalShuffleIntegrationSuite {
new byte[54321],
};
+ private static TransportConf createTransportConf(int maxRetries, boolean
rddEnabled) {
+ HashMap<String, String> config = new HashMap<>();
+ config.put("spark.shuffle.io.maxRetries", String.valueOf(maxRetries));
+ config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED,
String.valueOf(rddEnabled));
+ return new TransportConf("shuffle", new MapConfigProvider(config));
+ }
+
+ // This is split out so it can be invoked in a subclass with a different
config
@BeforeAll
public static void beforeAll() throws IOException {
+ doBeforeAllWithConfig(createTransportConf(0, true));
+ }
+
+ public static void doBeforeAllWithConfig(TransportConf transportConf) throws
IOException {
Random rand = new Random();
for (byte[] block : exec0Blocks) {
@@ -105,10 +116,7 @@ public class ExternalShuffleIntegrationSuite {
dataContext0.insertCachedRddData(RDD_ID, SPLIT_INDEX_VALID_BLOCK,
exec0RddBlockValid);
dataContext0.insertCachedRddData(RDD_ID, SPLIT_INDEX_VALID_BLOCK_TO_RM,
exec0RddBlockToRemove);
- HashMap<String, String> config = new HashMap<>();
- config.put("spark.shuffle.io.maxRetries", "0");
- config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, "true");
- conf = new TransportConf("shuffle", new MapConfigProvider(config));
+ conf = transportConf;
handler = new ExternalBlockHandler(
new OneForOneStreamManager(),
new ExternalShuffleBlockResolver(conf, null) {
@@ -319,8 +327,7 @@ public class ExternalShuffleIntegrationSuite {
@Test
public void testFetchNoServer() throws Exception {
- TransportConf clientConf = new TransportConf("shuffle",
- new MapConfigProvider(ImmutableMap.of("spark.shuffle.io.maxRetries",
"0")));
+ TransportConf clientConf = createTransportConf(0, false);
registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
FetchResult execFetch = fetchBlocks("exec-0",
new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, clientConf, 1 /* port
*/);
diff --git
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index b8beec303ae..76f82800c50 100644
---
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -39,10 +39,19 @@ import org.apache.spark.network.util.TransportConf;
public class ExternalShuffleSecuritySuite {
- TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
+ TransportConf conf = createTransportConf(false);
TransportServer server;
TransportContext transportContext;
+ protected TransportConf createTransportConf(boolean encrypt) {
+ if (encrypt) {
+ return new TransportConf("shuffle", new MapConfigProvider(
+ ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true")));
+ } else {
+ return new TransportConf("shuffle", MapConfigProvider.EMPTY);
+ }
+ }
+
@BeforeEach
public void beforeEach() throws IOException {
transportContext = new TransportContext(conf, new
ExternalBlockHandler(conf, null));
@@ -92,8 +101,7 @@ public class ExternalShuffleSecuritySuite {
throws IOException, InterruptedException {
TransportConf testConf = conf;
if (encrypt) {
- testConf = new TransportConf("shuffle", new MapConfigProvider(
- ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true")));
+ testConf = createTransportConf(encrypt);
}
try (ExternalBlockStoreClient client =
diff --git
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java
index 5484e8131a8..de164474766 100644
---
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java
+++
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java
@@ -60,13 +60,16 @@ public class ShuffleTransportContextSuite {
blockHandler = mock(ExternalBlockHandler.class);
}
- ShuffleTransportContext createShuffleTransportContext(boolean
separateFinalizeThread)
- throws IOException {
+ protected TransportConf createTransportConf(boolean separateFinalizeThread) {
Map<String, String> configs = new HashMap<>();
configs.put("spark.shuffle.server.finalizeShuffleMergeThreadsPercent",
- separateFinalizeThread ? "1" : "0");
- TransportConf transportConf = new TransportConf("shuffle",
- new MapConfigProvider(configs));
+ separateFinalizeThread ? "1" : "0");
+ return new TransportConf("shuffle", new MapConfigProvider(configs));
+ }
+
+ ShuffleTransportContext createShuffleTransportContext(boolean
separateFinalizeThread)
+ throws IOException {
+ TransportConf transportConf = createTransportConf(separateFinalizeThread);
return new ShuffleTransportContext(transportConf, blockHandler, true);
}
@@ -90,15 +93,17 @@ public class ShuffleTransportContextSuite {
public void testInitializePipeline() throws IOException {
// SPARK-43987: test that the FinalizedHandler is added to the pipeline
only when configured
for (boolean enabled : new boolean[]{true, false}) {
- ShuffleTransportContext ctx = createShuffleTransportContext(enabled);
- SocketChannel channel = new NioSocketChannel();
- RpcHandler rpcHandler = mock(RpcHandler.class);
- ctx.initializePipeline(channel, rpcHandler);
- String handlerName =
ShuffleTransportContext.FinalizedHandler.HANDLER_NAME;
- if (enabled) {
- Assertions.assertNotNull(channel.pipeline().get(handlerName));
- } else {
- Assertions.assertNull(channel.pipeline().get(handlerName));
+ for (boolean client: new boolean[]{true, false}) {
+ ShuffleTransportContext ctx = createShuffleTransportContext(enabled);
+ SocketChannel channel = new NioSocketChannel();
+ RpcHandler rpcHandler = mock(RpcHandler.class);
+ ctx.initializePipeline(channel, rpcHandler, client);
+ String handlerName =
ShuffleTransportContext.FinalizedHandler.HANDLER_NAME;
+ if (enabled) {
+ Assertions.assertNotNull(channel.pipeline().get(handlerName));
+ } else {
+ Assertions.assertNull(channel.pipeline().get(handlerName));
+ }
}
}
}
diff --git
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleIntegrationSuite.java
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleIntegrationSuite.java
new file mode 100644
index 00000000000..3591ccad150
--- /dev/null
+++
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleIntegrationSuite.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.util.HashMap;
+
+import org.junit.jupiter.api.BeforeAll;
+
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.ssl.SslSampleConfigs;
+
+public class SslExternalShuffleIntegrationSuite extends
ExternalShuffleIntegrationSuite {
+
+ private static TransportConf createTransportConf(int maxRetries, boolean
rddEnabled) {
+ HashMap<String, String> config = new HashMap<>();
+ config.put("spark.shuffle.io.maxRetries", String.valueOf(maxRetries));
+ config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED,
String.valueOf(rddEnabled));
+ return new TransportConf(
+ "shuffle",
+
SslSampleConfigs.createDefaultConfigProviderForRpcNamespaceWithAdditionalEntries(config)
+ );
+ }
+
+ @BeforeAll
+ public static void beforeAll() throws IOException {
+ doBeforeAllWithConfig(createTransportConf(0, true));
+ }
+}
diff --git
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleSecuritySuite.java
similarity index 50%
copy from
resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
copy to
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleSecuritySuite.java
index 322d6bfdb7c..061d63dbcd7 100644
---
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
+++
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleSecuritySuite.java
@@ -15,20 +15,31 @@
* limitations under the License.
*/
-package org.apache.spark.network.yarn
+package org.apache.spark.network.shuffle;
-import org.apache.spark.network.ssl.SslSampleConfigs
+import com.google.common.collect.ImmutableMap;
-class SslYarnShuffleServiceWithRocksDBBackendSuite
- extends YarnShuffleServiceWithRocksDBBackendSuite {
+import org.apache.spark.network.ssl.SslSampleConfigs;
+import org.apache.spark.network.util.TransportConf;
- /**
- * Override to add "spark.ssl.rpc.*" configuration parameters...
- */
- override def beforeEach(): Unit = {
- super.beforeEach()
- // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to
import here.
- SslSampleConfigs.createDefaultConfigMap().entrySet().
- forEach(entry => yarnConfig.set(entry.getKey, entry.getValue))
+public class SslExternalShuffleSecuritySuite extends
ExternalShuffleSecuritySuite {
+
+ @Override
+ protected TransportConf createTransportConf(boolean encrypt) {
+ if (encrypt) {
+ return new TransportConf(
+ "shuffle",
+
SslSampleConfigs.createDefaultConfigProviderForRpcNamespaceWithAdditionalEntries(
+ ImmutableMap.of(
+ "spark.authenticate.enableSaslEncryption",
+ "true")
+ )
+ );
+ } else {
+ return new TransportConf(
+ "shuffle",
+ SslSampleConfigs.createDefaultConfigProviderForRpcNamespace()
+ );
+ }
}
}
diff --git
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslShuffleTransportContextSuite.java
similarity index 55%
copy from
resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
copy to
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslShuffleTransportContextSuite.java
index 322d6bfdb7c..51463bbad55 100644
---
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
+++
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslShuffleTransportContextSuite.java
@@ -15,20 +15,24 @@
* limitations under the License.
*/
-package org.apache.spark.network.yarn
+package org.apache.spark.network.shuffle;
-import org.apache.spark.network.ssl.SslSampleConfigs
+import com.google.common.collect.ImmutableMap;
-class SslYarnShuffleServiceWithRocksDBBackendSuite
- extends YarnShuffleServiceWithRocksDBBackendSuite {
+import org.apache.spark.network.ssl.SslSampleConfigs;
+import org.apache.spark.network.util.TransportConf;
- /**
- * Override to add "spark.ssl.rpc.*" configuration parameters...
- */
- override def beforeEach(): Unit = {
- super.beforeEach()
- // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to
import here.
- SslSampleConfigs.createDefaultConfigMap().entrySet().
- forEach(entry => yarnConfig.set(entry.getKey, entry.getValue))
+public class SslShuffleTransportContextSuite extends
ShuffleTransportContextSuite {
+
+ @Override
+ protected TransportConf createTransportConf(boolean separateFinalizeThread) {
+ return new TransportConf(
+ "shuffle",
+
SslSampleConfigs.createDefaultConfigProviderForRpcNamespaceWithAdditionalEntries(
+ ImmutableMap.of(
+ "spark.shuffle.server.finalizeShuffleMergeThreadsPercent",
+ separateFinalizeThread ? "1" : "0")
+ )
+ );
}
}
diff --git
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
index 322d6bfdb7c..06b91faf44a 100644
---
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
+++
b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
@@ -28,7 +28,7 @@ class SslYarnShuffleServiceWithRocksDBBackendSuite
override def beforeEach(): Unit = {
super.beforeEach()
// Same as SSLTestUtils.updateWithSSLConfig(), which is not available to
import here.
- SslSampleConfigs.createDefaultConfigMap().entrySet().
+ SslSampleConfigs.createDefaultConfigMapForRpcNamespace().entrySet().
forEach(entry => yarnConfig.set(entry.getKey, entry.getValue))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]