This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2878cd84d5d [SPARK-41415] SASL Request Retries
2878cd84d5d is described below

commit 2878cd84d5d5edcaa9bbc642c12d8c7f6f009328
Author: Aravind Patnam <[email protected]>
AuthorDate: Sat Jan 14 23:58:56 2023 -0600

    [SPARK-41415] SASL Request Retries
    
    ### What changes were proposed in this pull request?
    
    Add the ability to retry SASL requests. Will add it as a metric too soon to 
track SASL retries.
    
    ### Why are the changes needed?
    We are seeing increased SASL timeouts internally, and this issue would 
mitigate the issue. We already have this feature enabled for our 2.3 jobs, and 
we have seen failures significantly decrease.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added unit tests, and tested on cluster to ensure the retries are being 
triggered correctly.
    
    Closes #38959 from akpatnam25/SPARK-41415.
    
    Authored-by: Aravind Patnam <[email protected]>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../spark/network/sasl/SaslClientBootstrap.java    |  14 ++-
 .../spark/network/sasl/SaslTimeoutException.java   |  35 ++++++
 .../apache/spark/network/util/TransportConf.java   |   7 ++
 .../network/shuffle/RetryingBlockTransferor.java   |  33 +++++-
 .../shuffle/RetryingBlockTransferorSuite.java      | 119 ++++++++++++++++++++-
 5 files changed, 200 insertions(+), 8 deletions(-)

diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 64781377229..69baaca8a26 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.sasl;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.concurrent.TimeoutException;
 import javax.security.sasl.Sasl;
 import javax.security.sasl.SaslException;
 
@@ -65,9 +66,18 @@ public class SaslClientBootstrap implements 
TransportClientBootstrap {
         SaslMessage msg = new SaslMessage(appId, payload);
         ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) 
msg.body().size());
         msg.encode(buf);
+        ByteBuffer response;
         buf.writeBytes(msg.body().nioByteBuffer());
-
-        ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), 
conf.authRTTimeoutMs());
+        try {
+          response = client.sendRpcSync(buf.nioBuffer(), 
conf.authRTTimeoutMs());
+        } catch (RuntimeException ex) {
+          // We know it is a Sasl timeout here if it is a TimeoutException.
+          if (ex.getCause() instanceof TimeoutException) {
+            throw new SaslTimeoutException(ex.getCause());
+          } else {
+            throw ex;
+          }
+        }
         payload = saslClient.response(JavaUtils.bufferToArray(response));
       }
 
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java
 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java
new file mode 100644
index 00000000000..2533ae93f8d
--- /dev/null
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+/**
+ * An exception thrown if there is a SASL timeout.
+ */
+public class SaslTimeoutException extends RuntimeException {
+  public SaslTimeoutException(Throwable cause) {
+    super(cause);
+  }
+
+  public SaslTimeoutException(String message) {
+    super(message);
+  }
+
+  public SaslTimeoutException(String message, Throwable cause) {
+    super(message, cause);
+  }
+}
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
 
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
index f2848c2d4c9..bbfb99168da 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -333,6 +333,13 @@ public class TransportConf {
     return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false);
   }
 
+  /** Whether to enable sasl retries or not. The number of retries is dictated 
by the config
+   * `spark.shuffle.io.maxRetries`.
+   */
+  public boolean enableSaslRetries() {
+    return conf.getBoolean("spark.shuffle.sasl.enableRetries", false);
+  }
+
   /**
    * Class name of the implementation of MergedShuffleFileManager that merges 
the blocks
    * pushed to it when push-based shuffle is enabled. By default, push-based 
shuffle is disabled at
diff --git 
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java
 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java
index 463edc770d2..4515e3a5c28 100644
--- 
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java
+++ 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java
@@ -24,12 +24,14 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Sets;
 import com.google.common.util.concurrent.Uninterruptibles;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.sasl.SaslTimeoutException;
 import org.apache.spark.network.util.NettyUtils;
 import org.apache.spark.network.util.TransportConf;
 
@@ -85,6 +87,8 @@ public class RetryingBlockTransferor {
   /** Number of times we've attempted to retry so far. */
   private int retryCount = 0;
 
+  private boolean saslTimeoutSeen;
+
   /**
    * Set of all block ids which have not been transferred successfully or with 
a non-IO Exception.
    * A retry involves requesting every outstanding block. Note that since this 
is a LinkedHashSet,
@@ -99,6 +103,9 @@ public class RetryingBlockTransferor {
    */
   private RetryingBlockTransferListener currentListener;
 
+  /** Whether sasl retries are enabled. */
+  private final boolean enableSaslRetries;
+
   private final ErrorHandler errorHandler;
 
   public RetryingBlockTransferor(
@@ -115,6 +122,8 @@ public class RetryingBlockTransferor {
     Collections.addAll(outstandingBlocksIds, blockIds);
     this.currentListener = new RetryingBlockTransferListener();
     this.errorHandler = errorHandler;
+    this.enableSaslRetries = conf.enableSaslRetries();
+    this.saslTimeoutSeen = false;
   }
 
   public RetryingBlockTransferor(
@@ -187,13 +196,29 @@ public class RetryingBlockTransferor {
 
   /**
    * Returns true if we should retry due a block transfer failure. We will 
retry if and only if
-   * the exception was an IOException and we haven't retried 'maxRetries' 
times already.
+   * the exception was an IOException or SaslTimeoutException and we haven't 
retried
+   * 'maxRetries' times already.
    */
   private synchronized boolean shouldRetry(Throwable e) {
     boolean isIOException = e instanceof IOException
       || e.getCause() instanceof IOException;
+    boolean isSaslTimeout = enableSaslRetries && e instanceof 
SaslTimeoutException;
+    if (!isSaslTimeout && saslTimeoutSeen) {
+      retryCount = 0;
+      saslTimeoutSeen = false;
+    }
     boolean hasRemainingRetries = retryCount < maxRetries;
-    return isIOException && hasRemainingRetries && 
errorHandler.shouldRetryError(e);
+    boolean shouldRetry =  (isSaslTimeout || isIOException) &&
+        hasRemainingRetries && errorHandler.shouldRetryError(e);
+    if (shouldRetry && isSaslTimeout) {
+      this.saslTimeoutSeen = true;
+    }
+    return shouldRetry;
+  }
+
+  @VisibleForTesting
+  public int getRetryCount() {
+    return retryCount;
   }
 
   /**
@@ -211,6 +236,10 @@ public class RetryingBlockTransferor {
         if (this == currentListener && outstandingBlocksIds.contains(blockId)) 
{
           outstandingBlocksIds.remove(blockId);
           shouldForwardSuccess = true;
+          if (saslTimeoutSeen) {
+            retryCount = 0;
+            saslTimeoutSeen = false;
+          }
         }
       }
 
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java
index 985a7a36428..a33a471fb7a 100644
--- 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java
@@ -20,13 +20,18 @@ package org.apache.spark.network.shuffle;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.TimeoutException;
 
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Sets;
+
+import org.junit.Before;
 import org.junit.Test;
 import org.mockito.stubbing.Answer;
 import org.mockito.stubbing.Stubber;
@@ -38,6 +43,7 @@ import org.apache.spark.network.buffer.ManagedBuffer;
 import org.apache.spark.network.buffer.NioManagedBuffer;
 import org.apache.spark.network.util.MapConfigProvider;
 import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.sasl.SaslTimeoutException;
 import static 
org.apache.spark.network.shuffle.RetryingBlockTransferor.BlockTransferStarter;
 
 /**
@@ -49,6 +55,16 @@ public class RetryingBlockTransferorSuite {
   private final ManagedBuffer block0 = new 
NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
   private final ManagedBuffer block1 = new 
NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
   private final ManagedBuffer block2 = new 
NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
+  private static Map<String, String> configMap;
+  private static RetryingBlockTransferor _retryingBlockTransferor;
+
+  @Before
+  public void initMap() {
+    configMap = new HashMap<String, String>() {{
+      put("spark.shuffle.io.maxRetries", "2");
+      put("spark.shuffle.io.retryWait", "0");
+    }};
+  }
 
   @Test
   public void testNoFailures() throws IOException, InterruptedException {
@@ -230,6 +246,101 @@ public class RetryingBlockTransferorSuite {
     verifyNoMoreInteractions(listener);
   }
 
+  @Test
+  public void testSaslTimeoutFailure() throws IOException, 
InterruptedException {
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+    TimeoutException timeoutException = new TimeoutException();
+    SaslTimeoutException saslTimeoutException =
+        new SaslTimeoutException(timeoutException);
+    List<? extends Map<String, Object>> interactions = Arrays.asList(
+        ImmutableMap.<String, Object>builder()
+            .put("b0", saslTimeoutException)
+            .build(),
+        ImmutableMap.<String, Object>builder()
+            .put("b0", block0)
+            .build()
+    );
+
+    performInteractions(interactions, listener);
+
+    verify(listener, timeout(5000)).onBlockTransferFailure("b0", 
saslTimeoutException);
+    verify(listener).getTransferType();
+    verifyNoMoreInteractions(listener);
+  }
+
+  @Test
+  public void testRetryOnSaslTimeout() throws IOException, 
InterruptedException {
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+    List<? extends Map<String, Object>> interactions = Arrays.asList(
+        // SaslTimeout will cause a retry. Since b0 fails, we will retry both.
+        ImmutableMap.<String, Object>builder()
+            .put("b0", new SaslTimeoutException(new TimeoutException()))
+            .build(),
+        ImmutableMap.<String, Object>builder()
+            .put("b0", block0)
+            .build()
+    );
+    configMap.put("spark.shuffle.sasl.enableRetries", "true");
+    performInteractions(interactions, listener);
+
+    verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0);
+    verify(listener).getTransferType();
+    verifyNoMoreInteractions(listener);
+    assert(_retryingBlockTransferor.getRetryCount() == 0);
+  }
+
+  @Test
+  public void testRepeatedSaslRetryFailures() throws IOException, 
InterruptedException {
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+    TimeoutException timeoutException = new TimeoutException();
+    SaslTimeoutException saslTimeoutException =
+        new SaslTimeoutException(timeoutException);
+    List<ImmutableMap<String, Object>> interactions = new ArrayList<>();
+    for (int i = 0; i < 3; i++) {
+      interactions.add(
+          ImmutableMap.<String, Object>builder()
+              .put("b0", saslTimeoutException)
+              .build()
+      );
+    }
+    configMap.put("spark.shuffle.sasl.enableRetries", "true");
+    performInteractions(interactions, listener);
+    verify(listener, timeout(5000)).onBlockTransferFailure("b0", 
saslTimeoutException);
+    verify(listener, times(3)).getTransferType();
+    verifyNoMoreInteractions(listener);
+    assert(_retryingBlockTransferor.getRetryCount() == 2);
+  }
+
+  @Test
+  public void testBlockTransferFailureAfterSasl() throws IOException, 
InterruptedException {
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+    List<? extends Map<String, Object>> interactions = Arrays.asList(
+        ImmutableMap.<String, Object>builder()
+            .put("b0", new SaslTimeoutException(new TimeoutException()))
+            .put("b1", new IOException())
+            .build(),
+        ImmutableMap.<String, Object>builder()
+            .put("b0", block0)
+            .put("b1", new IOException())
+            .build(),
+        ImmutableMap.<String, Object>builder()
+          .put("b1", block1)
+          .build()
+    );
+    configMap.put("spark.shuffle.sasl.enableRetries", "true");
+    performInteractions(interactions, listener);
+    verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0);
+    verify(listener, timeout(5000)).onBlockTransferSuccess("b1", block1);
+    verify(listener, atLeastOnce()).getTransferType();
+    verifyNoMoreInteractions(listener);
+    // This should be equal to 1 because after the SASL exception is retried,
+    // retryCount should be set back to 0. Then after that b1 encounters an
+    // exception that is retried.
+    assert(_retryingBlockTransferor.getRetryCount() == 1);
+  }
+
   /**
    * Performs a set of interactions in response to block requests from a 
RetryingBlockFetcher.
    * Each interaction is a Map from BlockId to either ManagedBuffer or 
Exception. This interaction
@@ -244,9 +355,7 @@ public class RetryingBlockTransferorSuite {
                                           BlockFetchingListener listener)
     throws IOException, InterruptedException {
 
-    MapConfigProvider provider = new MapConfigProvider(ImmutableMap.of(
-      "spark.shuffle.io.maxRetries", "2",
-      "spark.shuffle.io.retryWait", "0"));
+    MapConfigProvider provider = new MapConfigProvider(configMap);
     TransportConf conf = new TransportConf("shuffle", provider);
     BlockTransferStarter fetchStarter = mock(BlockTransferStarter.class);
 
@@ -298,6 +407,8 @@ public class RetryingBlockTransferorSuite {
     assertNotNull(stub);
     stub.when(fetchStarter).createAndStart(any(), any());
     String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]);
-    new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, 
listener).start();
+    _retryingBlockTransferor =
+        new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, 
listener);
+    _retryingBlockTransferor.start();
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to