This is an automated email from the ASF dual-hosted git repository.

gortiz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 4d423555fc Include stats on cancel (#15609)
4d423555fc is described below

commit 4d423555fcd2e1f908c5a36e27160d38b151a09e
Author: Gonzalo Ortiz Jaureguizar <gor...@users.noreply.github.com>
AuthorDate: Mon May 5 16:05:55 2025 +0200

    Include stats on cancel (#15609)
---
 .../MultiStageBrokerRequestHandler.java            |  36 +++--
 pinot-common/src/main/proto/worker.proto           |   2 +-
 .../query/runtime/MultiStageStatsTreeBuilder.java  |   2 +-
 .../apache/pinot/query/runtime/QueryRunner.java    |   7 +-
 .../runtime/executor/OpChainSchedulerService.java  |  75 ++++++++-
 .../query/runtime/plan/MultiStageQueryStats.java   |  28 ++++
 .../plan/pipeline/PipelineBreakerOperator.java     |   2 +-
 .../query/service/dispatch/DispatchClient.java     |  15 +-
 .../query/service/dispatch/QueryDispatcher.java    | 180 ++++++++++++++++++---
 .../pinot/query/service/server/QueryServer.java    |  20 ++-
 .../executor/OpChainSchedulerServiceTest.java      |   2 +
 .../apache/pinot/spi/utils/CommonConstants.java    |  13 +-
 12 files changed, 336 insertions(+), 46 deletions(-)

diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
index 3c39ecec72..4ee9b69e0c 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
@@ -20,6 +20,7 @@ package org.apache.pinot.broker.requesthandler;
 
 import com.fasterxml.jackson.databind.JsonNode;
 import com.fasterxml.jackson.databind.node.JsonNodeFactory;
+import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -122,9 +123,13 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
         CommonConstants.Broker.BROKER_TLS_PREFIX) : null;
 
     failureDetector.registerUnhealthyServerRetrier(this::retryUnhealthyServer);
+    long cancelMillis = config.getProperty(
+        CommonConstants.MultiStageQueryRunner.KEY_OF_CANCEL_TIMEOUT_MS,
+        CommonConstants.MultiStageQueryRunner.DEFAULT_OF_CANCEL_TIMEOUT_MS);
+    Duration cancelTimeout = Duration.ofMillis(cancelMillis);
     _queryDispatcher =
         new QueryDispatcher(new MailboxService(hostname, port, config, 
tlsConfig), failureDetector, tlsConfig,
-            this.isQueryCancellationEnabled());
+            this.isQueryCancellationEnabled(), cancelTimeout);
     LOGGER.info("Initialized MultiStageBrokerRequestHandler on host: {}, port: 
{} with broker id: {}, timeout: {}ms, "
             + "query log max length: {}, query log max rate: {}, query 
cancellation enabled: {}", hostname, port,
         _brokerId, _brokerTimeoutMs, _queryLogger.getMaxQueryLengthToLog(), 
_queryLogger.getLogRateLimit(),
@@ -421,13 +426,6 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
       try {
         queryResults = _queryDispatcher.submitAndReduce(requestContext, 
dispatchableSubPlan, timer.getRemainingTimeMs(),
                 query.getOptions());
-      } catch (TimeoutException e) {
-        for (String table : tableNames) {
-          _brokerMetrics.addMeteredTableValue(table, 
BrokerMeter.BROKER_RESPONSES_WITH_TIMEOUTS, 1);
-        }
-        LOGGER.warn("Timed out executing request {}: {}", requestId, query);
-        requestContext.setErrorCode(QueryErrorCode.EXECUTION_TIMEOUT);
-        return new BrokerResponseNative(QueryErrorCode.EXECUTION_TIMEOUT);
       } catch (QueryException e) {
         throw e;
       } catch (Throwable t) {
@@ -440,11 +438,27 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
         Tracing.ThreadAccountantOps.clear();
         onQueryFinish(requestId);
       }
-      long executionEndTimeNs = System.nanoTime();
-      updatePhaseTimingForTables(tableNames, BrokerQueryPhase.QUERY_EXECUTION,
-          executionEndTimeNs - executionStartTimeNs);
 
       BrokerResponseNativeV2 brokerResponse = new BrokerResponseNativeV2();
+
+      QueryProcessingException processingException = 
queryResults.getProcessingException();
+      if (processingException != null) {
+        brokerResponse.addException(processingException);
+        QueryErrorCode errorCode = 
QueryErrorCode.fromErrorCode(processingException.getErrorCode());
+        if (errorCode == QueryErrorCode.EXECUTION_TIMEOUT) {
+          for (String table : tableNames) {
+            _brokerMetrics.addMeteredTableValue(table, 
BrokerMeter.BROKER_RESPONSES_WITH_TIMEOUTS, 1);
+          }
+          LOGGER.warn("Timed out executing request {}: {}", requestId, query);
+        }
+        requestContext.setErrorCode(errorCode);
+      } else {
+        brokerResponse.setResultTable(queryResults.getResultTable());
+        long executionEndTimeNs = System.nanoTime();
+        updatePhaseTimingForTables(tableNames, 
BrokerQueryPhase.QUERY_EXECUTION,
+            executionEndTimeNs - executionStartTimeNs);
+      }
+
       brokerResponse.setClientRequestId(clientRequestId);
       brokerResponse.setResultTable(queryResults.getResultTable());
       brokerResponse.setTablesQueried(tableNames);
diff --git a/pinot-common/src/main/proto/worker.proto 
b/pinot-common/src/main/proto/worker.proto
index 08623a8fad..fc6da7785f 100644
--- a/pinot-common/src/main/proto/worker.proto
+++ b/pinot-common/src/main/proto/worker.proto
@@ -38,7 +38,7 @@ message CancelRequest {
 }
 
 message CancelResponse {
-  // intentionally left empty
+  map<int32, bytes> statsByStage = 1; // stageId -> serialized 
MultiStageQueryStats.StageStats.Closed
 }
 
 // QueryRequest is the dispatched content for all query stages to a physical 
worker.
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
index bc901cde3f..927f107c53 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
@@ -53,7 +53,7 @@ public class MultiStageStatsTreeBuilder {
       ObjectNode jsonNodes = JsonUtils.newObjectNode();
       jsonNodes.put("type", "EMPTY_MAILBOX_SEND");
       jsonNodes.put("stage", stage);
-      jsonNodes.put("description", "No stats available for this stage. They 
may have been pruned.");
+      jsonNodes.put("description", "No stats available for this stage");
       String tableName = _planFragments.get(stage).getTableName();
       if (tableName != null) {
         jsonNodes.put("table", tableName);
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
index c86015bdfc..64c754cff3 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
@@ -65,6 +65,7 @@ import 
org.apache.pinot.query.runtime.operator.LeafStageTransferableBlockOperato
 import org.apache.pinot.query.runtime.operator.MailboxSendOperator;
 import org.apache.pinot.query.runtime.operator.MultiStageOperator;
 import org.apache.pinot.query.runtime.operator.OpChain;
+import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.apache.pinot.query.runtime.plan.PlanNodeToOpChain;
 import org.apache.pinot.query.runtime.plan.pipeline.PipelineBreakerExecutor;
@@ -197,7 +198,7 @@ public class QueryRunner {
       _executorService = new HardLimitExecutor(hardLimit, _executorService);
     }
 
-    _opChainScheduler = new OpChainSchedulerService(_executorService);
+    _opChainScheduler = new OpChainSchedulerService(_executorService, config);
     _mailboxService = new MailboxService(hostname, port, config, tlsConfig);
     try {
       _leafQueryExecutor = new ServerQueryExecutorV1Impl();
@@ -443,8 +444,8 @@ public class QueryRunner {
     return opChainMetadata;
   }
 
-  public void cancel(long requestId) {
-    _opChainScheduler.cancel(requestId);
+  public Map<Integer, MultiStageQueryStats.StageStats.Closed> cancel(long 
requestId) {
+    return _opChainScheduler.cancel(requestId);
   }
 
   public StagePlan explainQuery(
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
index 60c4a70118..f3daece353 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
@@ -18,35 +18,65 @@
  */
 package org.apache.pinot.query.runtime.executor;
 
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
 import org.apache.pinot.core.util.trace.TraceRunnable;
 import org.apache.pinot.query.runtime.blocks.ErrorMseBlock;
 import org.apache.pinot.query.runtime.blocks.MseBlock;
+import org.apache.pinot.query.runtime.operator.MultiStageOperator;
 import org.apache.pinot.query.runtime.operator.OpChain;
 import org.apache.pinot.query.runtime.operator.OpChainId;
 import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
 import org.apache.pinot.spi.accounting.ThreadExecutionContext;
+import org.apache.pinot.spi.env.PinotConfiguration;
 import org.apache.pinot.spi.trace.Tracing;
+import org.apache.pinot.spi.utils.CommonConstants.MultiStageQueryRunner;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-
 public class OpChainSchedulerService {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(OpChainSchedulerService.class);
 
   private final ExecutorService _executorService;
-  private final ConcurrentHashMap<OpChainId, Future<?>> _submittedOpChainMap;
+  private final ConcurrentHashMap<OpChainId, Future<?>> _submittedOpChainMap = 
new ConcurrentHashMap<>();
+  private final Cache<OpChainId, MultiStageOperator> _opChainCache;
+
+
+  public OpChainSchedulerService(ExecutorService executorService, 
PinotConfiguration config) {
+    this(
+        executorService,
+        config.getProperty(MultiStageQueryRunner.KEY_OF_OP_STATS_CACHE_SIZE,
+            MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_SIZE),
+        
config.getProperty(MultiStageQueryRunner.KEY_OF_OP_STATS_CACHE_EXPIRE_MS,
+            MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS)
+    );
+  }
 
   public OpChainSchedulerService(ExecutorService executorService) {
+    this(executorService, MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_SIZE,
+        MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS);
+  }
+
+  public OpChainSchedulerService(ExecutorService executorService, int 
maxWeight, long expireAfterWriteMs) {
     _executorService = executorService;
-    _submittedOpChainMap = new ConcurrentHashMap<>();
+    _opChainCache = CacheBuilder.newBuilder()
+        .weigher((OpChainId key, MultiStageOperator value) -> 
countOperators(value))
+        .maximumWeight(maxWeight)
+        .expireAfterWrite(expireAfterWriteMs, TimeUnit.MILLISECONDS)
+        .build();
   }
 
   public void register(OpChain operatorChain) {
+    _opChainCache.put(operatorChain.getId(), operatorChain.getRoot());
+
     Future<?> scheduledFuture = _executorService.submit(new TraceRunnable() {
       @Override
       public void runJob() {
@@ -87,7 +117,7 @@ public class OpChainSchedulerService {
     _submittedOpChainMap.put(operatorChain.getId(), scheduledFuture);
   }
 
-  public void cancel(long requestId) {
+  public Map<Integer, MultiStageQueryStats.StageStats.Closed> cancel(long 
requestId) {
     // simple cancellation. for leaf stage this cannot be a dangling opchain 
b/c they will eventually be cleared up
     // via query timeout.
     Iterator<Map.Entry<OpChainId, Future<?>>> iterator = 
_submittedOpChainMap.entrySet().iterator();
@@ -98,5 +128,42 @@ public class OpChainSchedulerService {
         iterator.remove();
       }
     }
+    Map<OpChainId, MultiStageOperator> cancelledByOpChainId = 
_opChainCache.asMap()
+        .entrySet()
+        .stream()
+        .filter(entry -> entry.getKey().getRequestId() == requestId)
+        .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, 
e2) -> e1));
+    _opChainCache.invalidateAll(cancelledByOpChainId.keySet());
+
+    return cancelledByOpChainId.entrySet()
+        .stream()
+        .collect(Collectors.toMap(
+            e -> e.getKey().getStageId(),
+            e -> e.getValue().calculateStats().getCurrentStats().close(),
+            (e1, e2) -> {
+              e1.merge(e2);
+              return e1;
+            }
+        ));
+  }
+
+  /**
+   * Counts the number of operators in the tree rooted at the given operator.
+   */
+  private int countOperators(MultiStageOperator root) {
+    // This stack will have at most 2 elements on most stages given that there 
is only 1 join in a stage
+    // and joins only have 2 children.
+    // Some operators (like SetOperator) can have more than 2 children, but 
they are not common.
+    ArrayList<MultiStageOperator> stack = new ArrayList<>(8);
+    stack.add(root);
+    int result = 0;
+    while (!stack.isEmpty()) {
+      result++;
+      MultiStageOperator operator = stack.remove(stack.size() - 1);
+      if (operator.getChildOperators() != null) {
+        stack.addAll(operator.getChildOperators());
+      }
+    }
+    return result;
   }
 }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/MultiStageQueryStats.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/MultiStageQueryStats.java
index b4c86d183c..9ff54f01f5 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/MultiStageQueryStats.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/MultiStageQueryStats.java
@@ -205,6 +205,34 @@ public class MultiStageQueryStats {
     }
   }
 
+  /**
+   * Merge stats from an upstream stage into this one.
+   *
+   * This is similar to wrap the stats in a MultiStageQueryStats object and 
call
+   * {@link #mergeUpstream(MultiStageQueryStats)} but it is easier to call and 
slightly more efficient.
+   *
+   * @param stageId the stage id of the upstream stage
+   * @param closedStats the stats to merge
+   */
+  public void mergeUpstream(int stageId, StageStats.Closed closedStats) {
+    Preconditions.checkArgument(_currentStageId <= stageId,
+        "Cannot merge stats from early stage %s into stats of later stage %s",
+        stageId, _currentStageId);
+
+    growUpToStage(stageId);
+    int selfIdx = stageId - _currentStageId - 1;
+    StageStats.Closed myStats = _closedStats.get(selfIdx);
+    try {
+      if (myStats == null) {
+        _closedStats.set(selfIdx, closedStats);
+      } else {
+        myStats.merge(closedStats);
+      }
+    } catch (IllegalArgumentException | IllegalStateException ex) {
+      LOGGER.warn("Error merging stats on stage {}. Ignoring the new stats", 
stageId, ex);
+    }
+  }
+
   /**
    * Merge upstream stats from another MultiStageQueryStats object into this 
one.
    *
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerOperator.java
index e0862bc655..8fe38c776e 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerOperator.java
@@ -65,7 +65,7 @@ public class PipelineBreakerOperator extends 
MultiStageOperator {
 
   @Override
   public List<MultiStageOperator> getChildOperators() {
-    throw new UnsupportedOperationException();
+    return Collections.emptyList();
   }
 
   @Override
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
index 1b90a7ea95..8590d971df 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
@@ -70,7 +70,7 @@ class DispatchClient {
     _dispatchStub.withDeadline(deadline).submit(request, new 
LastValueDispatchObserver<>(virtualServer, callback));
   }
 
-  public void cancel(long requestId) {
+  public void cancelAsync(long requestId) {
     String cid = QueryThreadContext.isInitialized() && 
QueryThreadContext.getCid() != null
         ? QueryThreadContext.getCid()
         : Long.toString(requestId);
@@ -81,6 +81,19 @@ class DispatchClient {
     _dispatchStub.cancel(cancelRequest, NO_OP_CANCEL_STREAM_OBSERVER);
   }
 
+  public void cancel(long requestId, QueryServerInstance virtualServer, 
Deadline deadline,
+      Consumer<AsyncResponse<Worker.CancelResponse>> callback) {
+    String cid = QueryThreadContext.isInitialized() && 
QueryThreadContext.getCid() != null
+        ? QueryThreadContext.getCid()
+        : Long.toString(requestId);
+    Worker.CancelRequest cancelRequest = Worker.CancelRequest.newBuilder()
+        .setRequestId(requestId)
+        .setCid(cid)
+        .build();
+    StreamObserver<Worker.CancelResponse> observer = new 
LastValueDispatchObserver<>(virtualServer, callback);
+    _dispatchStub.withDeadline(deadline).cancel(cancelRequest, observer);
+  }
+
   public void explain(Worker.QueryRequest request, QueryServerInstance 
virtualServer, Deadline deadline,
       Consumer<AsyncResponse<List<Worker.ExplainResponse>>> callback) {
     _dispatchStub.withDeadline(deadline).explain(request, new 
AllValuesDispatchObserver<>(virtualServer, callback));
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index d8f6550b30..e212e38f02 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -25,6 +25,9 @@ import com.google.protobuf.ByteString;
 import com.google.protobuf.InvalidProtocolBufferException;
 import io.grpc.ConnectivityState;
 import io.grpc.Deadline;
+import java.io.DataInputStream;
+import java.io.InputStream;
+import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
@@ -43,15 +46,18 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
+import java.util.function.Function;
 import javax.annotation.Nullable;
 import org.apache.calcite.runtime.PairList;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.config.TlsConfig;
 import org.apache.pinot.common.datablock.DataBlock;
+import org.apache.pinot.common.datatable.StatMap;
 import org.apache.pinot.common.failuredetector.FailureDetector;
 import org.apache.pinot.common.proto.Plan;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.common.response.PinotBrokerTimeSeriesResponse;
+import org.apache.pinot.common.response.broker.QueryProcessingException;
 import org.apache.pinot.common.response.broker.ResultTable;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
@@ -72,7 +78,9 @@ import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.ErrorMseBlock;
 import org.apache.pinot.query.runtime.blocks.MseBlock;
+import org.apache.pinot.query.runtime.operator.BaseMailboxReceiveOperator;
 import org.apache.pinot.query.runtime.operator.MailboxReceiveOperator;
+import org.apache.pinot.query.runtime.operator.MultiStageOperator;
 import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import 
org.apache.pinot.query.runtime.timeseries.PhysicalTimeSeriesBrokerPlanVisitor;
@@ -81,6 +89,7 @@ import 
org.apache.pinot.query.service.dispatch.timeseries.TimeSeriesDispatchClie
 import 
org.apache.pinot.query.service.dispatch.timeseries.TimeSeriesDispatchObserver;
 import org.apache.pinot.spi.accounting.ThreadExecutionContext;
 import org.apache.pinot.spi.exception.QueryErrorCode;
+import org.apache.pinot.spi.exception.QueryException;
 import org.apache.pinot.spi.query.QueryThreadContext;
 import org.apache.pinot.spi.trace.RequestContext;
 import org.apache.pinot.spi.trace.Tracing;
@@ -115,6 +124,7 @@ public class QueryDispatcher {
   private final PhysicalTimeSeriesBrokerPlanVisitor 
_timeSeriesBrokerPlanVisitor
       = new PhysicalTimeSeriesBrokerPlanVisitor();
   private final FailureDetector _failureDetector;
+  private final Duration _cancelTimeout;
 
   public QueryDispatcher(MailboxService mailboxService, FailureDetector 
failureDetector) {
     this(mailboxService, failureDetector, null, false);
@@ -122,6 +132,12 @@ public class QueryDispatcher {
 
   public QueryDispatcher(MailboxService mailboxService, FailureDetector 
failureDetector, @Nullable TlsConfig tlsConfig,
       boolean enableCancellation) {
+    this(mailboxService, failureDetector, tlsConfig, enableCancellation, 
Duration.ofSeconds(1));
+  }
+
+  public QueryDispatcher(MailboxService mailboxService, FailureDetector 
failureDetector, @Nullable TlsConfig tlsConfig,
+      boolean enableCancellation, Duration cancelTimeout) {
+    _cancelTimeout = cancelTimeout;
     _mailboxService = mailboxService;
     _executorService = Executors.newFixedThreadPool(2 * 
Runtime.getRuntime().availableProcessors(),
         new TracedThreadFactory(Thread.NORM_PRIORITY, false, 
PINOT_BROKER_QUERY_DISPATCHER_FORMAT));
@@ -139,6 +155,10 @@ public class QueryDispatcher {
     _mailboxService.start();
   }
 
+  /// Submits a query to the server and waits for the result.
+  ///
+  /// This method may throw almost any exception but QueryException or 
TimeoutException, which are caught and converted
+  /// into a QueryResult with the error code (and stats, if any can be 
collected).
   public QueryResult submitAndReduce(RequestContext context, 
DispatchableSubPlan dispatchableSubPlan, long timeoutMs,
       Map<String, String> queryOptions)
       throws Exception {
@@ -146,20 +166,52 @@ public class QueryDispatcher {
     Set<QueryServerInstance> servers = new HashSet<>();
     try {
       submit(requestId, dispatchableSubPlan, timeoutMs, servers, queryOptions);
-      try {
-        return runReducer(requestId, dispatchableSubPlan, timeoutMs, 
queryOptions, _mailboxService);
-      } finally {
-        if (isQueryCancellationEnabled()) {
-          _serversByQuery.remove(requestId);
-        }
-      }
+      return runReducer(requestId, dispatchableSubPlan, timeoutMs, 
queryOptions, _mailboxService);
+    } catch (Exception ex) {
+      return tryRecover(context.getRequestId(), servers, ex);
     } catch (Throwable e) {
       // TODO: Consider always cancel when it returns (early terminate)
-      cancel(requestId, servers);
+      cancel(requestId);
       throw e;
+    } finally {
+      if (isQueryCancellationEnabled()) {
+        _serversByQuery.remove(requestId);
+      }
     }
   }
 
+  /// Tries to recover from an exception thrown during query dispatching.
+  ///
+  /// [QueryException] and [TimeoutException] are handled by returning a 
[QueryResult] with the error code and stats,
+  /// while other exceptions are not known, so they are directly rethrown.
+  private QueryResult tryRecover(long requestId, Set<QueryServerInstance> 
servers, Exception ex)
+      throws Exception {
+    if (servers.isEmpty()) {
+      throw ex;
+    }
+    if (ex instanceof ExecutionException && ex.getCause() instanceof 
Exception) {
+      ex = (Exception) ex.getCause();
+    }
+    QueryErrorCode errorCode;
+    if (ex instanceof TimeoutException) {
+      errorCode = QueryErrorCode.EXECUTION_TIMEOUT;
+    } else if (ex instanceof QueryException) {
+      errorCode = ((QueryException) ex).getErrorCode();
+    } else {
+      // in case of unknown exceptions, the exception will be rethrown, so we 
don't need stats
+      cancel(requestId, servers);
+      throw ex;
+    }
+    // in case of known exceptions (timeout or query exception), we need can 
build here the erroneous QueryResult
+    // that include the stats.
+    MultiStageQueryStats stats = cancelWithStats(requestId, servers);
+    if (stats == null) {
+      throw ex;
+    }
+    QueryProcessingException processingException = new 
QueryProcessingException(errorCode, ex.getMessage());
+    return new QueryResult(processingException, stats, 0L);
+  }
+
   public List<PlanNode> explain(RequestContext context, 
DispatchablePlanFragment fragment, long timeoutMs,
       Map<String, String> queryOptions)
       throws TimeoutException, InterruptedException, ExecutionException {
@@ -169,7 +221,7 @@ public class QueryDispatcher {
     Set<DispatchablePlanFragment> plans = Collections.singleton(fragment);
     Set<QueryServerInstance> servers = new HashSet<>();
     try {
-      SendRequest<List<Worker.ExplainResponse>> requestSender = 
DispatchClient::explain;
+      SendRequest<Worker.QueryRequest, List<Worker.ExplainResponse>> 
requestSender = DispatchClient::explain;
       execute(requestId, plans, timeoutMs, queryOptions, requestSender, 
servers, (responses, serverInstance) -> {
         for (Worker.ExplainResponse response : responses) {
           if 
(response.containsMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR))
 {
@@ -206,7 +258,7 @@ public class QueryDispatcher {
       long requestId, DispatchableSubPlan dispatchableSubPlan, long timeoutMs, 
Set<QueryServerInstance> serversOut,
       Map<String, String> queryOptions)
       throws Exception {
-    SendRequest<Worker.QueryResponse> requestSender = DispatchClient::submit;
+    SendRequest<Worker.QueryRequest, Worker.QueryResponse> requestSender = 
DispatchClient::submit;
     Set<DispatchablePlanFragment> plansWithoutRoot = 
dispatchableSubPlan.getQueryStagesWithoutRoot();
     execute(requestId, plansWithoutRoot, timeoutMs, queryOptions, 
requestSender, serversOut,
         (response, serverInstance) -> {
@@ -253,7 +305,7 @@ public class QueryDispatcher {
 
   private <E> void execute(long requestId, Set<DispatchablePlanFragment> 
stagePlans,
       long timeoutMs, Map<String, String> queryOptions,
-      SendRequest<E> sendRequest, Set<QueryServerInstance> serverInstancesOut,
+      SendRequest<Worker.QueryRequest, E> sendRequest, 
Set<QueryServerInstance> serverInstancesOut,
       BiConsumer<E, QueryServerInstance> resultConsumer)
       throws ExecutionException, InterruptedException, TimeoutException {
 
@@ -271,28 +323,39 @@ public class QueryDispatcher {
     ByteString protoRequestMetadata = 
QueryPlanSerDeUtils.toProtoProperties(requestMetadata);
 
     // Submit the query plan to all servers in parallel
-    int numServers = serverInstancesOut.size();
-    BlockingQueue<AsyncResponse<E>> dispatchCallbacks = new 
ArrayBlockingQueue<>(numServers);
+    BlockingQueue<AsyncResponse<E>> dispatchCallbacks = dispatch(sendRequest, 
serverInstancesOut, deadline,
+        serverInstance -> createRequest(serverInstance, stageInfos, 
protoRequestMetadata));
+
+    processResults(requestId, serverInstancesOut.size(), resultConsumer, 
deadline, dispatchCallbacks);
+  }
+
+  private <R, E> BlockingQueue<AsyncResponse<E>> dispatch(SendRequest<R, E> 
sendRequest,
+      Set<QueryServerInstance> serverInstancesOut, Deadline deadline, 
Function<QueryServerInstance, R> requestBuilder) {
+    BlockingQueue<AsyncResponse<E>> dispatchCallbacks = new 
ArrayBlockingQueue<>(serverInstancesOut.size());
 
     for (QueryServerInstance serverInstance : serverInstancesOut) {
       Consumer<AsyncResponse<E>> callbackConsumer = response -> {
         if (!dispatchCallbacks.offer(response)) {
-          LOGGER.warn("Failed to offer response to dispatchCallbacks queue for 
query: {} on server: {}", requestId,
-              serverInstance);
+          LOGGER.warn("Failed to offer response to dispatchCallbacks queue for 
query on server: {}", serverInstance);
         }
       };
-      Worker.QueryRequest requestBuilder = createRequest(serverInstance, 
stageInfos, protoRequestMetadata);
+      R request = requestBuilder.apply(serverInstance);
       DispatchClient dispatchClient = 
getOrCreateDispatchClient(serverInstance);
 
       try {
-        sendRequest.send(dispatchClient, requestBuilder, serverInstance, 
deadline, callbackConsumer);
+        sendRequest.send(dispatchClient, request, serverInstance, deadline, 
callbackConsumer);
       } catch (Throwable t) {
-        LOGGER.warn("Caught exception while dispatching query: {} to server: 
{}", requestId, serverInstance, t);
+        LOGGER.warn("Caught exception while dispatching query to server: {}", 
serverInstance, t);
         callbackConsumer.accept(new AsyncResponse<>(serverInstance, null, t));
         _failureDetector.markServerUnhealthy(serverInstance.getInstanceId());
       }
     }
+    return dispatchCallbacks;
+  }
 
+  private <E> void processResults(long requestId, int numServers, 
BiConsumer<E, QueryServerInstance> resultConsumer,
+      Deadline deadline, BlockingQueue<AsyncResponse<E>> dispatchCallbacks)
+      throws InterruptedException, TimeoutException {
     int numSuccessCalls = 0;
     // TODO: Cancel all dispatched requests if one of the dispatch errors out 
or deadline is breached.
     while (!deadline.isExpired() && numSuccessCalls < numServers) {
@@ -315,6 +378,8 @@ public class QueryDispatcher {
           resultConsumer.accept(response, resp.getServerInstance());
           numSuccessCalls++;
         }
+      } else {
+        LOGGER.info("No response from server for query");
       }
     }
     if (deadline.isExpired()) {
@@ -389,7 +454,7 @@ public class QueryDispatcher {
   private Map<DispatchablePlanFragment, StageInfo> serializePlanFragments(
       Set<DispatchablePlanFragment> stagePlans,
       Set<QueryServerInstance> serverInstances, Deadline deadline)
-      throws InterruptedException, ExecutionException, TimeoutException {
+      throws InterruptedException, ExecutionException {
     List<CompletableFuture<Pair<DispatchablePlanFragment, StageInfo>>> 
stageInfoFutures =
         new ArrayList<>(stagePlans.size());
     for (DispatchablePlanFragment stagePlan : stagePlans) {
@@ -437,13 +502,14 @@ public class QueryDispatcher {
     }
   }
 
+  ///  Cancels a request without waiting for the stats in the response.
   private boolean cancel(long requestId, @Nullable Set<QueryServerInstance> 
servers) {
     if (servers == null) {
       return false;
     }
     for (QueryServerInstance queryServerInstance : servers) {
       try {
-        getOrCreateDispatchClient(queryServerInstance).cancel(requestId);
+        getOrCreateDispatchClient(queryServerInstance).cancelAsync(requestId);
       } catch (Throwable t) {
         LOGGER.warn("Caught exception while cancelling query: {} on server: 
{}", requestId, queryServerInstance, t);
       }
@@ -454,6 +520,43 @@ public class QueryDispatcher {
     return true;
   }
 
+
+  @Nullable
+  private MultiStageQueryStats cancelWithStats(long requestId, @Nullable 
Set<QueryServerInstance> servers) {
+    if (servers == null) {
+      return null;
+    }
+
+    Deadline deadline = Deadline.after(_cancelTimeout.toMillis(), 
TimeUnit.MILLISECONDS);
+    SendRequest<Long, Worker.CancelResponse> sendRequest = 
DispatchClient::cancel;
+    BlockingQueue<AsyncResponse<Worker.CancelResponse>> dispatchCallbacks = 
dispatch(sendRequest, servers, deadline,
+        serverInstance -> requestId);
+
+    MultiStageQueryStats stats = MultiStageQueryStats.emptyStats(0);
+    StatMap<BaseMailboxReceiveOperator.StatKey> rootStats = new 
StatMap<>(BaseMailboxReceiveOperator.StatKey.class);
+    
stats.getCurrentStats().addLastOperator(MultiStageOperator.Type.MAILBOX_RECEIVE,
 rootStats);
+    try {
+      processResults(requestId, servers.size(), (response, server) -> {
+        Map<Integer, ByteString> statsByStage = response.getStatsByStageMap();
+        for (Map.Entry<Integer, ByteString> entry : statsByStage.entrySet()) {
+          try (InputStream is = entry.getValue().newInput();
+              DataInputStream dis = new DataInputStream(is)) {
+            MultiStageQueryStats.StageStats.Closed closed = 
MultiStageQueryStats.StageStats.Closed.deserialize(dis);
+            stats.mergeUpstream(entry.getKey(), closed);
+          } catch (Exception e) {
+            LOGGER.debug("Caught exception while deserializing stats on 
server: {}", server, e);
+          }
+        }
+      }, deadline, dispatchCallbacks);
+      return stats;
+    } catch (InterruptedException e) {
+      throw QueryErrorCode.INTERNAL.asException("Interrupted while waiting for 
cancel response", e);
+    } catch (TimeoutException e) {
+      LOGGER.debug("Timed out waiting for cancel response", e);
+      return stats;
+    }
+  }
+
   private DispatchClient getOrCreateDispatchClient(QueryServerInstance 
queryServerInstance) {
     String hostname = queryServerInstance.getHostname();
     int port = queryServerInstance.getQueryServicePort();
@@ -635,10 +738,16 @@ public class QueryDispatcher {
   }
 
   public static class QueryResult {
+    @Nullable
     private final ResultTable _resultTable;
+    @Nullable
+    private final QueryProcessingException _processingException;
     private final List<MultiStageQueryStats.StageStats.Closed> _queryStats;
     private final long _brokerReduceTimeMs;
 
+    /**
+     * Creates a successful query result.
+     */
     public QueryResult(ResultTable resultTable, MultiStageQueryStats 
queryStats, long brokerReduceTimeMs) {
       _resultTable = resultTable;
       Preconditions.checkArgument(queryStats.getCurrentStageId() == 0, 
"Expecting query stats for stage 0, got: %s",
@@ -650,12 +759,39 @@ public class QueryDispatcher {
         _queryStats.add(queryStats.getUpstreamStageStats(i));
       }
       _brokerReduceTimeMs = brokerReduceTimeMs;
+      _processingException = null;
+    }
+
+    /**
+     * Creates a failed query result.
+     * @param processingException the exception that occurred during query 
processing
+     * @param queryStats the query stats, which may be empty
+     */
+    public QueryResult(QueryProcessingException processingException, 
MultiStageQueryStats queryStats,
+        long brokerReduceTimeMs) {
+      _processingException = processingException;
+      _resultTable = null;
+      _brokerReduceTimeMs = brokerReduceTimeMs;
+      Preconditions.checkArgument(queryStats.getCurrentStageId() == 0, 
"Expecting query stats for stage 0, got: %s",
+          queryStats.getCurrentStageId());
+      int numStages = queryStats.getMaxStageId() + 1;
+      _queryStats = new ArrayList<>(numStages);
+      _queryStats.add(queryStats.getCurrentStats().close());
+      for (int i = 1; i < numStages; i++) {
+        _queryStats.add(queryStats.getUpstreamStageStats(i));
+      }
     }
 
+    @Nullable
     public ResultTable getResultTable() {
       return _resultTable;
     }
 
+    @Nullable
+    public QueryProcessingException getProcessingException() {
+      return _processingException;
+    }
+
     public List<MultiStageQueryStats.StageStats.Closed> getQueryStats() {
       return _queryStats;
     }
@@ -665,8 +801,8 @@ public class QueryDispatcher {
     }
   }
 
-  private interface SendRequest<E> {
-    void send(DispatchClient dispatchClient, Worker.QueryRequest request, 
QueryServerInstance serverInstance,
+  private interface SendRequest<R, E> {
+    void send(DispatchClient dispatchClient, R request, QueryServerInstance 
serverInstance,
         Deadline deadline, Consumer<AsyncResponse<E>> callbackConsumer);
   }
 }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
index 4242941359..6ac7ca5598 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
@@ -20,10 +20,12 @@ package org.apache.pinot.query.service.server;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.protobuf.ByteString;
+import com.google.protobuf.UnsafeByteOperations;
 import io.grpc.Server;
 import io.grpc.ServerBuilder;
 import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
 import io.grpc.stub.StreamObserver;
+import java.io.DataOutputStream;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
@@ -36,6 +38,7 @@ import java.util.concurrent.TimeoutException;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import javax.annotation.Nullable;
+import org.apache.commons.io.output.UnsynchronizedByteArrayOutputStream;
 import org.apache.pinot.common.config.TlsConfig;
 import org.apache.pinot.common.metrics.ServerMeter;
 import org.apache.pinot.common.metrics.ServerMetrics;
@@ -50,6 +53,7 @@ import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.StagePlan;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.QueryRunner;
+import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
 import org.apache.pinot.query.service.dispatch.QueryDispatcher;
 import org.apache.pinot.spi.accounting.ThreadExecutionContext;
 import org.apache.pinot.spi.env.PinotConfiguration;
@@ -264,7 +268,21 @@ public class QueryServer extends 
PinotQueryWorkerGrpc.PinotQueryWorkerImplBase {
     long requestId = request.getRequestId();
     try (QueryThreadContext.CloseableContext closeable = 
QueryThreadContext.open()) {
       QueryThreadContext.setIds(requestId, request.getCid().isBlank() ? 
request.getCid() : Long.toString(requestId));
-      _queryRunner.cancel(requestId);
+      Map<Integer, MultiStageQueryStats.StageStats.Closed> stats = 
_queryRunner.cancel(requestId);
+
+      Worker.CancelResponse.Builder cancelBuilder = 
Worker.CancelResponse.newBuilder();
+      for (Map.Entry<Integer, MultiStageQueryStats.StageStats.Closed> 
statEntry : stats.entrySet()) {
+        try (UnsynchronizedByteArrayOutputStream baos = new 
UnsynchronizedByteArrayOutputStream.Builder().get();
+            DataOutputStream daos = new DataOutputStream(baos)) {
+          statEntry.getValue().serialize(daos);
+
+          daos.flush();
+          byte[] byteArray = baos.toByteArray();
+          ByteString bytes = UnsafeByteOperations.unsafeWrap(byteArray);
+          cancelBuilder.putStatsByStage(statEntry.getKey(), bytes);
+        }
+      }
+      responseObserver.onNext(cancelBuilder.build());
     } catch (Throwable t) {
       LOGGER.error("Caught exception while cancelling opChain for request: 
{}", requestId, t);
     }
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
index 9ce3ffaae1..b38ef22893 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
@@ -32,6 +32,7 @@ import org.apache.pinot.query.runtime.blocks.ErrorMseBlock;
 import org.apache.pinot.query.runtime.blocks.SuccessMseBlock;
 import org.apache.pinot.query.runtime.operator.MultiStageOperator;
 import org.apache.pinot.query.runtime.operator.OpChain;
+import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.apache.pinot.spi.executor.ExecutorServiceUtils;
 import org.mockito.Mockito;
@@ -174,6 +175,7 @@ public class OpChainSchedulerServiceTest {
       cancelLatch.countDown();
       return null;
     }).when(_operatorA).cancel(Mockito.any());
+    Mockito.doAnswer(inv -> 
MultiStageQueryStats.emptyStats(1)).when(_operatorA).calculateStats();
 
     schedulerService.register(opChain);
 
diff --git 
a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java 
b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
index b7a179dcf0..11e8f0d71e 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
@@ -1552,7 +1552,6 @@ public class CommonConstants {
     /// running 1.3.0 may fail, which breaks backward compatibility.
     public static final String KEY_OF_SEND_STATS_MODE = 
"pinot.query.mse.stats.mode";
     public static final String DEFAULT_SEND_STATS_MODE = "SAFE";
-
     public enum JoinOverFlowMode {
       THROW, BREAK
     }
@@ -1577,6 +1576,18 @@ public class CommonConstants {
     public static final String KEY_OF_MULTISTAGE_EXPLAIN_INCLUDE_SEGMENT_PLAN =
         "pinot.query.multistage.explain.include.segment.plan";
     public static final boolean 
DEFAULT_OF_MULTISTAGE_EXPLAIN_INCLUDE_SEGMENT_PLAN = false;
+
+    /// Max number of rows operators stored in the op stats cache.
+    /// Although the cache stores stages, each entry has a weight equal to the 
number of operators in the stage.
+    public static final String KEY_OF_OP_STATS_CACHE_SIZE = 
"pinot.server.query.op.stats.cache.size";
+    public static final int DEFAULT_OF_OP_STATS_CACHE_SIZE = 1000;
+
+    /// Max time to keep the op stats in the cache.
+    public static final String KEY_OF_OP_STATS_CACHE_EXPIRE_MS = 
"pinot.server.query.op.stats.cache.ms";
+    public static final int DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS = 60 * 1000;
+    /// Timeout of the cancel request, in milliseconds.
+    public static final String KEY_OF_CANCEL_TIMEOUT_MS = 
"pinot.server.query.cancel.timeout.ms";
+    public static final long DEFAULT_OF_CANCEL_TIMEOUT_MS = 1000;
   }
 
   public static class NullValuePlaceHolder {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org


Reply via email to