This is an automated email from the ASF dual-hosted git repository.

yashmayya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new f65f845eed Rewrite FailureDetector interface and implementations to 
also work with the multi-stage engine (#15005)
f65f845eed is described below

commit f65f845eedca3e8502ec4add3e640de3a0d849ee
Author: Yash Mayya <yash.ma...@gmail.com>
AuthorDate: Fri Feb 14 12:51:36 2025 +0530

    Rewrite FailureDetector interface and implementations to also work with the 
multi-stage engine (#15005)
---
 .../broker/broker/helix/BaseBrokerStarter.java     |  20 +-
 .../requesthandler/GrpcBrokerRequestHandler.java   |  65 ++++--
 .../MultiStageBrokerRequestHandler.java            |  31 ++-
 .../SingleConnectionBrokerRequestHandler.java      |  40 ++--
 .../pinot/broker/routing/BrokerRoutingManager.java |   1 +
 .../ConnectionFailureDetectorTest.java             | 157 ---------------
 .../LiteralOnlyBrokerRequestTest.java              |   7 +-
 ...BaseExponentialBackoffRetryFailureDetector.java |  59 ++++--
 .../failuredetector/ConnectionFailureDetector.java |  19 +-
 .../common}/failuredetector/FailureDetector.java   |  43 ++--
 .../failuredetector/FailureDetectorFactory.java    |   2 +-
 .../failuredetector/NoOpFailureDetector.java       |  13 +-
 .../pinot/common/utils/grpc/GrpcQueryClient.java   |   4 +
 .../ConnectionFailureDetectorTest.java             | 224 +++++++++++++++++++++
 .../pinot/query/routing/QueryServerInstance.java   |  10 +-
 .../apache/pinot/query/routing/WorkerManager.java  |   6 +-
 .../pinot/query/QueryEnvironmentTestBase.java      |   2 +-
 .../query/service/dispatch/QueryDispatcher.java    |  42 +++-
 .../runtime/queries/QueryRunnerAccountingTest.java |   4 +-
 .../query/runtime/queries/QueryRunnerTest.java     |   4 +-
 .../runtime/queries/ResourceBasedQueriesTest.java  |   4 +-
 .../apache/pinot/spi/utils/CommonConstants.java    |   2 +-
 22 files changed, 458 insertions(+), 301 deletions(-)

diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java
index e134d65b75..35e10361fc 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java
@@ -58,6 +58,8 @@ import org.apache.pinot.common.config.NettyConfig;
 import org.apache.pinot.common.config.TlsConfig;
 import org.apache.pinot.common.config.provider.TableCache;
 import org.apache.pinot.common.cursors.AbstractResponseStore;
+import org.apache.pinot.common.failuredetector.FailureDetector;
+import org.apache.pinot.common.failuredetector.FailureDetectorFactory;
 import org.apache.pinot.common.function.FunctionRegistry;
 import org.apache.pinot.common.metadata.ZKMetadataProvider;
 import org.apache.pinot.common.metrics.BrokerGauge;
@@ -144,6 +146,7 @@ public abstract class BaseBrokerStarter implements 
ServiceStartable {
   protected HelixExternalViewBasedQueryQuotaManager _queryQuotaManager;
   protected MultiStageQueryThrottler _multiStageQueryThrottler;
   protected AbstractResponseStore _responseStore;
+  protected FailureDetector _failureDetector;
 
   @Override
   public void init(PinotConfiguration brokerConf)
@@ -319,6 +322,14 @@ public abstract class BaseBrokerStarter implements 
ServiceStartable {
     LOGGER.info("Initializing Broker Event Listener Factory");
     
BrokerQueryEventListenerFactory.init(_brokerConf.subset(Broker.EVENT_LISTENER_CONFIG_PREFIX));
 
+    // Initialize the failure detector that removes servers from the broker 
routing table if they are not healthy
+    _failureDetector = FailureDetectorFactory.getFailureDetector(_brokerConf, 
_brokerMetrics);
+    _failureDetector.registerHealthyServerNotifier(
+        instanceId -> _routingManager.includeServerToRouting(instanceId));
+    _failureDetector.registerUnhealthyServerNotifier(
+        instanceId -> _routingManager.excludeServerFromRouting(instanceId));
+    _failureDetector.start();
+
     // Create Broker request handler.
     String brokerId = _brokerConf.getProperty(Broker.CONFIG_OF_BROKER_ID, 
getDefaultBrokerId());
     String brokerRequestHandlerType =
@@ -326,7 +337,7 @@ public abstract class BaseBrokerStarter implements 
ServiceStartable {
     BaseSingleStageBrokerRequestHandler singleStageBrokerRequestHandler;
     if 
(brokerRequestHandlerType.equalsIgnoreCase(Broker.GRPC_BROKER_REQUEST_HANDLER_TYPE))
 {
       singleStageBrokerRequestHandler = new 
GrpcBrokerRequestHandler(_brokerConf, brokerId, _routingManager,
-          _accessControlFactory, _queryQuotaManager, tableCache);
+          _accessControlFactory, _queryQuotaManager, tableCache, 
_failureDetector);
     } else {
       // Default request handler type, i.e. netty
       NettyConfig nettyDefaults = NettyConfig.extractNettyConfig(_brokerConf, 
Broker.BROKER_NETTY_PREFIX);
@@ -337,7 +348,8 @@ public abstract class BaseBrokerStarter implements 
ServiceStartable {
       }
       singleStageBrokerRequestHandler =
           new SingleConnectionBrokerRequestHandler(_brokerConf, brokerId, 
_routingManager, _accessControlFactory,
-              _queryQuotaManager, tableCache, nettyDefaults, tlsDefaults, 
_serverRoutingStatsManager);
+              _queryQuotaManager, tableCache, nettyDefaults, tlsDefaults, 
_serverRoutingStatsManager,
+              _failureDetector);
     }
     MultiStageBrokerRequestHandler multiStageBrokerRequestHandler = null;
     QueryDispatcher queryDispatcher = null;
@@ -350,7 +362,7 @@ public abstract class BaseBrokerStarter implements 
ServiceStartable {
       queryDispatcher = createQueryDispatcher(_brokerConf);
       multiStageBrokerRequestHandler =
           new MultiStageBrokerRequestHandler(_brokerConf, brokerId, 
_routingManager, _accessControlFactory,
-              _queryQuotaManager, tableCache, _multiStageQueryThrottler);
+              _queryQuotaManager, tableCache, _multiStageQueryThrottler, 
_failureDetector);
     }
     TimeSeriesRequestHandler timeSeriesRequestHandler = null;
     if 
(StringUtils.isNotBlank(_brokerConf.getProperty(PinotTimeSeriesConfiguration.getEnabledLanguagesConfigKey())))
 {
@@ -613,6 +625,8 @@ public abstract class BaseBrokerStarter implements 
ServiceStartable {
     LOGGER.info("Stopping cluster change mediator");
     _clusterChangeMediator.stop();
 
+    _failureDetector.stop();
+
     // Delay shutdown of request handler so that the pending queries can be 
finished. The participant Helix manager has
     // been disconnected, so instance should disappear from ExternalView soon 
and stop getting new queries.
     long delayShutdownTimeMs =
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/GrpcBrokerRequestHandler.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/GrpcBrokerRequestHandler.java
index d6c2f3aacc..eb963b1ceb 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/GrpcBrokerRequestHandler.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/GrpcBrokerRequestHandler.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.broker.requesthandler;
 
+import io.grpc.ConnectivityState;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
@@ -31,6 +32,7 @@ import org.apache.pinot.broker.queryquota.QueryQuotaManager;
 import org.apache.pinot.broker.routing.BrokerRoutingManager;
 import org.apache.pinot.common.config.GrpcConfig;
 import org.apache.pinot.common.config.provider.TableCache;
+import org.apache.pinot.common.failuredetector.FailureDetector;
 import org.apache.pinot.common.proto.Server;
 import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.response.broker.BrokerResponseNative;
@@ -43,6 +45,8 @@ import org.apache.pinot.core.transport.ServerRoutingInstance;
 import org.apache.pinot.spi.config.table.TableType;
 import org.apache.pinot.spi.env.PinotConfiguration;
 import org.apache.pinot.spi.trace.RequestContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 
 /**
@@ -50,15 +54,21 @@ import org.apache.pinot.spi.trace.RequestContext;
  */
 @ThreadSafe
 public class GrpcBrokerRequestHandler extends 
BaseSingleStageBrokerRequestHandler {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(GrpcBrokerRequestHandler.class);
+
   private final StreamingReduceService _streamingReduceService;
   private final PinotStreamingQueryClient _streamingQueryClient;
+  private final FailureDetector _failureDetector;
 
   // TODO: Support TLS
   public GrpcBrokerRequestHandler(PinotConfiguration config, String brokerId, 
BrokerRoutingManager routingManager,
-      AccessControlFactory accessControlFactory, QueryQuotaManager 
queryQuotaManager, TableCache tableCache) {
+      AccessControlFactory accessControlFactory, QueryQuotaManager 
queryQuotaManager, TableCache tableCache,
+      FailureDetector failureDetector) {
     super(config, brokerId, routingManager, accessControlFactory, 
queryQuotaManager, tableCache);
     _streamingReduceService = new StreamingReduceService(config);
     _streamingQueryClient = new 
PinotStreamingQueryClient(GrpcConfig.buildGrpcQueryConfig(config));
+    _failureDetector = failureDetector;
+    
_failureDetector.registerUnhealthyServerRetrier(this::retryUnhealthyServer);
   }
 
   @Override
@@ -81,7 +91,6 @@ public class GrpcBrokerRequestHandler extends 
BaseSingleStageBrokerRequestHandle
       @Nullable Map<ServerInstance, ServerRouteInfo> realtimeRoutingTable, 
long timeoutMs,
       ServerStats serverStats, RequestContext requestContext)
       throws Exception {
-    // TODO: Support failure detection
     // TODO: Add servers queried/responded stats
     assert offlineBrokerRequest != null || realtimeBrokerRequest != null;
     Map<ServerRoutingInstance, Iterator<Server.ServerResponse>> responseMap = 
new HashMap<>();
@@ -112,33 +121,39 @@ public class GrpcBrokerRequestHandler extends 
BaseSingleStageBrokerRequestHandle
       ServerInstance serverInstance = routingEntry.getKey();
       // TODO: support optional segments for GrpcQueryServer.
       List<String> segments = routingEntry.getValue().getSegments();
-      String serverHost = serverInstance.getHostname();
-      int port = serverInstance.getGrpcPort();
       // TODO: enable throttling on per host bases.
-      Iterator<Server.ServerResponse> streamingResponse = 
_streamingQueryClient.submit(serverHost, port,
-          new 
GrpcRequestBuilder().setRequestId(requestId).setBrokerId(_brokerId).setEnableTrace(trace)
-              
.setEnableStreaming(true).setBrokerRequest(brokerRequest).setSegments(segments).build());
-      responseMap.put(serverInstance.toServerRoutingInstance(tableType, 
ServerInstance.RoutingType.GRPC),
-          streamingResponse);
+      try {
+        Iterator<Server.ServerResponse> streamingResponse = 
_streamingQueryClient.submit(serverInstance,
+            new 
GrpcRequestBuilder().setRequestId(requestId).setBrokerId(_brokerId).setEnableTrace(trace)
+                
.setEnableStreaming(true).setBrokerRequest(brokerRequest).setSegments(segments).build());
+        responseMap.put(serverInstance.toServerRoutingInstance(tableType, 
ServerInstance.RoutingType.GRPC),
+            streamingResponse);
+      } catch (Exception e) {
+        LOGGER.warn("Failed to send request {} to server: {}", requestId, 
serverInstance.getInstanceId(), e);
+        _failureDetector.markServerUnhealthy(serverInstance.getInstanceId());
+      }
     }
   }
 
   public static class PinotStreamingQueryClient {
     private final Map<String, GrpcQueryClient> _grpcQueryClientMap = new 
ConcurrentHashMap<>();
+    private final Map<String, String> _instanceIdToHostnamePortMap = new 
ConcurrentHashMap<>();
     private final GrpcConfig _config;
 
     public PinotStreamingQueryClient(GrpcConfig config) {
       _config = config;
     }
 
-    public Iterator<Server.ServerResponse> submit(String host, int port, 
Server.ServerRequest serverRequest) {
-      GrpcQueryClient client = getOrCreateGrpcQueryClient(host, port);
+    public Iterator<Server.ServerResponse> submit(ServerInstance 
serverInstance, Server.ServerRequest serverRequest) {
+      GrpcQueryClient client = getOrCreateGrpcQueryClient(serverInstance);
       return client.submit(serverRequest);
     }
 
-    private GrpcQueryClient getOrCreateGrpcQueryClient(String host, int port) {
-      String key = String.format("%s_%d", host, port);
-      return _grpcQueryClientMap.computeIfAbsent(key, k -> new 
GrpcQueryClient(host, port, _config));
+    private GrpcQueryClient getOrCreateGrpcQueryClient(ServerInstance 
serverInstance) {
+      String hostnamePort = String.format("%s_%d", 
serverInstance.getHostname(), serverInstance.getGrpcPort());
+      _instanceIdToHostnamePortMap.put(serverInstance.getInstanceId(), 
hostnamePort);
+      return _grpcQueryClientMap.computeIfAbsent(hostnamePort,
+          k -> new GrpcQueryClient(serverInstance.getHostname(), 
serverInstance.getGrpcPort(), _config));
     }
 
     public void shutdown() {
@@ -147,4 +162,26 @@ public class GrpcBrokerRequestHandler extends 
BaseSingleStageBrokerRequestHandle
       }
     }
   }
+
+  /**
+   * Check if a server that was previously detected as unhealthy is now 
healthy.
+   */
+  private boolean retryUnhealthyServer(String instanceId) {
+    LOGGER.info("Checking gRPC connection to unhealthy server: {}", 
instanceId);
+    ServerInstance serverInstance = 
_routingManager.getEnabledServerInstanceMap().get(instanceId);
+    if (serverInstance == null) {
+      LOGGER.info("Failed to find enabled server: {} in routing manager, 
skipping the retry", instanceId);
+      return false;
+    }
+
+    String hostnamePort = 
_streamingQueryClient._instanceIdToHostnamePortMap.get(instanceId);
+    GrpcQueryClient client = 
_streamingQueryClient._grpcQueryClientMap.get(hostnamePort);
+
+    if (client == null) {
+      LOGGER.warn("No GrpcQueryClient found for server with instanceId: {}", 
instanceId);
+      return false;
+    }
+
+    return client.getChannel().getState(true) == ConnectivityState.READY;
+  }
 }
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
index 2a20f35e90..cbbdf8cb19 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
@@ -51,6 +51,7 @@ import org.apache.pinot.common.config.TlsConfig;
 import org.apache.pinot.common.config.provider.TableCache;
 import org.apache.pinot.common.exception.QueryException;
 import org.apache.pinot.common.exception.QueryInfoException;
+import org.apache.pinot.common.failuredetector.FailureDetector;
 import org.apache.pinot.common.metrics.BrokerMeter;
 import org.apache.pinot.common.metrics.BrokerQueryPhase;
 import org.apache.pinot.common.response.BrokerResponse;
@@ -67,6 +68,7 @@ import org.apache.pinot.common.utils.config.QueryOptionsUtils;
 import org.apache.pinot.common.utils.tls.TlsUtils;
 import org.apache.pinot.core.auth.Actions;
 import org.apache.pinot.core.auth.TargetType;
+import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.planner.explain.AskingServerStageExplainer;
@@ -91,6 +93,10 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 
+/**
+ * This class serves as the broker entry-point for handling incoming 
multi-stage query requests and dispatching them
+ * to servers.
+ */
 public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(MultiStageBrokerRequestHandler.class);
 
@@ -104,17 +110,20 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
 
   public MultiStageBrokerRequestHandler(PinotConfiguration config, String 
brokerId, BrokerRoutingManager routingManager,
       AccessControlFactory accessControlFactory, QueryQuotaManager 
queryQuotaManager, TableCache tableCache,
-      MultiStageQueryThrottler queryThrottler) {
+      MultiStageQueryThrottler queryThrottler, FailureDetector 
failureDetector) {
     super(config, brokerId, routingManager, accessControlFactory, 
queryQuotaManager, tableCache);
     String hostname = 
config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_HOSTNAME);
     int port = 
Integer.parseInt(config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT));
-    _workerManager = new WorkerManager(hostname, port, _routingManager);
+    _workerManager = new WorkerManager(_brokerId, hostname, port, 
_routingManager);
     TlsConfig tlsConfig = config.getProperty(
         CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_TLS_ENABLED,
         CommonConstants.Helix.DEFAULT_MULTI_STAGE_ENGINE_TLS_ENABLED) ? 
TlsUtils.extractTlsConfig(config,
         CommonConstants.Broker.BROKER_TLS_PREFIX) : null;
-    _queryDispatcher = new QueryDispatcher(
-        new MailboxService(hostname, port, config, tlsConfig), tlsConfig, 
this.isQueryCancellationEnabled());
+
+    failureDetector.registerUnhealthyServerRetrier(this::retryUnhealthyServer);
+    _queryDispatcher =
+        new QueryDispatcher(new MailboxService(hostname, port, config, 
tlsConfig), tlsConfig, failureDetector,
+            this.isQueryCancellationEnabled());
     LOGGER.info("Initialized MultiStageBrokerRequestHandler on host: {}, port: 
{} with broker id: {}, timeout: {}ms, "
             + "query log max length: {}, query log max rate: {}, query 
cancellation enabled: {}", hostname, port,
         _brokerId, _brokerTimeoutMs, _queryLogger.getMaxQueryLengthToLog(), 
_queryLogger.getLogRateLimit(),
@@ -538,4 +547,18 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
     return setOfStrings.stream().limit(limit)
         .collect(Collectors.joining(", ", "[", setOfStrings.size() > limit ? 
"...]" : "]"));
   }
+
+  /**
+   * Check if a server that was previously detected as unhealthy is now 
healthy.
+   */
+  public boolean retryUnhealthyServer(String instanceId) {
+    LOGGER.info("Checking gRPC connection to unhealthy server: {}", 
instanceId);
+    ServerInstance serverInstance = 
_routingManager.getEnabledServerInstanceMap().get(instanceId);
+    if (serverInstance == null) {
+      LOGGER.info("Failed to find enabled server: {} in routing manager, 
skipping the retry", instanceId);
+      return false;
+    }
+
+    return _queryDispatcher.checkConnectivityToInstance(instanceId);
+  }
 }
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/SingleConnectionBrokerRequestHandler.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/SingleConnectionBrokerRequestHandler.java
index 6e8ecff9a5..28143565a1 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/SingleConnectionBrokerRequestHandler.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/SingleConnectionBrokerRequestHandler.java
@@ -26,8 +26,6 @@ import java.util.concurrent.TimeUnit;
 import javax.annotation.Nullable;
 import javax.annotation.concurrent.ThreadSafe;
 import org.apache.pinot.broker.broker.AccessControlFactory;
-import org.apache.pinot.broker.failuredetector.FailureDetector;
-import org.apache.pinot.broker.failuredetector.FailureDetectorFactory;
 import org.apache.pinot.broker.queryquota.QueryQuotaManager;
 import org.apache.pinot.broker.routing.BrokerRoutingManager;
 import org.apache.pinot.common.config.NettyConfig;
@@ -35,6 +33,7 @@ import org.apache.pinot.common.config.TlsConfig;
 import org.apache.pinot.common.config.provider.TableCache;
 import org.apache.pinot.common.datatable.DataTable;
 import org.apache.pinot.common.exception.QueryException;
+import org.apache.pinot.common.failuredetector.FailureDetector;
 import org.apache.pinot.common.metrics.BrokerMeter;
 import org.apache.pinot.common.metrics.BrokerQueryPhase;
 import org.apache.pinot.common.request.BrokerRequest;
@@ -63,8 +62,7 @@ import org.slf4j.LoggerFactory;
  * connection per server to route the queries.
  */
 @ThreadSafe
-public class SingleConnectionBrokerRequestHandler extends 
BaseSingleStageBrokerRequestHandler
-    implements FailureDetector.Listener {
+public class SingleConnectionBrokerRequestHandler extends 
BaseSingleStageBrokerRequestHandler {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(SingleConnectionBrokerRequestHandler.class);
 
   private final BrokerReduceService _brokerReduceService;
@@ -74,24 +72,22 @@ public class SingleConnectionBrokerRequestHandler extends 
BaseSingleStageBrokerR
   public SingleConnectionBrokerRequestHandler(PinotConfiguration config, 
String brokerId,
       BrokerRoutingManager routingManager, AccessControlFactory 
accessControlFactory,
       QueryQuotaManager queryQuotaManager, TableCache tableCache, NettyConfig 
nettyConfig, TlsConfig tlsConfig,
-      ServerRoutingStatsManager serverRoutingStatsManager) {
+      ServerRoutingStatsManager serverRoutingStatsManager, FailureDetector 
failureDetector) {
     super(config, brokerId, routingManager, accessControlFactory, 
queryQuotaManager, tableCache);
     _brokerReduceService = new BrokerReduceService(_config);
     _queryRouter = new QueryRouter(_brokerId, _brokerMetrics, nettyConfig, 
tlsConfig, serverRoutingStatsManager);
-    _failureDetector = FailureDetectorFactory.getFailureDetector(config, 
_brokerMetrics);
+    _failureDetector = failureDetector;
+    
_failureDetector.registerUnhealthyServerRetrier(this::retryUnhealthyServer);
   }
 
   @Override
   public void start() {
     super.start();
-    _failureDetector.register(this);
-    _failureDetector.start();
   }
 
   @Override
   public void shutDown() {
     super.shutDown();
-    _failureDetector.stop();
     _queryRouter.shutDown();
     _brokerReduceService.shutDown();
   }
@@ -114,14 +110,15 @@ public class SingleConnectionBrokerRequestHandler extends 
BaseSingleStageBrokerR
     AsyncQueryResponse asyncQueryResponse =
         _queryRouter.submitQuery(requestId, rawTableName, 
offlineBrokerRequest, offlineRoutingTable,
             realtimeBrokerRequest, realtimeRoutingTable, timeoutMs);
-    _failureDetector.notifyQuerySubmitted(asyncQueryResponse);
     Map<ServerRoutingInstance, ServerResponse> finalResponses = 
asyncQueryResponse.getFinalResponses();
     if (asyncQueryResponse.getStatus() == QueryResponse.Status.TIMED_OUT) {
       BrokerMeter meter = 
QueryOptionsUtils.isSecondaryWorkload(serverBrokerRequest.getPinotQuery().getQueryOptions())
           ? BrokerMeter.SECONDARY_WORKLOAD_BROKER_RESPONSES_WITH_TIMEOUTS : 
BrokerMeter.BROKER_RESPONSES_WITH_TIMEOUTS;
       _brokerMetrics.addMeteredTableValue(rawTableName, meter, 1);
     }
-    _failureDetector.notifyQueryFinished(asyncQueryResponse);
+    if (asyncQueryResponse.getFailedServer() != null) {
+      
_failureDetector.markServerUnhealthy(asyncQueryResponse.getFailedServer().getInstanceId());
+    }
     _brokerMetrics.addPhaseTiming(rawTableName, 
BrokerQueryPhase.SCATTER_GATHER,
         System.nanoTime() - scatterGatherStartTimeNs);
     // TODO Use scatterGatherStats as serverStats
@@ -179,29 +176,22 @@ public class SingleConnectionBrokerRequestHandler extends 
BaseSingleStageBrokerR
     return brokerResponse;
   }
 
-  @Override
-  public void notifyUnhealthyServer(String instanceId, FailureDetector 
failureDetector) {
-    _routingManager.excludeServerFromRouting(instanceId);
-  }
-
-  @Override
-  public void retryUnhealthyServer(String instanceId, FailureDetector 
failureDetector) {
+  /**
+   * Check if a server that was previously detected as unhealthy is now 
healthy.
+   */
+  public boolean retryUnhealthyServer(String instanceId) {
     LOGGER.info("Retrying unhealthy server: {}", instanceId);
     ServerInstance serverInstance = 
_routingManager.getEnabledServerInstanceMap().get(instanceId);
     if (serverInstance == null) {
       LOGGER.info("Failed to find enabled server: {} in routing manager, 
skipping the retry", instanceId);
-      return;
+      return false;
     }
     if (_queryRouter.connect(serverInstance)) {
       LOGGER.info("Successfully connect to server: {}, marking it healthy", 
instanceId);
-      failureDetector.markServerHealthy(instanceId);
+      return true;
     } else {
       LOGGER.warn("Still cannot connect to server: {}, retry later", 
instanceId);
+      return false;
     }
   }
-
-  @Override
-  public void notifyHealthyServer(String instanceId, FailureDetector 
failureDetector) {
-    _routingManager.includeServerToRouting(instanceId);
-  }
 }
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java
index 6690c2dc8b..c81b4a6e09 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java
@@ -388,6 +388,7 @@ public class BrokerRoutingManager implements 
RoutingManager, ClusterChangeHandle
     LOGGER.info("Including server: {} to routing", instanceId);
     if (!_excludedServers.remove(instanceId)) {
       LOGGER.info("Server: {} is not previously excluded, skipping updating 
the routing", instanceId);
+      return;
     }
     if (!_enabledServerInstanceMap.containsKey(instanceId)) {
       LOGGER.info("Server: {} is not enabled, skipping updating the routing", 
instanceId);
diff --git 
a/pinot-broker/src/test/java/org/apache/pinot/broker/failuredetector/ConnectionFailureDetectorTest.java
 
b/pinot-broker/src/test/java/org/apache/pinot/broker/failuredetector/ConnectionFailureDetectorTest.java
deleted file mode 100644
index 859b834f0c..0000000000
--- 
a/pinot-broker/src/test/java/org/apache/pinot/broker/failuredetector/ConnectionFailureDetectorTest.java
+++ /dev/null
@@ -1,157 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.broker.failuredetector;
-
-import java.util.Collections;
-import java.util.Set;
-import java.util.concurrent.atomic.AtomicInteger;
-import org.apache.pinot.common.metrics.BrokerGauge;
-import org.apache.pinot.common.metrics.BrokerMetrics;
-import org.apache.pinot.common.metrics.MetricValueUtils;
-import org.apache.pinot.core.transport.QueryResponse;
-import org.apache.pinot.core.transport.ServerRoutingInstance;
-import org.apache.pinot.spi.config.table.TableType;
-import org.apache.pinot.spi.env.PinotConfiguration;
-import org.apache.pinot.spi.metrics.PinotMetricUtils;
-import org.apache.pinot.spi.utils.CommonConstants.Broker;
-import org.apache.pinot.util.TestUtils;
-import org.testng.annotations.AfterClass;
-import org.testng.annotations.BeforeClass;
-import org.testng.annotations.Test;
-
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-import static org.testng.Assert.assertEquals;
-import static org.testng.Assert.assertTrue;
-
-
-public class ConnectionFailureDetectorTest {
-  private static final String INSTANCE_ID = "Server_localhost_1234";
-
-  private BrokerMetrics _brokerMetrics;
-  private FailureDetector _failureDetector;
-  private Listener _listener;
-
-  @BeforeClass
-  public void setUp() {
-    PinotConfiguration config = new PinotConfiguration();
-    config.setProperty(Broker.FailureDetector.CONFIG_OF_TYPE, 
Broker.FailureDetector.Type.CONNECTION.name());
-    
config.setProperty(Broker.FailureDetector.CONFIG_OF_RETRY_INITIAL_DELAY_MS, 
100);
-    config.setProperty(Broker.FailureDetector.CONFIG_OF_RETRY_DELAY_FACTOR, 1);
-    _brokerMetrics = new 
BrokerMetrics(PinotMetricUtils.getPinotMetricsRegistry());
-    _failureDetector = FailureDetectorFactory.getFailureDetector(config, 
_brokerMetrics);
-    assertTrue(_failureDetector instanceof ConnectionFailureDetector);
-    _listener = new Listener();
-    _failureDetector.register(_listener);
-    _failureDetector.start();
-  }
-
-  @Test
-  public void testConnectionFailure() {
-    QueryResponse queryResponse = mock(QueryResponse.class);
-    when(queryResponse.getFailedServer()).thenReturn(new 
ServerRoutingInstance("localhost", 1234, TableType.OFFLINE));
-
-    // No failure detection when submitting the query
-    _failureDetector.notifyQuerySubmitted(queryResponse);
-    verify(Collections.emptySet(), 0, 0);
-
-    // When query finishes, the failed server should be count as unhealthy and 
trigger a callback
-    _failureDetector.notifyQueryFinished(queryResponse);
-    verify(Collections.singleton(INSTANCE_ID), 1, 0);
-
-    // Mark server unhealthy again should have no effect
-    _failureDetector.markServerUnhealthy(INSTANCE_ID);
-    verify(Collections.singleton(INSTANCE_ID), 1, 0);
-
-    // Mark server healthy should remove it from the unhealthy servers and 
trigger a callback
-    _failureDetector.markServerHealthy(INSTANCE_ID);
-    verify(Collections.emptySet(), 1, 1);
-
-    _listener.reset();
-  }
-
-  @Test
-  public void testRetry() {
-    _failureDetector.markServerUnhealthy(INSTANCE_ID);
-    verify(Collections.singleton(INSTANCE_ID), 1, 0);
-
-    // Should get 10 retries in 1s, then remove the failed server from the 
unhealthy servers.
-    // Wait for up to 5s to avoid flakiness
-    TestUtils.waitForCondition(aVoid -> {
-      int numRetries = _listener._retryUnhealthyServerCalled.get();
-      if (numRetries < Broker.FailureDetector.DEFAULT_MAX_RETIRES) {
-        assertEquals(_failureDetector.getUnhealthyServers(), 
Collections.singleton(INSTANCE_ID));
-        assertEquals(MetricValueUtils.getGlobalGaugeValue(_brokerMetrics, 
BrokerGauge.UNHEALTHY_SERVERS), 1);
-        return false;
-      }
-      assertEquals(numRetries, Broker.FailureDetector.DEFAULT_MAX_RETIRES);
-      // There might be a small delay between the last retry and removing 
failed server from the unhealthy servers.
-      // Perform a check instead of an assertion.
-      return _failureDetector.getUnhealthyServers().isEmpty()
-          && MetricValueUtils.getGaugeValue(_brokerMetrics, 
BrokerGauge.UNHEALTHY_SERVERS.getGaugeName()) == 0
-          && _listener._notifyUnhealthyServerCalled.get() == 1 && 
_listener._notifyHealthyServerCalled.get() == 1;
-    }, 5_000L, "Failed to get 10 retires");
-
-    _listener.reset();
-  }
-
-  private void verify(Set<String> expectedUnhealthyServers, int 
expectedNotifyUnhealthyServerCalled,
-      int expectedNotifyHealthyServerCalled) {
-    assertEquals(_failureDetector.getUnhealthyServers(), 
expectedUnhealthyServers);
-    assertEquals(MetricValueUtils.getGlobalGaugeValue(_brokerMetrics, 
BrokerGauge.UNHEALTHY_SERVERS),
-        expectedUnhealthyServers.size());
-    assertEquals(_listener._notifyUnhealthyServerCalled.get(), 
expectedNotifyUnhealthyServerCalled);
-    assertEquals(_listener._notifyHealthyServerCalled.get(), 
expectedNotifyHealthyServerCalled);
-  }
-
-  @AfterClass
-  public void tearDown() {
-    _failureDetector.stop();
-  }
-
-  private static class Listener implements FailureDetector.Listener {
-    final AtomicInteger _notifyUnhealthyServerCalled = new AtomicInteger();
-    final AtomicInteger _retryUnhealthyServerCalled = new AtomicInteger();
-    final AtomicInteger _notifyHealthyServerCalled = new AtomicInteger();
-
-    @Override
-    public void notifyUnhealthyServer(String instanceId, FailureDetector 
failureDetector) {
-      assertEquals(instanceId, INSTANCE_ID);
-      _notifyUnhealthyServerCalled.getAndIncrement();
-    }
-
-    @Override
-    public void retryUnhealthyServer(String instanceId, FailureDetector 
failureDetector) {
-      assertEquals(instanceId, INSTANCE_ID);
-      _retryUnhealthyServerCalled.getAndIncrement();
-    }
-
-    @Override
-    public void notifyHealthyServer(String instanceId, FailureDetector 
failureDetector) {
-      assertEquals(instanceId, INSTANCE_ID);
-      _notifyHealthyServerCalled.getAndIncrement();
-    }
-
-    void reset() {
-      _notifyUnhealthyServerCalled.set(0);
-      _retryUnhealthyServerCalled.set(0);
-      _notifyHealthyServerCalled.set(0);
-    }
-  }
-}
diff --git 
a/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/LiteralOnlyBrokerRequestTest.java
 
b/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/LiteralOnlyBrokerRequestTest.java
index 0b68d2b843..a4bcf99bb9 100644
--- 
a/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/LiteralOnlyBrokerRequestTest.java
+++ 
b/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/LiteralOnlyBrokerRequestTest.java
@@ -23,6 +23,7 @@ import java.util.Random;
 import java.util.concurrent.TimeUnit;
 import org.apache.pinot.broker.broker.AccessControlFactory;
 import org.apache.pinot.broker.broker.AllowAllAccessControlFactory;
+import org.apache.pinot.common.failuredetector.FailureDetector;
 import org.apache.pinot.common.metrics.BrokerMetrics;
 import org.apache.pinot.common.response.BrokerResponse;
 import org.apache.pinot.common.response.broker.ResultTable;
@@ -169,7 +170,7 @@ public class LiteralOnlyBrokerRequestTest {
       throws Exception {
     SingleConnectionBrokerRequestHandler requestHandler =
         new SingleConnectionBrokerRequestHandler(new PinotConfiguration(), 
"testBrokerId", null, ACCESS_CONTROL_FACTORY,
-            null, null, null, null, mock(ServerRoutingStatsManager.class));
+            null, null, null, null, mock(ServerRoutingStatsManager.class), 
mock(FailureDetector.class));
 
     long randNum = RANDOM.nextLong();
     byte[] randBytes = new byte[12];
@@ -193,7 +194,7 @@ public class LiteralOnlyBrokerRequestTest {
       throws Exception {
     SingleConnectionBrokerRequestHandler requestHandler =
         new SingleConnectionBrokerRequestHandler(new PinotConfiguration(), 
"testBrokerId", null, ACCESS_CONTROL_FACTORY,
-            null, null, null, null, mock(ServerRoutingStatsManager.class));
+            null, null, null, null, mock(ServerRoutingStatsManager.class), 
mock(FailureDetector.class));
     long currentTsMin = System.currentTimeMillis();
     BrokerResponse brokerResponse = requestHandler.handleRequest(
         "SELECT now() AS currentTs, fromDateTime('2020-01-01 UTC', 'yyyy-MM-dd 
z') AS firstDayOf2020");
@@ -347,7 +348,7 @@ public class LiteralOnlyBrokerRequestTest {
       throws Exception {
     SingleConnectionBrokerRequestHandler requestHandler =
         new SingleConnectionBrokerRequestHandler(new PinotConfiguration(), 
"testBrokerId", null, ACCESS_CONTROL_FACTORY,
-            null, null, null, null, mock(ServerRoutingStatsManager.class));
+            null, null, null, null, mock(ServerRoutingStatsManager.class), 
mock(FailureDetector.class));
 
     // Test 1: select constant
     BrokerResponse brokerResponse = requestHandler.handleRequest("EXPLAIN PLAN 
FOR SELECT 1.5, 'test'");
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/BaseExponentialBackoffRetryFailureDetector.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/failuredetector/BaseExponentialBackoffRetryFailureDetector.java
similarity index 78%
rename from 
pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/BaseExponentialBackoffRetryFailureDetector.java
rename to 
pinot-common/src/main/java/org/apache/pinot/common/failuredetector/BaseExponentialBackoffRetryFailureDetector.java
index b3c7639cfd..088ddfc6f0 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/BaseExponentialBackoffRetryFailureDetector.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/failuredetector/BaseExponentialBackoffRetryFailureDetector.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.broker.failuredetector;
+package org.apache.pinot.common.failuredetector;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -25,6 +25,8 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.DelayQueue;
 import java.util.concurrent.Delayed;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+import java.util.function.Function;
 import javax.annotation.concurrent.ThreadSafe;
 import org.apache.pinot.common.metrics.BrokerGauge;
 import org.apache.pinot.common.metrics.BrokerMetrics;
@@ -43,10 +45,12 @@ public abstract class 
BaseExponentialBackoffRetryFailureDetector implements Fail
   private static final Logger LOGGER = 
LoggerFactory.getLogger(BaseExponentialBackoffRetryFailureDetector.class);
 
   protected final String _name = getClass().getSimpleName();
-  protected final List<Listener> _listeners = new ArrayList<>();
   protected final ConcurrentHashMap<String, RetryInfo> 
_unhealthyServerRetryInfoMap = new ConcurrentHashMap<>();
   protected final DelayQueue<RetryInfo> _retryInfoDelayQueue = new 
DelayQueue<>();
 
+  protected final List<Function<String, Boolean>> _unhealthyServerRetriers = 
new ArrayList<>();
+  protected Consumer<String> _healthyServerNotifier;
+  protected Consumer<String> _unhealthyServerNotifier;
   protected BrokerMetrics _brokerMetrics;
   protected long _retryInitialDelayNs;
   protected double _retryDelayFactor;
@@ -64,14 +68,24 @@ public abstract class 
BaseExponentialBackoffRetryFailureDetector implements Fail
     _retryDelayFactor = 
config.getProperty(Broker.FailureDetector.CONFIG_OF_RETRY_DELAY_FACTOR,
         Broker.FailureDetector.DEFAULT_RETRY_DELAY_FACTOR);
     _maxRetries =
-        config.getProperty(Broker.FailureDetector.CONFIG_OF_MAX_RETRIES, 
Broker.FailureDetector.DEFAULT_MAX_RETIRES);
+        config.getProperty(Broker.FailureDetector.CONFIG_OF_MAX_RETRIES, 
Broker.FailureDetector.DEFAULT_MAX_RETRIES);
     LOGGER.info("Initialized {} with retry initial delay: {}ms, exponential 
backoff factor: {}, max retries: {}", _name,
         retryInitialDelayMs, _retryDelayFactor, _maxRetries);
   }
 
   @Override
-  public void register(Listener listener) {
-    _listeners.add(listener);
+  public void registerUnhealthyServerRetrier(Function<String, Boolean> 
unhealthyServerRetrier) {
+    _unhealthyServerRetriers.add(unhealthyServerRetrier);
+  }
+
+  @Override
+  public void registerHealthyServerNotifier(Consumer<String> 
healthyServerNotifier) {
+    _healthyServerNotifier = healthyServerNotifier;
+  }
+
+  @Override
+  public void registerUnhealthyServerNotifier(Consumer<String> 
unhealthyServerNotifier) {
+    _unhealthyServerNotifier = unhealthyServerNotifier;
   }
 
   @Override
@@ -88,21 +102,28 @@ public abstract class 
BaseExponentialBackoffRetryFailureDetector implements Fail
             LOGGER.info("Server: {} has been marked healthy, skipping the 
retry", instanceId);
             continue;
           }
-          if (retryInfo._numRetires == _maxRetries) {
+          if (retryInfo._numRetries == _maxRetries) {
             LOGGER.warn("Unhealthy server: {} already reaches the max retries: 
{}, do not retry again and treat it "
                 + "as healthy so that the listeners do not lose track of the 
server", instanceId, _maxRetries);
             markServerHealthy(instanceId);
             continue;
           }
           LOGGER.info("Retry unhealthy server: {}", instanceId);
-          for (Listener listener : _listeners) {
-            listener.retryUnhealthyServer(instanceId, this);
+          boolean recovered = true;
+          for (Function<String, Boolean> unhealthyServerRetrier : 
_unhealthyServerRetriers) {
+            if (!unhealthyServerRetrier.apply(instanceId)) {
+              recovered = false;
+            }
+          }
+          if (recovered) {
+            markServerHealthy(instanceId);
+          } else {
+            // Update the retry info and add it back to the delay queue
+            retryInfo._retryDelayNs = (long) (retryInfo._retryDelayNs * 
_retryDelayFactor);
+            retryInfo._retryTimeNs = System.nanoTime() + 
retryInfo._retryDelayNs;
+            retryInfo._numRetries++;
+            _retryInfoDelayQueue.offer(retryInfo);
           }
-          // Update the retry info and add it back to the delay queue
-          retryInfo._retryDelayNs = (long) (retryInfo._retryDelayNs * 
_retryDelayFactor);
-          retryInfo._retryTimeNs = System.nanoTime() + retryInfo._retryDelayNs;
-          retryInfo._numRetires++;
-          _retryInfoDelayQueue.offer(retryInfo);
         } catch (Exception e) {
           if (_running) {
             LOGGER.error("Caught exception in the retry thread, continuing 
with errors", e);
@@ -120,9 +141,7 @@ public abstract class 
BaseExponentialBackoffRetryFailureDetector implements Fail
     _unhealthyServerRetryInfoMap.computeIfPresent(instanceId, (id, retryInfo) 
-> {
       LOGGER.info("Mark server: {} as healthy", instanceId);
       _brokerMetrics.setValueOfGlobalGauge(BrokerGauge.UNHEALTHY_SERVERS, 
_unhealthyServerRetryInfoMap.size() - 1);
-      for (Listener listener : _listeners) {
-        listener.notifyHealthyServer(instanceId, this);
-      }
+      _healthyServerNotifier.accept(instanceId);
       return null;
     });
   }
@@ -132,9 +151,7 @@ public abstract class 
BaseExponentialBackoffRetryFailureDetector implements Fail
     _unhealthyServerRetryInfoMap.computeIfAbsent(instanceId, id -> {
       LOGGER.warn("Mark server: {} as unhealthy", instanceId);
       _brokerMetrics.setValueOfGlobalGauge(BrokerGauge.UNHEALTHY_SERVERS, 
_unhealthyServerRetryInfoMap.size() + 1);
-      for (Listener listener : _listeners) {
-        listener.notifyUnhealthyServer(instanceId, this);
-      }
+      _unhealthyServerNotifier.accept(instanceId);
       RetryInfo retryInfo = new RetryInfo(id);
       _retryInfoDelayQueue.offer(retryInfo);
       return retryInfo;
@@ -167,13 +184,13 @@ public abstract class 
BaseExponentialBackoffRetryFailureDetector implements Fail
 
     long _retryTimeNs;
     long _retryDelayNs;
-    int _numRetires;
+    int _numRetries;
 
     RetryInfo(String instanceId) {
       _instanceId = instanceId;
       _retryTimeNs = System.nanoTime() + _retryInitialDelayNs;
       _retryDelayNs = _retryInitialDelayNs;
-      _numRetires = 0;
+      _numRetries = 0;
     }
 
     @Override
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/ConnectionFailureDetector.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/failuredetector/ConnectionFailureDetector.java
similarity index 69%
rename from 
pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/ConnectionFailureDetector.java
rename to 
pinot-common/src/main/java/org/apache/pinot/common/failuredetector/ConnectionFailureDetector.java
index 6add90b4f9..4c62762440 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/ConnectionFailureDetector.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/failuredetector/ConnectionFailureDetector.java
@@ -16,29 +16,18 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.broker.failuredetector;
+package org.apache.pinot.common.failuredetector;
 
 import javax.annotation.concurrent.ThreadSafe;
-import org.apache.pinot.core.transport.QueryResponse;
-import org.apache.pinot.core.transport.ServerRoutingInstance;
 
 
 /**
  * The {@code ConnectionFailureDetector} marks failed server (connection 
failure) from query response as unhealthy, and
  * retries the unhealthy servers with exponential increasing delays.
+ * <p>
+ * This class doesn't currently implement any additional logic over 
BaseExponentialBackoffRetryFailureDetector and is
+ * retained for backward compatibility.
  */
 @ThreadSafe
 public class ConnectionFailureDetector extends 
BaseExponentialBackoffRetryFailureDetector {
-
-  @Override
-  public void notifyQuerySubmitted(QueryResponse queryResponse) {
-  }
-
-  @Override
-  public void notifyQueryFinished(QueryResponse queryResponse) {
-    ServerRoutingInstance failedServer = queryResponse.getFailedServer();
-    if (failedServer != null) {
-      markServerUnhealthy(failedServer.getInstanceId());
-    }
-  }
 }
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/FailureDetector.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/failuredetector/FailureDetector.java
similarity index 63%
rename from 
pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/FailureDetector.java
rename to 
pinot-common/src/main/java/org/apache/pinot/common/failuredetector/FailureDetector.java
index cfcd8719db..33d47e1e8c 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/FailureDetector.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/failuredetector/FailureDetector.java
@@ -16,12 +16,13 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.broker.failuredetector;
+package org.apache.pinot.common.failuredetector;
 
 import java.util.Set;
+import java.util.function.Consumer;
+import java.util.function.Function;
 import javax.annotation.concurrent.ThreadSafe;
 import org.apache.pinot.common.metrics.BrokerMetrics;
-import org.apache.pinot.core.transport.QueryResponse;
 import org.apache.pinot.spi.annotations.InterfaceAudience;
 import org.apache.pinot.spi.annotations.InterfaceStability;
 import org.apache.pinot.spi.env.PinotConfiguration;
@@ -37,51 +38,31 @@ import org.apache.pinot.spi.env.PinotConfiguration;
 @ThreadSafe
 public interface FailureDetector {
 
-  /**
-   * Listener for the failure detector.
-   */
-  interface Listener {
-
-    /**
-     * Notifies the listener of an unhealthy server.
-     */
-    void notifyUnhealthyServer(String instanceId, FailureDetector 
failureDetector);
-
-    /**
-     * Notifies the listener to retry a previous unhealthy server.
-     */
-    void retryUnhealthyServer(String instanceId, FailureDetector 
failureDetector);
-
-    /**
-     * Notifies the listener of a previous unhealthy server turning healthy.
-     */
-    void notifyHealthyServer(String instanceId, FailureDetector 
failureDetector);
-  }
-
   /**
    * Initializes the failure detector.
    */
   void init(PinotConfiguration config, BrokerMetrics brokerMetrics);
 
   /**
-   * Registers a listener to the failure detector.
+   * Registers a function that will be periodically called to retry unhealthy 
servers. The function is called with the
+   * instanceId of the unhealthy server and should return true if the server 
is now healthy, false otherwise.
    */
-  void register(Listener listener);
+  void registerUnhealthyServerRetrier(Function<String, Boolean> 
unhealthyServerRetrier);
 
   /**
-   * Starts the failure detector. Listeners should be registered before 
starting the failure detector.
+   * Registers a consumer that will be called with the instanceId of a server 
that is detected as healthy.
    */
-  void start();
+  void registerHealthyServerNotifier(Consumer<String> healthyServerNotifier);
 
   /**
-   * Notifies the failure detector that a query is submitted.
+   * Registers a consumer that will be called with the instanceId of a server 
that is detected as unhealthy.
    */
-  void notifyQuerySubmitted(QueryResponse queryResponse);
+  void registerUnhealthyServerNotifier(Consumer<String> 
unhealthyServerNotifier);
 
   /**
-   * Notifies the failure detector that a query is finished (COMPLETED, FAILED 
or TIMED_OUT).
+   * Starts the failure detector.
    */
-  void notifyQueryFinished(QueryResponse queryResponse);
+  void start();
 
   /**
    * Marks a server as healthy.
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/FailureDetectorFactory.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/failuredetector/FailureDetectorFactory.java
similarity index 98%
rename from 
pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/FailureDetectorFactory.java
rename to 
pinot-common/src/main/java/org/apache/pinot/common/failuredetector/FailureDetectorFactory.java
index 165d6cf162..60bc82ca9b 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/FailureDetectorFactory.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/failuredetector/FailureDetectorFactory.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.broker.failuredetector;
+package org.apache.pinot.common.failuredetector;
 
 import com.google.common.base.Preconditions;
 import org.apache.commons.lang3.StringUtils;
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/NoOpFailureDetector.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/failuredetector/NoOpFailureDetector.java
similarity index 79%
rename from 
pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/NoOpFailureDetector.java
rename to 
pinot-common/src/main/java/org/apache/pinot/common/failuredetector/NoOpFailureDetector.java
index d81147ce6d..3b81d11b18 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/failuredetector/NoOpFailureDetector.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/failuredetector/NoOpFailureDetector.java
@@ -16,13 +16,14 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.broker.failuredetector;
+package org.apache.pinot.common.failuredetector;
 
 import java.util.Collections;
 import java.util.Set;
+import java.util.function.Consumer;
+import java.util.function.Function;
 import javax.annotation.concurrent.ThreadSafe;
 import org.apache.pinot.common.metrics.BrokerMetrics;
-import org.apache.pinot.core.transport.QueryResponse;
 import org.apache.pinot.spi.env.PinotConfiguration;
 
 
@@ -34,19 +35,19 @@ public class NoOpFailureDetector implements FailureDetector 
{
   }
 
   @Override
-  public void register(Listener listener) {
+  public void registerUnhealthyServerRetrier(Function<String, Boolean> 
unhealthyServerRetrier) {
   }
 
   @Override
-  public void start() {
+  public void registerHealthyServerNotifier(Consumer<String> 
healthyServerNotifier) {
   }
 
   @Override
-  public void notifyQuerySubmitted(QueryResponse queryResponse) {
+  public void registerUnhealthyServerNotifier(Consumer<String> 
unhealthyServerNotifier) {
   }
 
   @Override
-  public void notifyQueryFinished(QueryResponse queryResponse) {
+  public void start() {
   }
 
   @Override
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java
index 7d4a6cf487..76dc112b05 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java
@@ -112,6 +112,10 @@ public class GrpcQueryClient implements Closeable {
     return _blockingStub.submit(request);
   }
 
+  public ManagedChannel getChannel() {
+    return _managedChannel;
+  }
+
   @Override
   public void close() {
     if (!_managedChannel.isShutdown()) {
diff --git 
a/pinot-common/src/test/java/org/apache/pinot/common/failuredetector/ConnectionFailureDetectorTest.java
 
b/pinot-common/src/test/java/org/apache/pinot/common/failuredetector/ConnectionFailureDetectorTest.java
new file mode 100644
index 0000000000..45dd435b55
--- /dev/null
+++ 
b/pinot-common/src/test/java/org/apache/pinot/common/failuredetector/ConnectionFailureDetectorTest.java
@@ -0,0 +1,224 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.failuredetector;
+
+import java.util.Collections;
+import java.util.Set;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import org.apache.pinot.common.metrics.BrokerGauge;
+import org.apache.pinot.common.metrics.BrokerMetrics;
+import org.apache.pinot.common.metrics.MetricValueUtils;
+import org.apache.pinot.spi.env.PinotConfiguration;
+import org.apache.pinot.spi.metrics.PinotMetricUtils;
+import org.apache.pinot.spi.utils.CommonConstants.Broker;
+import org.apache.pinot.util.TestUtils;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+
+public class ConnectionFailureDetectorTest {
+  private static final String INSTANCE_ID = "Server_localhost_1234";
+
+  private BrokerMetrics _brokerMetrics;
+  private FailureDetector _failureDetector;
+  private UnhealthyServerRetrier _unhealthyServerRetrier;
+  private HealthyServerNotifier _healthyServerNotifier;
+  private UnhealthyServerNotifier _unhealthyServerNotifier;
+
+  @BeforeMethod
+  public void setUp() {
+    PinotConfiguration config = new PinotConfiguration();
+    config.setProperty(Broker.FailureDetector.CONFIG_OF_TYPE, 
Broker.FailureDetector.Type.CONNECTION.name());
+    
config.setProperty(Broker.FailureDetector.CONFIG_OF_RETRY_INITIAL_DELAY_MS, 
100);
+    config.setProperty(Broker.FailureDetector.CONFIG_OF_RETRY_DELAY_FACTOR, 1);
+    _brokerMetrics = new 
BrokerMetrics(PinotMetricUtils.getPinotMetricsRegistry());
+    _failureDetector = FailureDetectorFactory.getFailureDetector(config, 
_brokerMetrics);
+    assertTrue(_failureDetector instanceof ConnectionFailureDetector);
+    _healthyServerNotifier = new HealthyServerNotifier();
+    _failureDetector.registerHealthyServerNotifier(_healthyServerNotifier);
+    _unhealthyServerNotifier = new UnhealthyServerNotifier();
+    _failureDetector.registerUnhealthyServerNotifier(_unhealthyServerNotifier);
+    _failureDetector.start();
+  }
+
+  @Test
+  public void testConnectionFailure() {
+    // No unhealthy servers initially
+    verify(Collections.emptySet(), 0, 0);
+
+    _failureDetector.markServerUnhealthy(INSTANCE_ID);
+    verify(Collections.singleton(INSTANCE_ID), 1, 0);
+
+    // Mark server unhealthy again should have no effect
+    _failureDetector.markServerUnhealthy(INSTANCE_ID);
+    verify(Collections.singleton(INSTANCE_ID), 1, 0);
+
+    // Mark server healthy should remove it from the unhealthy servers and 
trigger a callback
+    _failureDetector.markServerHealthy(INSTANCE_ID);
+    verify(Collections.emptySet(), 1, 1);
+  }
+
+  @Test
+  public void testRetryWithoutRecovery() {
+    _unhealthyServerRetrier = new UnhealthyServerRetrier(10);
+    _failureDetector.registerUnhealthyServerRetrier(_unhealthyServerRetrier);
+
+    _failureDetector.markServerUnhealthy(INSTANCE_ID);
+    verify(Collections.singleton(INSTANCE_ID), 1, 0);
+
+    // Should get 10 retries in 1s, then remove the failed server from the 
unhealthy servers.
+    // Wait for up to 5s to avoid flakiness
+    TestUtils.waitForCondition(aVoid -> {
+      int numRetries = _unhealthyServerRetrier._retryUnhealthyServerCalled;
+      if (numRetries < Broker.FailureDetector.DEFAULT_MAX_RETRIES) {
+        assertEquals(_failureDetector.getUnhealthyServers(), 
Collections.singleton(INSTANCE_ID));
+        assertEquals(MetricValueUtils.getGlobalGaugeValue(_brokerMetrics, 
BrokerGauge.UNHEALTHY_SERVERS), 1);
+        return false;
+      }
+      assertEquals(numRetries, Broker.FailureDetector.DEFAULT_MAX_RETRIES);
+      // There might be a small delay between the last retry and removing 
failed server from the unhealthy servers.
+      // Perform a check instead of an assertion.
+      return _failureDetector.getUnhealthyServers().isEmpty()
+          && MetricValueUtils.getGaugeValue(_brokerMetrics, 
BrokerGauge.UNHEALTHY_SERVERS.getGaugeName()) == 0
+          && _unhealthyServerNotifier._notifyUnhealthyServerCalled == 1
+          && _healthyServerNotifier._notifyHealthyServerCalled == 1;
+    }, 5_000L, "Failed to get 10 retries");
+  }
+
+  @Test
+  public void testRetryWithRecovery() {
+    _unhealthyServerRetrier = new UnhealthyServerRetrier(6);
+    _failureDetector.registerUnhealthyServerRetrier(_unhealthyServerRetrier);
+
+    _failureDetector.markServerUnhealthy(INSTANCE_ID);
+    verify(Collections.singleton(INSTANCE_ID), 1, 0);
+
+    TestUtils.waitForCondition(aVoid -> {
+      int numRetries = _unhealthyServerRetrier._retryUnhealthyServerCalled;
+      if (numRetries < 7) {
+        // Avoid test flakiness by not making these assertions close to the 
end of the expected retry period
+        if (numRetries > 0 && numRetries <= 5) {
+          assertEquals(_failureDetector.getUnhealthyServers(), 
Collections.singleton(INSTANCE_ID));
+          assertEquals(MetricValueUtils.getGlobalGaugeValue(_brokerMetrics, 
BrokerGauge.UNHEALTHY_SERVERS), 1);
+        }
+        return false;
+      }
+      assertEquals(numRetries, 7);
+      // There might be a small delay between the successful attempt and 
removing failed server from the unhealthy
+      // servers. Perform a check instead of an assertion.
+      return _failureDetector.getUnhealthyServers().isEmpty()
+          && MetricValueUtils.getGaugeValue(_brokerMetrics, 
BrokerGauge.UNHEALTHY_SERVERS.getGaugeName()) == 0
+          && _unhealthyServerNotifier._notifyUnhealthyServerCalled == 1
+          && _healthyServerNotifier._notifyHealthyServerCalled == 1;
+    }, 5_000L, "Failed to get 7 retries");
+
+    // Verify no further retries
+    assertEquals(_unhealthyServerRetrier._retryUnhealthyServerCalled, 7);
+  }
+
+  @Test
+  public void testRetryWithMultipleUnhealthyServerRetriers() {
+    _unhealthyServerRetrier = new UnhealthyServerRetrier(7);
+    _failureDetector.registerUnhealthyServerRetrier(_unhealthyServerRetrier);
+
+    UnhealthyServerRetrier unhealthyServerRetrier2 = new 
UnhealthyServerRetrier(8);
+    _failureDetector.registerUnhealthyServerRetrier(unhealthyServerRetrier2);
+
+    _failureDetector.markServerUnhealthy(INSTANCE_ID);
+    verify(Collections.singleton(INSTANCE_ID), 1, 0);
+
+    // Should retry until both unhealthy server retriers return that the 
server is healthy
+    TestUtils.waitForCondition(aVoid -> {
+      int numRetries = _unhealthyServerRetrier._retryUnhealthyServerCalled;
+      if (numRetries < 9) {
+        // Avoid test flakiness by not making these assertions close to the 
end of the expected retry period
+        if (numRetries > 0 && numRetries <= 7) {
+          assertEquals(_failureDetector.getUnhealthyServers(), 
Collections.singleton(INSTANCE_ID));
+          assertEquals(MetricValueUtils.getGlobalGaugeValue(_brokerMetrics, 
BrokerGauge.UNHEALTHY_SERVERS), 1);
+        }
+        return false;
+      }
+      assertEquals(numRetries, 9);
+      // There might be a small delay between the successful attempt and 
removing failed server from the unhealthy
+      // servers. Perform a check instead of an assertion.
+      return _failureDetector.getUnhealthyServers().isEmpty()
+          && MetricValueUtils.getGaugeValue(_brokerMetrics, 
BrokerGauge.UNHEALTHY_SERVERS.getGaugeName()) == 0
+          && _unhealthyServerNotifier._notifyUnhealthyServerCalled == 1
+          && _healthyServerNotifier._notifyHealthyServerCalled == 1;
+    }, 5_000L, "Failed to get 5 retries");
+
+    // Verify no further retries
+    assertEquals(_unhealthyServerRetrier._retryUnhealthyServerCalled, 9);
+  }
+
+  private void verify(Set<String> expectedUnhealthyServers, int 
expectedNotifyUnhealthyServerCalled,
+      int expectedNotifyHealthyServerCalled) {
+    assertEquals(_failureDetector.getUnhealthyServers(), 
expectedUnhealthyServers);
+    assertEquals(MetricValueUtils.getGlobalGaugeValue(_brokerMetrics, 
BrokerGauge.UNHEALTHY_SERVERS),
+        expectedUnhealthyServers.size());
+    assertEquals(_unhealthyServerNotifier._notifyUnhealthyServerCalled, 
expectedNotifyUnhealthyServerCalled);
+    assertEquals(_healthyServerNotifier._notifyHealthyServerCalled, 
expectedNotifyHealthyServerCalled);
+  }
+
+  @AfterClass
+  public void tearDown() {
+    _failureDetector.stop();
+  }
+
+  private static class HealthyServerNotifier implements Consumer<String> {
+    int _notifyHealthyServerCalled = 0;
+
+    @Override
+    public void accept(String instanceId) {
+      assertEquals(instanceId, INSTANCE_ID);
+      _notifyHealthyServerCalled++;
+    }
+  }
+
+  private static class UnhealthyServerNotifier implements Consumer<String> {
+    int _notifyUnhealthyServerCalled = 0;
+
+    @Override
+    public void accept(String instanceId) {
+      assertEquals(instanceId, INSTANCE_ID);
+      _notifyUnhealthyServerCalled++;
+    }
+  }
+
+  private static class UnhealthyServerRetrier implements Function<String, 
Boolean> {
+    int _retryUnhealthyServerCalled = 0;
+    final int _numFailures;
+
+    UnhealthyServerRetrier(int numFailures) {
+      _numFailures = numFailures;
+    }
+
+    @Override
+    public Boolean apply(String instanceId) {
+      assertEquals(instanceId, INSTANCE_ID);
+      _retryUnhealthyServerCalled++;
+      return _retryUnhealthyServerCalled > _numFailures;
+    }
+  }
+}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/QueryServerInstance.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/QueryServerInstance.java
index b9442b728c..f576fed47f 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/QueryServerInstance.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/QueryServerInstance.java
@@ -30,20 +30,26 @@ import org.apache.pinot.core.transport.ServerInstance;
  * <p>Note that {@code QueryServerInstance} should only be used during 
dispatch.</p>
  */
 public class QueryServerInstance {
+  private final String _instanceId;
   private final String _hostname;
   private final int _queryServicePort;
   private final int _queryMailboxPort;
 
   public QueryServerInstance(ServerInstance server) {
-    this(server.getHostname(), server.getQueryServicePort(), 
server.getQueryMailboxPort());
+    this(server.getInstanceId(), server.getHostname(), 
server.getQueryServicePort(), server.getQueryMailboxPort());
   }
 
-  public QueryServerInstance(String hostName, int servicePort, int 
mailboxPort) {
+  public QueryServerInstance(String instanceId, String hostName, int 
servicePort, int mailboxPort) {
+    _instanceId = instanceId;
     _hostname = hostName;
     _queryServicePort = servicePort;
     _queryMailboxPort = mailboxPort;
   }
 
+  public String getInstanceId() {
+    return _instanceId;
+  }
+
   public String getHostname() {
     return _hostname;
   }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
index 73174304e4..f2b296ea3b 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
@@ -69,11 +69,13 @@ public class WorkerManager {
   // default table partition function if not specified in hint
   private static final String DEFAULT_TABLE_PARTITION_FUNCTION = "Murmur";
 
+  private final String _instanceId;
   private final String _hostName;
   private final int _port;
   private final RoutingManager _routingManager;
 
-  public WorkerManager(String hostName, int port, RoutingManager 
routingManager) {
+  public WorkerManager(String instanceId, String hostName, int port, 
RoutingManager routingManager) {
+    _instanceId = instanceId;
     _hostName = hostName;
     _port = port;
     _routingManager = routingManager;
@@ -84,7 +86,7 @@ public class WorkerManager {
     // worker instance with identical server/mailbox port number.
     DispatchablePlanMetadata metadata = 
context.getDispatchablePlanMetadataMap().get(0);
     metadata.setWorkerIdToServerInstanceMap(
-        Collections.singletonMap(0, new QueryServerInstance(_hostName, _port, 
_port)));
+        Collections.singletonMap(0, new QueryServerInstance(_instanceId, 
_hostName, _port, _port)));
     for (PlanFragment child : rootFragment.getChildren()) {
       assignWorkersToNonRootFragment(child, context);
     }
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
index 37863f9d99..3bd11833c1 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
@@ -307,7 +307,7 @@ public class QueryEnvironmentTestBase {
     RoutingManager routingManager = 
factory.buildRoutingManager(partitionInfoMap);
     TableCache tableCache = factory.buildTableCache();
     return new QueryEnvironment(CommonConstants.DEFAULT_DATABASE, tableCache,
-        new WorkerManager("localhost", reducerPort, routingManager));
+        new WorkerManager("Broker_localhost", "localhost", reducerPort, 
routingManager));
   }
 
   /**
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index e3c8d07ef4..06e230dce9 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -18,11 +18,11 @@
  */
 package org.apache.pinot.query.service.dispatch;
 
-import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.InvalidProtocolBufferException;
+import io.grpc.ConnectivityState;
 import io.grpc.Deadline;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -48,6 +48,7 @@ import org.apache.pinot.common.config.TlsConfig;
 import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.common.exception.QueryException;
 import org.apache.pinot.common.exception.QueryInfoException;
+import org.apache.pinot.common.failuredetector.FailureDetector;
 import org.apache.pinot.common.proto.Plan;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.common.response.PinotBrokerTimeSeriesResponse;
@@ -99,11 +100,11 @@ import org.slf4j.LoggerFactory;
 public class QueryDispatcher {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(QueryDispatcher.class);
   private static final String PINOT_BROKER_QUERY_DISPATCHER_FORMAT = 
"multistage-query-dispatch-%d";
-  private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
 
   private final MailboxService _mailboxService;
   private final ExecutorService _executorService;
   private final Map<String, DispatchClient> _dispatchClientMap = new 
ConcurrentHashMap<>();
+  private final Map<String, String> _instanceIdToHostnamePortMap = new 
ConcurrentHashMap<>();
   private final Map<String, TimeSeriesDispatchClient> 
_timeSeriesDispatchClientMap = new ConcurrentHashMap<>();
   @Nullable
   private final TlsConfig _tlsConfig;
@@ -111,16 +112,21 @@ public class QueryDispatcher {
   private final Map<Long, Set<QueryServerInstance>> _serversByQuery;
   private final PhysicalTimeSeriesBrokerPlanVisitor 
_timeSeriesBrokerPlanVisitor
       = new PhysicalTimeSeriesBrokerPlanVisitor();
+  @Nullable
+  private final FailureDetector _failureDetector;
 
   public QueryDispatcher(MailboxService mailboxService) {
-    this(mailboxService, null, false);
+    this(mailboxService, null, null, false);
   }
 
-  public QueryDispatcher(MailboxService mailboxService, @Nullable TlsConfig 
tlsConfig, boolean enableCancellation) {
+  public QueryDispatcher(MailboxService mailboxService, @Nullable TlsConfig 
tlsConfig,
+      @Nullable FailureDetector failureDetector, boolean enableCancellation) {
     _mailboxService = mailboxService;
     _executorService = Executors.newFixedThreadPool(2 * 
Runtime.getRuntime().availableProcessors(),
         new TracedThreadFactory(Thread.NORM_PRIORITY, false, 
PINOT_BROKER_QUERY_DISPATCHER_FORMAT));
     _tlsConfig = tlsConfig;
+    _failureDetector = failureDetector;
+
     if (enableCancellation) {
       _serversByQuery = new ConcurrentHashMap<>();
     } else {
@@ -212,6 +218,18 @@ public class QueryDispatcher {
     }
   }
 
+  public boolean checkConnectivityToInstance(String instanceId) {
+    String hostnamePort = _instanceIdToHostnamePortMap.get(instanceId);
+    DispatchClient client = _dispatchClientMap.get(hostnamePort);
+
+    if (client == null) {
+      LOGGER.warn("No DispatchClient found for server with instanceId: {}", 
instanceId);
+      return false;
+    }
+
+    return client.getChannel().getState(true) == ConnectivityState.READY;
+  }
+
   private boolean isQueryCancellationEnabled() {
     return _serversByQuery != null;
   }
@@ -244,14 +262,18 @@ public class QueryDispatcher {
               serverInstance);
         }
       };
+      Worker.QueryRequest requestBuilder =
+          createRequest(serverInstance, stagePlans, stageInfos, 
protoRequestMetadata);
+      DispatchClient dispatchClient = 
getOrCreateDispatchClient(serverInstance);
+
       try {
-        Worker.QueryRequest requestBuilder =
-            createRequest(serverInstance, stagePlans, stageInfos, 
protoRequestMetadata);
-        DispatchClient dispatchClient = 
getOrCreateDispatchClient(serverInstance);
         sendRequest.send(dispatchClient, requestBuilder, serverInstance, 
deadline, callbackConsumer);
       } catch (Throwable t) {
         LOGGER.warn("Caught exception while dispatching query: {} to server: 
{}", requestId, serverInstance, t);
         callbackConsumer.accept(new AsyncResponse<>(serverInstance, null, t));
+        if (_failureDetector != null) {
+          _failureDetector.markServerUnhealthy(serverInstance.getInstanceId());
+        }
       }
     }
 
@@ -410,8 +432,9 @@ public class QueryDispatcher {
   private DispatchClient getOrCreateDispatchClient(QueryServerInstance 
queryServerInstance) {
     String hostname = queryServerInstance.getHostname();
     int port = queryServerInstance.getQueryServicePort();
-    String key = String.format("%s_%d", hostname, port);
-    return _dispatchClientMap.computeIfAbsent(key, k -> new 
DispatchClient(hostname, port, _tlsConfig));
+    String hostnamePort = String.format("%s_%d", hostname, port);
+    _instanceIdToHostnamePortMap.put(queryServerInstance.getInstanceId(), 
hostnamePort);
+    return _dispatchClientMap.computeIfAbsent(hostnamePort, k -> new 
DispatchClient(hostname, port, _tlsConfig));
   }
 
   private TimeSeriesDispatchClient getOrCreateTimeSeriesDispatchClient(
@@ -513,6 +536,7 @@ public class QueryDispatcher {
       dispatchClient.getChannel().shutdown();
     }
     _dispatchClientMap.clear();
+    _instanceIdToHostnamePortMap.clear();
     _mailboxService.shutdown();
     _executorService.shutdown();
   }
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerAccountingTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerAccountingTest.java
index 3aa6556b1b..3c6d809626 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerAccountingTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerAccountingTest.java
@@ -74,8 +74,8 @@ public class QueryRunnerAccountingTest extends 
QueryRunnerTestBase {
     // this is only use for test identifier purpose.
     int port1 = server1.getPort();
     int port2 = server2.getPort();
-    _servers.put(new QueryServerInstance("localhost", port1, port1), server1);
-    _servers.put(new QueryServerInstance("localhost", port2, port2), server2);
+    _servers.put(new QueryServerInstance("Server_localhost_" + port1, 
"localhost", port1, port1), server1);
+    _servers.put(new QueryServerInstance("Server_localhost_" + port2, 
"localhost", port2, port2), server2);
 
     _queryEnvironment = 
QueryEnvironmentTestBase.getQueryEnvironment(_reducerPort, server1.getPort(), 
server2.getPort(),
         factory1.getRegisteredSchemaMap(), 
factory1.buildTableSegmentNameMap(), factory2.buildTableSegmentNameMap(),
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java
index dd32ee98b6..58a042d975 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java
@@ -146,8 +146,8 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
     // this is only use for test identifier purpose.
     int port1 = server1.getPort();
     int port2 = server2.getPort();
-    _servers.put(new QueryServerInstance("localhost", port1, port1), server1);
-    _servers.put(new QueryServerInstance("localhost", port2, port2), server2);
+    _servers.put(new QueryServerInstance("Server_localhost_" + port1, 
"localhost", port1, port1), server1);
+    _servers.put(new QueryServerInstance("Server_localhost_" + port2, 
"localhost", port2, port2), server2);
 
     _queryEnvironment = 
QueryEnvironmentTestBase.getQueryEnvironment(_reducerPort, server1.getPort(), 
server2.getPort(),
         factory1.getRegisteredSchemaMap(), 
factory1.buildTableSegmentNameMap(), factory2.buildTableSegmentNameMap(),
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
index 845c5ff52f..f94d85c92b 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
@@ -222,8 +222,8 @@ public class ResourceBasedQueriesTest extends 
QueryRunnerTestBase {
     // this is only use for test identifier purpose.
     int port1 = server1.getPort();
     int port2 = server2.getPort();
-    _servers.put(new QueryServerInstance("localhost", port1, port1), server1);
-    _servers.put(new QueryServerInstance("localhost", port2, port2), server2);
+    _servers.put(new QueryServerInstance("Server_localhost_" + port1, 
"localhost", port1, port1), server1);
+    _servers.put(new QueryServerInstance("Server_localhost_" + port2, 
"localhost", port2, port2), server2);
 
     _queryEnvironment = 
QueryEnvironmentTestBase.getQueryEnvironment(_reducerPort, server1.getPort(), 
server2.getPort(),
         factory1.getRegisteredSchemaMap(), 
factory1.buildTableSegmentNameMap(), factory2.buildTableSegmentNameMap(),
diff --git 
a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java 
b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
index 653667f4b3..4a7c58d965 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
@@ -614,7 +614,7 @@ public class CommonConstants {
       public static final String CONFIG_OF_RETRY_DELAY_FACTOR = 
"pinot.broker.failure.detector.retry.delay.factor";
       public static final double DEFAULT_RETRY_DELAY_FACTOR = 2.0;
       public static final String CONFIG_OF_MAX_RETRIES = 
"pinot.broker.failure.detector.max.retries";
-      public static final int DEFAULT_MAX_RETIRES = 10;
+      public static final int DEFAULT_MAX_RETRIES = 10;
     }
 
     // Configs related to AdaptiveServerSelection.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to