This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 1dc13206c52 Enhance LeafOperator (#16582)
1dc13206c52 is described below

commit 1dc13206c524341fc9b6eda4ef79b3b2e17eed4f
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Wed Aug 13 13:32:28 2025 -0600

    Enhance LeafOperator (#16582)
---
 .../org/apache/pinot/common/datatable/StatMap.java |   1 +
 .../apache/pinot/query/runtime/QueryRunner.java    |   4 -
 .../pinot/query/runtime/operator/LeafOperator.java | 588 +++++++++++----------
 .../pinot/query/service/server/QueryServer.java    |   5 -
 .../query/runtime/operator/LeafOperatorTest.java   | 103 +++-
 .../query/runtime/operator/OperatorTestUtil.java   |  35 +-
 6 files changed, 425 insertions(+), 311 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/datatable/StatMap.java 
b/pinot-common/src/main/java/org/apache/pinot/common/datatable/StatMap.java
index fcfd6dcaaa9..55a8d9bdc23 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/datatable/StatMap.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/datatable/StatMap.java
@@ -58,6 +58,7 @@ public class StatMap<K extends Enum<K> & StatMap.Key> {
   public StatMap(Class<K> keyClass) {
     _keyClass = keyClass;
     // TODO: Study whether this is fine or we should impose a single thread 
policy in StatMaps
+    // TODO: We might need to synchronize the methods because some methods 
access the map multiple times
     _map = Collections.synchronizedMap(new EnumMap<>(keyClass));
   }
 
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
index d1417cb9964..4d21dd8ae41 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
@@ -88,7 +88,6 @@ import org.apache.pinot.spi.utils.CommonConstants;
 import 
org.apache.pinot.spi.utils.CommonConstants.Broker.Request.QueryOptionKey;
 import 
org.apache.pinot.spi.utils.CommonConstants.MultiStageQueryRunner.JoinOverFlowMode;
 import 
org.apache.pinot.spi.utils.CommonConstants.MultiStageQueryRunner.WindowOverFlowMode;
-import org.apache.pinot.spi.utils.CommonConstants.Query.Request.MetadataKeys;
 import org.apache.pinot.spi.utils.CommonConstants.Server;
 import org.apache.pinot.sql.parsers.rewriter.RlsUtils;
 import 
org.apache.pinot.tsdb.planner.TimeSeriesPlanConstants.WorkerRequestMetadataKeys;
@@ -532,9 +531,6 @@ public class QueryRunner {
       LOGGER.debug("Explain query on intermediate stages is a NOOP");
       return stagePlan;
     }
-    long requestId = 
Long.parseLong(requestMetadata.get(MetadataKeys.REQUEST_ID));
-    long timeoutMs = 
Long.parseLong(requestMetadata.get(QueryOptionKey.TIMEOUT_MS));
-    long deadlineMs = System.currentTimeMillis() + timeoutMs;
 
     StageMetadata stageMetadata = stagePlan.getStageMetadata();
     Map<String, String> opChainMetadata = 
consolidateMetadata(stageMetadata.getCustomProperties(), requestMetadata);
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafOperator.java
index fab26d033b0..157bcf689e6 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafOperator.java
@@ -23,7 +23,6 @@ import com.google.common.base.Preconditions;
 import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ArrayBlockingQueue;
@@ -33,6 +32,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicReference;
 import javax.annotation.Nullable;
 import org.apache.pinot.common.datatable.DataTable;
 import org.apache.pinot.common.datatable.StatMap;
@@ -52,7 +52,6 @@ import org.apache.pinot.core.query.executor.QueryExecutor;
 import org.apache.pinot.core.query.executor.ResultsBlockStreamer;
 import org.apache.pinot.core.query.logger.ServerQueryLogger;
 import org.apache.pinot.core.query.request.ServerQueryRequest;
-import org.apache.pinot.core.query.request.context.ExplainMode;
 import org.apache.pinot.core.query.request.context.QueryContext;
 import org.apache.pinot.query.planner.plannode.ExplainedNode;
 import org.apache.pinot.query.planner.plannode.PlanNode;
@@ -63,6 +62,7 @@ import org.apache.pinot.query.runtime.blocks.SuccessMseBlock;
 import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.apache.pinot.spi.accounting.ThreadExecutionContext;
+import org.apache.pinot.spi.exception.EarlyTerminationException;
 import org.apache.pinot.spi.exception.QueryErrorCode;
 import org.apache.pinot.spi.trace.Tracing;
 import 
org.apache.pinot.spi.utils.CommonConstants.Broker.Request.QueryOptionValue;
@@ -77,22 +77,30 @@ import org.slf4j.LoggerFactory;
 public class LeafOperator extends MultiStageOperator {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(LeafOperator.class);
   private static final String EXPLAIN_NAME = "LEAF";
+  private static final ErrorMseBlock CANCELLED_BLOCK =
+      ErrorMseBlock.fromError(QueryErrorCode.QUERY_CANCELLATION, "Cancelled 
while waiting for leaf results");
+  private static final ErrorMseBlock TIMEOUT_BLOCK =
+      ErrorMseBlock.fromError(QueryErrorCode.EXECUTION_TIMEOUT, "Timed out 
waiting for leaf results");
 
   // Use a special results block to indicate that this is the last results 
block
-  private static final MetadataResultsBlock LAST_RESULTS_BLOCK = new 
MetadataResultsBlock();
+  @VisibleForTesting
+  static final MetadataResultsBlock LAST_RESULTS_BLOCK = new 
MetadataResultsBlock();
 
   private final List<ServerQueryRequest> _requests;
   private final DataSchema _dataSchema;
   private final QueryExecutor _queryExecutor;
   private final ExecutorService _executorService;
+  private final String _tableName;
+  private final StatMap<StatKey> _statMap = new StatMap<>(StatKey.class);
+  private final AtomicReference<ErrorMseBlock> _errorBlock = new 
AtomicReference<>();
 
   // Use a limit-sized BlockingQueue to store the results blocks and apply 
back pressure to the single-stage threads
-  private final BlockingQueue<BaseResultsBlock> _blockingQueue;
+  @VisibleForTesting
+  final BlockingQueue<BaseResultsBlock> _blockingQueue;
 
   @Nullable
-  private volatile Future<Void> _executionFuture;
-  private volatile Map<Integer, String> _exceptions;
-  private final StatMap<StatKey> _statMap = new StatMap<>(StatKey.class);
+  private volatile Future<?> _executionFuture;
+  private volatile boolean _terminated;
 
   public LeafOperator(OpChainExecutionContext context, 
List<ServerQueryRequest> requests, DataSchema dataSchema,
       QueryExecutor queryExecutor, ExecutorService executorService) {
@@ -103,11 +111,12 @@ public class LeafOperator extends MultiStageOperator {
     _dataSchema = dataSchema;
     _queryExecutor = queryExecutor;
     _executorService = executorService;
+    _tableName = context.getStageMetadata().getTableName();
+    Preconditions.checkArgument(_tableName != null, "Table name must be set in 
the stage metadata");
+    _statMap.merge(StatKey.TABLE, _tableName);
     Integer maxStreamingPendingBlocks = 
QueryOptionsUtils.getMaxStreamingPendingBlocks(context.getOpChainMetadata());
     _blockingQueue = new ArrayBlockingQueue<>(maxStreamingPendingBlocks != 
null ? maxStreamingPendingBlocks
         : QueryOptionValue.DEFAULT_MAX_STREAMING_PENDING_BLOCKS);
-    String tableName = 
context.getLeafStageContext().getStagePlan().getStageMetadata().getTableName();
-    _statMap.merge(StatKey.TABLE, tableName);
   }
 
   public List<ServerQueryRequest> getRequests() {
@@ -136,7 +145,7 @@ public class LeafOperator extends MultiStageOperator {
 
   @Override
   public List<MultiStageOperator> getChildOperators() {
-    return Collections.emptyList();
+    return List.of();
   }
 
   @Override
@@ -145,27 +154,35 @@ public class LeafOperator extends MultiStageOperator {
   }
 
   @Override
-  protected MseBlock getNextBlock()
-      throws InterruptedException, TimeoutException {
+  protected MseBlock getNextBlock() {
     if (_executionFuture == null) {
       _executionFuture = startExecution();
     }
     if (_isEarlyTerminated) {
+      terminateAndClearResultsBlocks();
       return SuccessMseBlock.INSTANCE;
     }
-    // Here we use passive deadline because we end up waiting for the SSE 
operators
-    // which can timeout by their own
-    BaseResultsBlock resultsBlock =
-        _blockingQueue.poll(_context.getPassiveDeadlineMs() - 
System.currentTimeMillis(), TimeUnit.MILLISECONDS);
+    BaseResultsBlock resultsBlock;
+    try {
+      // Here we use passive deadline because we end up waiting for the SSE 
operators which can timeout by their own.
+      resultsBlock =
+          _blockingQueue.poll(_context.getPassiveDeadlineMs() - 
System.currentTimeMillis(), TimeUnit.MILLISECONDS);
+    } catch (InterruptedException e) {
+      terminateAndClearResultsBlocks();
+      return CANCELLED_BLOCK;
+    }
     if (resultsBlock == null) {
-      throw new TimeoutException("Timed out waiting for results block");
+      terminateAndClearResultsBlocks();
+      return TIMEOUT_BLOCK;
     }
-    // Terminate when receiving exception block
-    Map<Integer, String> exceptions = _exceptions;
-    if (exceptions != null) {
-      return ErrorMseBlock.fromMap(QueryErrorCode.fromKeyMap(exceptions));
+    // Terminate when there is error block
+    ErrorMseBlock errorBlock = getErrorBlock();
+    if (errorBlock != null) {
+      terminateAndClearResultsBlocks();
+      return errorBlock;
     }
     if (resultsBlock == LAST_RESULTS_BLOCK) {
+      _terminated = true;
       return SuccessMseBlock.INSTANCE;
     } else {
       // Regular data block
@@ -173,197 +190,195 @@ public class LeafOperator extends MultiStageOperator {
     }
   }
 
-  @Override
-  protected StatMap<?> copyStatMaps() {
-    return new StatMap<>(_statMap);
-  }
-
-  @Override
-  protected void earlyTerminate() {
-    super.earlyTerminate();
-    cancelSseTasks();
-  }
-
-  @Override
-  public void cancel(Throwable e) {
-    super.cancel(e);
-    cancelSseTasks();
-  }
-
-  @VisibleForTesting
-  protected void cancelSseTasks() {
-    Future<Void> executionFuture = _executionFuture;
-    if (executionFuture != null) {
-      executionFuture.cancel(true);
-    }
-  }
-
-  private void mergeExecutionStats(@Nullable Map<String, String> 
executionStats) {
-    if (executionStats != null) {
-      for (Map.Entry<String, String> entry : executionStats.entrySet()) {
-        DataTable.MetadataKey key = 
DataTable.MetadataKey.getByName(entry.getKey());
-        if (key == null) {
-          LOGGER.debug("Skipping unknown execution stat: {}", entry.getKey());
-          continue;
-        }
-        switch (key) {
-          case UNKNOWN:
-            LOGGER.debug("Skipping unknown execution stat: {}", 
entry.getKey());
-            break;
-          case TABLE:
-            _statMap.merge(StatKey.TABLE, entry.getValue());
-            break;
-          case NUM_DOCS_SCANNED:
-            _statMap.merge(StatKey.NUM_DOCS_SCANNED, 
Long.parseLong(entry.getValue()));
-            break;
-          case NUM_ENTRIES_SCANNED_IN_FILTER:
-            _statMap.merge(StatKey.NUM_ENTRIES_SCANNED_IN_FILTER, 
Long.parseLong(entry.getValue()));
-            break;
-          case NUM_ENTRIES_SCANNED_POST_FILTER:
-            _statMap.merge(StatKey.NUM_ENTRIES_SCANNED_POST_FILTER, 
Long.parseLong(entry.getValue()));
-            break;
-          case NUM_SEGMENTS_QUERIED:
-            _statMap.merge(StatKey.NUM_SEGMENTS_QUERIED, 
Integer.parseInt(entry.getValue()));
-            break;
-          case NUM_SEGMENTS_PROCESSED:
-            _statMap.merge(StatKey.NUM_SEGMENTS_PROCESSED, 
Integer.parseInt(entry.getValue()));
-            break;
-          case NUM_SEGMENTS_MATCHED:
-            _statMap.merge(StatKey.NUM_SEGMENTS_MATCHED, 
Integer.parseInt(entry.getValue()));
-            break;
-          case NUM_CONSUMING_SEGMENTS_QUERIED:
-            _statMap.merge(StatKey.NUM_CONSUMING_SEGMENTS_QUERIED, 
Integer.parseInt(entry.getValue()));
-            break;
-          case MIN_CONSUMING_FRESHNESS_TIME_MS:
-            _statMap.merge(StatKey.MIN_CONSUMING_FRESHNESS_TIME_MS, 
Long.parseLong(entry.getValue()));
-            break;
-          case TOTAL_DOCS:
-            _statMap.merge(StatKey.TOTAL_DOCS, 
Long.parseLong(entry.getValue()));
-            break;
-          case GROUPS_TRIMMED:
-            _statMap.merge(StatKey.GROUPS_TRIMMED, 
Boolean.parseBoolean(entry.getValue()));
-            break;
-          case NUM_GROUPS_LIMIT_REACHED:
-            _statMap.merge(StatKey.NUM_GROUPS_LIMIT_REACHED, 
Boolean.parseBoolean(entry.getValue()));
-            break;
-          case NUM_GROUPS_WARNING_LIMIT_REACHED:
-            _statMap.merge(StatKey.NUM_GROUPS_WARNING_LIMIT_REACHED, 
Boolean.parseBoolean(entry.getValue()));
-            break;
-          case TIME_USED_MS:
-            _statMap.merge(StatKey.EXECUTION_TIME_MS, 
Long.parseLong(entry.getValue()));
-            break;
-          case TRACE_INFO:
-            LOGGER.debug("Skipping trace info: {}", entry.getValue());
-            break;
-          case REQUEST_ID:
-            LOGGER.debug("Skipping request ID: {}", entry.getValue());
-            break;
-          case NUM_RESIZES:
-            _statMap.merge(StatKey.NUM_RESIZES, 
Integer.parseInt(entry.getValue()));
-            break;
-          case RESIZE_TIME_MS:
-            _statMap.merge(StatKey.RESIZE_TIME_MS, 
Long.parseLong(entry.getValue()));
-            break;
-          case THREAD_CPU_TIME_NS:
-            _statMap.merge(StatKey.THREAD_CPU_TIME_NS, 
Long.parseLong(entry.getValue()));
-            break;
-          case SYSTEM_ACTIVITIES_CPU_TIME_NS:
-            _statMap.merge(StatKey.SYSTEM_ACTIVITIES_CPU_TIME_NS, 
Long.parseLong(entry.getValue()));
-            break;
-          case RESPONSE_SER_CPU_TIME_NS:
-            _statMap.merge(StatKey.RESPONSE_SER_CPU_TIME_NS, 
Long.parseLong(entry.getValue()));
-            break;
-          case THREAD_MEM_ALLOCATED_BYTES:
-            _statMap.merge(StatKey.THREAD_MEM_ALLOCATED_BYTES, 
Long.parseLong(entry.getValue()));
-            break;
-          case RESPONSE_SER_MEM_ALLOCATED_BYTES:
-            _statMap.merge(StatKey.RESPONSE_SER_MEM_ALLOCATED_BYTES, 
Long.parseLong(entry.getValue()));
-            break;
-          case NUM_SEGMENTS_PRUNED_BY_SERVER:
-            _statMap.merge(StatKey.NUM_SEGMENTS_PRUNED_BY_SERVER, 
Integer.parseInt(entry.getValue()));
-            break;
-          case NUM_SEGMENTS_PRUNED_INVALID:
-            _statMap.merge(StatKey.NUM_SEGMENTS_PRUNED_INVALID, 
Integer.parseInt(entry.getValue()));
-            break;
-          case NUM_SEGMENTS_PRUNED_BY_LIMIT:
-            _statMap.merge(StatKey.NUM_SEGMENTS_PRUNED_BY_LIMIT, 
Integer.parseInt(entry.getValue()));
-            break;
-          case NUM_SEGMENTS_PRUNED_BY_VALUE:
-            _statMap.merge(StatKey.NUM_SEGMENTS_PRUNED_BY_VALUE, 
Integer.parseInt(entry.getValue()));
-            break;
-          case EXPLAIN_PLAN_NUM_EMPTY_FILTER_SEGMENTS:
-            LOGGER.debug("Skipping empty filter segments: {}", 
entry.getValue());
-            break;
-          case EXPLAIN_PLAN_NUM_MATCH_ALL_FILTER_SEGMENTS:
-            LOGGER.debug("Skipping match all filter segments: {}", 
entry.getValue());
-            break;
-          case NUM_CONSUMING_SEGMENTS_PROCESSED:
-            _statMap.merge(StatKey.NUM_CONSUMING_SEGMENTS_PROCESSED, 
Integer.parseInt(entry.getValue()));
-            break;
-          case NUM_CONSUMING_SEGMENTS_MATCHED:
-            _statMap.merge(StatKey.NUM_CONSUMING_SEGMENTS_MATCHED, 
Integer.parseInt(entry.getValue()));
-            break;
-          case SORTED:
-            break;
-          default: {
-            throw new IllegalArgumentException("Unhandled V1 execution stat: " 
+ entry.getKey());
-          }
-        }
-      }
-    }
-  }
-
   public ExplainedNode explain() {
-    Preconditions.checkState(
-        _requests.stream().allMatch(request -> 
request.getQueryContext().getExplain() == ExplainMode.NODE),
-        "All requests must have explain mode set to ExplainMode.NODE");
-
     if (_executionFuture == null) {
       _executionFuture = startExecution();
     }
-
     List<PlanNode> childNodes = new ArrayList<>();
     while (true) {
+      if (_isEarlyTerminated) {
+        terminateAndClearResultsBlocks();
+        break;
+      }
       BaseResultsBlock resultsBlock;
       try {
-        // Here we could use active or passive, given we don't actually 
execute anything
-        long timeout = _context.getPassiveDeadlineMs() - 
System.currentTimeMillis();
-        resultsBlock = _blockingQueue.poll(timeout, TimeUnit.MILLISECONDS);
+        // Here we use passive deadline because we end up waiting for the SSE 
operators which can timeout by their own.
+        resultsBlock =
+            _blockingQueue.poll(_context.getPassiveDeadlineMs() - 
System.currentTimeMillis(), TimeUnit.MILLISECONDS);
       } catch (InterruptedException e) {
+        terminateAndClearResultsBlocks();
         Thread.currentThread().interrupt();
         throw new RuntimeException("Interrupted while waiting for results 
block", e);
       }
       if (resultsBlock == null) {
+        terminateAndClearResultsBlocks();
         throw new RuntimeException("Timed out waiting for results block");
       }
-      // Terminate when receiving exception block
-      Map<Integer, String> exceptions = _exceptions;
-      if (exceptions != null) {
-        throw new RuntimeException("Received exception block: " + exceptions);
+      // Terminate when there is error block
+      ErrorMseBlock errorBlock = getErrorBlock();
+      if (errorBlock != null) {
+        terminateAndClearResultsBlocks();
+        throw new RuntimeException("Received error block: " + 
errorBlock.getErrorMessages());
       }
-      if (_isEarlyTerminated || resultsBlock == LAST_RESULTS_BLOCK) {
+      if (resultsBlock == LAST_RESULTS_BLOCK) {
+        _terminated = true;
         break;
-      } else if (!(resultsBlock instanceof ExplainV2ResultBlock)) {
-        throw new IllegalArgumentException("Expected ExplainV2ResultBlock, 
got: " + resultsBlock.getClass().getName());
-      } else {
+      }
+      if (resultsBlock instanceof ExplainV2ResultBlock) {
         ExplainV2ResultBlock block = (ExplainV2ResultBlock) resultsBlock;
         for (ExplainInfo physicalPlan : block.getPhysicalPlans()) {
           childNodes.add(asNode(physicalPlan));
         }
+      } else {
+        terminateAndClearResultsBlocks();
+        throw new IllegalArgumentException("Expected ExplainV2ResultBlock, 
got: " + resultsBlock.getClass().getName());
       }
     }
-    String tableName = _context.getStageMetadata().getTableName();
-    Map<String, Plan.ExplainNode.AttributeValue> attributes;
-    if (tableName == null) { // this should never happen, but let's be 
paranoid to never fail
-      attributes = Collections.emptyMap();
-    } else {
-      attributes =
-          Collections.singletonMap("table", 
Plan.ExplainNode.AttributeValue.newBuilder().setString(tableName).build());
-    }
+    Map<String, Plan.ExplainNode.AttributeValue> attributes =
+        Map.of("table", 
Plan.ExplainNode.AttributeValue.newBuilder().setString(_tableName).build());
     return new ExplainedNode(_context.getStageId(), _dataSchema, null, 
childNodes, "LeafStageCombineOperator",
         attributes);
   }
 
+  @Override
+  protected StatMap<?> copyStatMaps() {
+    return new StatMap<>(_statMap);
+  }
+
+  @Override
+  protected void earlyTerminate() {
+    _isEarlyTerminated = true;
+    cancelSseTasks();
+  }
+
+  @Override
+  public void cancel(Throwable e) {
+    cancelSseTasks();
+  }
+
+  @Override
+  public void close() {
+    cancelSseTasks();
+  }
+
+  @VisibleForTesting
+  void cancelSseTasks() {
+    Future<?> executionFuture = _executionFuture;
+    if (executionFuture != null) {
+      executionFuture.cancel(true);
+    }
+  }
+
+  private synchronized void mergeExecutionStats(Map<String, String> 
executionStats) {
+    for (Map.Entry<String, String> entry : executionStats.entrySet()) {
+      String key = entry.getKey();
+      DataTable.MetadataKey metadataKey = DataTable.MetadataKey.getByName(key);
+      if (metadataKey == null || metadataKey == DataTable.MetadataKey.UNKNOWN) 
{
+        LOGGER.debug("Skipping unknown execution stat: {}", key);
+        continue;
+      }
+      switch (metadataKey) {
+        case TABLE:
+          _statMap.merge(StatKey.TABLE, entry.getValue());
+          break;
+        case NUM_DOCS_SCANNED:
+          _statMap.merge(StatKey.NUM_DOCS_SCANNED, 
Long.parseLong(entry.getValue()));
+          break;
+        case NUM_ENTRIES_SCANNED_IN_FILTER:
+          _statMap.merge(StatKey.NUM_ENTRIES_SCANNED_IN_FILTER, 
Long.parseLong(entry.getValue()));
+          break;
+        case NUM_ENTRIES_SCANNED_POST_FILTER:
+          _statMap.merge(StatKey.NUM_ENTRIES_SCANNED_POST_FILTER, 
Long.parseLong(entry.getValue()));
+          break;
+        case NUM_SEGMENTS_QUERIED:
+          _statMap.merge(StatKey.NUM_SEGMENTS_QUERIED, 
Integer.parseInt(entry.getValue()));
+          break;
+        case NUM_SEGMENTS_PROCESSED:
+          _statMap.merge(StatKey.NUM_SEGMENTS_PROCESSED, 
Integer.parseInt(entry.getValue()));
+          break;
+        case NUM_SEGMENTS_MATCHED:
+          _statMap.merge(StatKey.NUM_SEGMENTS_MATCHED, 
Integer.parseInt(entry.getValue()));
+          break;
+        case NUM_CONSUMING_SEGMENTS_QUERIED:
+          _statMap.merge(StatKey.NUM_CONSUMING_SEGMENTS_QUERIED, 
Integer.parseInt(entry.getValue()));
+          break;
+        case MIN_CONSUMING_FRESHNESS_TIME_MS:
+          _statMap.merge(StatKey.MIN_CONSUMING_FRESHNESS_TIME_MS, 
Long.parseLong(entry.getValue()));
+          break;
+        case TOTAL_DOCS:
+          _statMap.merge(StatKey.TOTAL_DOCS, Long.parseLong(entry.getValue()));
+          break;
+        case GROUPS_TRIMMED:
+          _statMap.merge(StatKey.GROUPS_TRIMMED, 
Boolean.parseBoolean(entry.getValue()));
+          break;
+        case NUM_GROUPS_LIMIT_REACHED:
+          _statMap.merge(StatKey.NUM_GROUPS_LIMIT_REACHED, 
Boolean.parseBoolean(entry.getValue()));
+          break;
+        case NUM_GROUPS_WARNING_LIMIT_REACHED:
+          _statMap.merge(StatKey.NUM_GROUPS_WARNING_LIMIT_REACHED, 
Boolean.parseBoolean(entry.getValue()));
+          break;
+        case TIME_USED_MS:
+          _statMap.merge(StatKey.EXECUTION_TIME_MS, 
Long.parseLong(entry.getValue()));
+          break;
+        case TRACE_INFO:
+          LOGGER.debug("Skipping trace info: {}", entry.getValue());
+          break;
+        case REQUEST_ID:
+          LOGGER.debug("Skipping request ID: {}", entry.getValue());
+          break;
+        case NUM_RESIZES:
+          _statMap.merge(StatKey.NUM_RESIZES, 
Integer.parseInt(entry.getValue()));
+          break;
+        case RESIZE_TIME_MS:
+          _statMap.merge(StatKey.RESIZE_TIME_MS, 
Long.parseLong(entry.getValue()));
+          break;
+        case THREAD_CPU_TIME_NS:
+          _statMap.merge(StatKey.THREAD_CPU_TIME_NS, 
Long.parseLong(entry.getValue()));
+          break;
+        case SYSTEM_ACTIVITIES_CPU_TIME_NS:
+          _statMap.merge(StatKey.SYSTEM_ACTIVITIES_CPU_TIME_NS, 
Long.parseLong(entry.getValue()));
+          break;
+        case RESPONSE_SER_CPU_TIME_NS:
+          _statMap.merge(StatKey.RESPONSE_SER_CPU_TIME_NS, 
Long.parseLong(entry.getValue()));
+          break;
+        case THREAD_MEM_ALLOCATED_BYTES:
+          _statMap.merge(StatKey.THREAD_MEM_ALLOCATED_BYTES, 
Long.parseLong(entry.getValue()));
+          break;
+        case RESPONSE_SER_MEM_ALLOCATED_BYTES:
+          _statMap.merge(StatKey.RESPONSE_SER_MEM_ALLOCATED_BYTES, 
Long.parseLong(entry.getValue()));
+          break;
+        case NUM_SEGMENTS_PRUNED_BY_SERVER:
+          _statMap.merge(StatKey.NUM_SEGMENTS_PRUNED_BY_SERVER, 
Integer.parseInt(entry.getValue()));
+          break;
+        case NUM_SEGMENTS_PRUNED_INVALID:
+          _statMap.merge(StatKey.NUM_SEGMENTS_PRUNED_INVALID, 
Integer.parseInt(entry.getValue()));
+          break;
+        case NUM_SEGMENTS_PRUNED_BY_LIMIT:
+          _statMap.merge(StatKey.NUM_SEGMENTS_PRUNED_BY_LIMIT, 
Integer.parseInt(entry.getValue()));
+          break;
+        case NUM_SEGMENTS_PRUNED_BY_VALUE:
+          _statMap.merge(StatKey.NUM_SEGMENTS_PRUNED_BY_VALUE, 
Integer.parseInt(entry.getValue()));
+          break;
+        case EXPLAIN_PLAN_NUM_EMPTY_FILTER_SEGMENTS:
+          LOGGER.debug("Skipping empty filter segments: {}", entry.getValue());
+          break;
+        case EXPLAIN_PLAN_NUM_MATCH_ALL_FILTER_SEGMENTS:
+          LOGGER.debug("Skipping match all filter segments: {}", 
entry.getValue());
+          break;
+        case NUM_CONSUMING_SEGMENTS_PROCESSED:
+          _statMap.merge(StatKey.NUM_CONSUMING_SEGMENTS_PROCESSED, 
Integer.parseInt(entry.getValue()));
+          break;
+        case NUM_CONSUMING_SEGMENTS_MATCHED:
+          _statMap.merge(StatKey.NUM_CONSUMING_SEGMENTS_MATCHED, 
Integer.parseInt(entry.getValue()));
+          break;
+        case SORTED:
+          break;
+        default:
+          throw new IllegalArgumentException("Unhandled leaf execution stat: " 
+ key);
+      }
+    }
+  }
+
   private ExplainedNode asNode(ExplainInfo info) {
     int size = info.getInputs().size();
     List<PlanNode> inputs = new ArrayList<>(size);
@@ -374,114 +389,155 @@ public class LeafOperator extends MultiStageOperator {
     return new ExplainedNode(_context.getStageId(), _dataSchema, null, inputs, 
info.getTitle(), info.getAttributes());
   }
 
-  private Future<Void> startExecution() {
-    ResultsBlockConsumer resultsBlockConsumer = new ResultsBlockConsumer();
-    ServerQueryLogger queryLogger = ServerQueryLogger.getInstance();
+  @Nullable
+  private ErrorMseBlock getErrorBlock() {
+    return _errorBlock.get();
+  }
+
+  private void setErrorBlock(ErrorMseBlock errorBlock) {
+    // Keep the first encountered error block
+    _errorBlock.compareAndSet(null, errorBlock);
+  }
+
+  private Future<?> startExecution() {
     ThreadExecutionContext parentContext = 
Tracing.getThreadAccountant().getThreadExecutionContext();
     return _executorService.submit(() -> {
       try {
-        if (_requests.size() == 1) {
-          ServerQueryRequest request = _requests.get(0);
-          Tracing.ThreadAccountantOps.setupWorker(1, parentContext);
-
-          InstanceResponseBlock instanceResponseBlock =
-              _queryExecutor.execute(request, _executorService, 
resultsBlockConsumer);
-          if (queryLogger != null) {
-            queryLogger.logQuery(request, instanceResponseBlock, 
"MultistageEngine");
+        execute(parentContext);
+      } catch (Exception e) {
+        setErrorBlock(
+            ErrorMseBlock.fromError(QueryErrorCode.INTERNAL, "Caught exception 
while executing leaf stage: " + e));
+      } finally {
+        // Always add the last results block to mark the end of the execution 
and notify the main thread waiting for the
+        // results block.
+        try {
+          addResultsBlock(LAST_RESULTS_BLOCK);
+        } catch (Exception e) {
+          if (!(e instanceof EarlyTerminationException)) {
+            LOGGER.warn("Failed to add the last results block", e);
           }
-          // TODO: Revisit if we should treat all exceptions as query failure. 
Currently MERGE_RESPONSE_ERROR and
-          //       SERVER_SEGMENT_MISSING_ERROR are counted as query failure.
-          Map<Integer, String> exceptions = 
instanceResponseBlock.getExceptions();
-          if (!exceptions.isEmpty()) {
-            _exceptions = exceptions;
-          } else {
-            // NOTE: Instance response block might contain data (not metadata 
only) when all the segments are pruned.
-            //       Add the results block if it contains data.
-            BaseResultsBlock resultsBlock = 
instanceResponseBlock.getResultsBlock();
-            if (resultsBlock != null && resultsBlock.getNumRows() > 0) {
-              addResultsBlock(resultsBlock);
+        }
+      }
+    });
+  }
+
+  @VisibleForTesting
+  void execute(ThreadExecutionContext parentContext) {
+    ResultsBlockConsumer resultsBlockConsumer = new ResultsBlockConsumer();
+    ServerQueryLogger queryLogger = ServerQueryLogger.getInstance();
+    if (_requests.size() == 1) {
+      ServerQueryRequest request = _requests.get(0);
+      Tracing.ThreadAccountantOps.setupWorker(1, parentContext);
+
+      InstanceResponseBlock instanceResponseBlock =
+          _queryExecutor.execute(request, _executorService, 
resultsBlockConsumer);
+      if (queryLogger != null) {
+        queryLogger.logQuery(request, instanceResponseBlock, 
"MultistageEngine");
+      }
+      // Collect the execution stats
+      mergeExecutionStats(instanceResponseBlock.getResponseMetadata());
+      // TODO: Revisit if we should treat all exceptions as query failure. 
Currently MERGE_RESPONSE_ERROR and
+      //       SERVER_SEGMENT_MISSING_ERROR are counted as query failure.
+      Map<Integer, String> exceptions = instanceResponseBlock.getExceptions();
+      if (!exceptions.isEmpty()) {
+        
setErrorBlock(ErrorMseBlock.fromMap(QueryErrorCode.fromKeyMap(exceptions)));
+      } else {
+        // NOTE: Instance response block might contain data (not metadata 
only) when all the segments are pruned.
+        //       Add the results block if it contains data.
+        BaseResultsBlock resultsBlock = 
instanceResponseBlock.getResultsBlock();
+        if (resultsBlock != null && resultsBlock.getNumRows() > 0) {
+          try {
+            addResultsBlock(resultsBlock);
+          } catch (InterruptedException e) {
+            setErrorBlock(CANCELLED_BLOCK);
+          } catch (TimeoutException e) {
+            setErrorBlock(TIMEOUT_BLOCK);
+          } catch (Exception e) {
+            if (!(e instanceof EarlyTerminationException)) {
+              LOGGER.warn("Failed to add results block", e);
+            }
+          }
+        }
+      }
+    } else {
+      // Hit 2 physical tables, one REALTIME and one OFFLINE
+      assert _requests.size() == 2;
+      Future<?>[] futures = new Future[2];
+      // TODO: this latch mechanism is not the most elegant. We should change 
it to use a CompletionService.
+      //  In order to interrupt the execution in case of error, we could 
different mechanisms like throwing in the
+      //  future, or using a shared volatile variable.
+      CountDownLatch latch = new CountDownLatch(2);
+      for (int i = 0; i < 2; i++) {
+        ServerQueryRequest request = _requests.get(i);
+        int taskId = i;
+        futures[i] = _executorService.submit(() -> {
+          Tracing.ThreadAccountantOps.setupWorker(taskId, parentContext);
+
+          try {
+            InstanceResponseBlock instanceResponseBlock =
+                _queryExecutor.execute(request, _executorService, 
resultsBlockConsumer);
+            if (queryLogger != null) {
+              queryLogger.logQuery(request, instanceResponseBlock, 
"MultistageEngine");
             }
             // Collect the execution stats
             mergeExecutionStats(instanceResponseBlock.getResponseMetadata());
-          }
-        } else {
-          assert _requests.size() == 2;
-          Future<Map<String, String>>[] futures = new Future[2];
-          // TODO: this latch mechanism is not the most elegant. We should 
change it to use a CompletionService.
-          //  In order to interrupt the execution in case of error, we could 
different mechanisms like throwing in the
-          //  future, or using a shared volatile variable.
-          CountDownLatch latch = new CountDownLatch(2);
-          for (int i = 0; i < 2; i++) {
-            ServerQueryRequest request = _requests.get(i);
-            int taskId = i;
-            futures[i] = _executorService.submit(() -> {
-              Tracing.ThreadAccountantOps.setupWorker(taskId, parentContext);
-
-              try {
-                InstanceResponseBlock instanceResponseBlock =
-                    _queryExecutor.execute(request, _executorService, 
resultsBlockConsumer);
-                if (queryLogger != null) {
-                  queryLogger.logQuery(request, instanceResponseBlock, 
"MultistageEngine");
-                }
-                Map<Integer, String> exceptions = 
instanceResponseBlock.getExceptions();
-                if (!exceptions.isEmpty()) {
-                  // Drain the latch when receiving exception block and not 
wait for the other thread to finish
-                  _exceptions = exceptions;
-                  latch.countDown();
-                  return Collections.emptyMap();
-                } else {
-                  // NOTE: Instance response block might contain data (not 
metadata only) when all the segments are
-                  //       pruned. Add the results block if it contains data.
-                  BaseResultsBlock resultsBlock = 
instanceResponseBlock.getResultsBlock();
-                  if (resultsBlock != null && resultsBlock.getNumRows() > 0) {
-                    addResultsBlock(resultsBlock);
+            Map<Integer, String> exceptions = 
instanceResponseBlock.getExceptions();
+            if (!exceptions.isEmpty()) {
+              
setErrorBlock(ErrorMseBlock.fromMap(QueryErrorCode.fromKeyMap(exceptions)));
+              // Drain the latch when receiving exception block and not wait 
for the other thread to finish
+              latch.countDown();
+            } else {
+              // NOTE: Instance response block might contain data (not 
metadata only) when all the segments are
+              //       pruned. Add the results block if it contains data.
+              BaseResultsBlock resultsBlock = 
instanceResponseBlock.getResultsBlock();
+              if (resultsBlock != null && resultsBlock.getNumRows() > 0) {
+                try {
+                  addResultsBlock(resultsBlock);
+                } catch (InterruptedException e) {
+                  setErrorBlock(CANCELLED_BLOCK);
+                } catch (TimeoutException e) {
+                  setErrorBlock(TIMEOUT_BLOCK);
+                } catch (Exception e) {
+                  if (!(e instanceof EarlyTerminationException)) {
+                    LOGGER.warn("Failed to add results block", e);
                   }
-                  // Collect the execution stats
-                  return instanceResponseBlock.getResponseMetadata();
                 }
-              } finally {
-                latch.countDown();
               }
-            });
-          }
-          try {
-            if (!latch.await(_context.getPassiveDeadlineMs() - 
System.currentTimeMillis(), TimeUnit.MILLISECONDS)) {
-              throw new TimeoutException("Timed out waiting for leaf stage to 
finish");
-            }
-            // Propagate the exception thrown by the leaf stage
-            for (Future<Map<String, String>> future : futures) {
-              Map<String, String> stats =
-                  future.get(_context.getPassiveDeadlineMs() - 
System.currentTimeMillis(), TimeUnit.MILLISECONDS);
-              mergeExecutionStats(stats);
             }
-          } catch (TimeoutException e) {
-            throw new TimeoutException("Timed out waiting for leaf stage to 
finish");
           } finally {
-            for (Future<?> future : futures) {
-              future.cancel(true);
-            }
+            latch.countDown();
           }
+        });
+      }
+      try {
+        if (!latch.await(_context.getPassiveDeadlineMs() - 
System.currentTimeMillis(), TimeUnit.MILLISECONDS)) {
+          setErrorBlock(TIMEOUT_BLOCK);
         }
-        return null;
+      } catch (InterruptedException e) {
+        setErrorBlock(CANCELLED_BLOCK);
       } finally {
-        // Always add the last results block to mark the end of the execution
-        addResultsBlock(LAST_RESULTS_BLOCK);
+        for (Future<?> future : futures) {
+          future.cancel(true);
+        }
       }
-    });
+    }
   }
 
   @VisibleForTesting
   void addResultsBlock(BaseResultsBlock resultsBlock)
       throws InterruptedException, TimeoutException {
+    if (_terminated) {
+      throw new EarlyTerminationException("Query has been terminated");
+    }
     if (!_blockingQueue.offer(resultsBlock, _context.getPassiveDeadlineMs() - 
System.currentTimeMillis(),
         TimeUnit.MILLISECONDS)) {
       throw new TimeoutException("Timed out waiting to add results block");
     }
   }
 
-  @Override
-  public void close() {
-    cancelSseTasks();
+  private void terminateAndClearResultsBlocks() {
+    _terminated = true;
+    _blockingQueue.clear();
   }
 
   /**
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
index 7fa9f157b96..b551737326b 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
@@ -33,7 +33,6 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
-import java.util.function.BiFunction;
 import javax.annotation.Nullable;
 import org.apache.commons.io.output.UnsynchronizedByteArrayOutputStream;
 import org.apache.pinot.common.config.TlsConfig;
@@ -357,10 +356,6 @@ public class QueryServer extends 
PinotQueryWorkerGrpc.PinotQueryWorkerImplBase {
     try (QueryThreadContext.CloseableContext qTlClosable
         = QueryThreadContext.openFromRequestMetadata(_instanceId, reqMetadata);
         QueryThreadContext.CloseableContext mseTlCloseable = 
MseWorkerThreadContext.open()) {
-      // Explain the stage for each worker
-      BiFunction<StagePlan, WorkerMetadata, StagePlan> explainFun = 
(stagePlan, workerMetadata) ->
-          _queryRunner.explainQuery(workerMetadata, stagePlan, reqMetadata);
-
       List<Worker.StagePlan> protoStagePlans = request.getStagePlanList();
 
       for (Worker.StagePlan protoStagePlan : protoStagePlans) {
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/LeafOperatorTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/LeafOperatorTest.java
index 7529004b7e4..5a01bf6d425 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/LeafOperatorTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/LeafOperatorTest.java
@@ -21,9 +21,13 @@ package org.apache.pinot.query.runtime.operator;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.operator.blocks.InstanceResponseBlock;
@@ -36,10 +40,12 @@ import 
org.apache.pinot.core.query.request.context.QueryContext;
 import 
org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.blocks.MseBlock;
+import org.apache.pinot.query.runtime.blocks.SuccessMseBlock;
+import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
 import org.apache.pinot.spi.exception.QueryErrorCode;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
-import org.testng.Assert;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.AfterMethod;
 import org.testng.annotations.BeforeMethod;
@@ -50,6 +56,10 @@ import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotSame;
+import static org.testng.Assert.assertSame;
+import static org.testng.Assert.assertTrue;
 
 
 // TODO: add tests for Agg / GroupBy / Distinct result blocks
@@ -121,9 +131,9 @@ public class LeafOperatorTest {
 
     // Then:
     List<Object[]> rows = ((MseBlock.Data) resultBlock).asRowHeap().getRows();
-    Assert.assertEquals(rows.get(0), new Object[]{"foo", 1});
-    Assert.assertEquals(rows.get(1), new Object[]{"", 2});
-    Assert.assertTrue(operator.nextBlock().isEos(), "Expected EOS after 
reading 2 blocks");
+    assertEquals(rows.get(0), new Object[]{"foo", 1});
+    assertEquals(rows.get(1), new Object[]{"", 2});
+    assertTrue(operator.nextBlock().isEos(), "Expected EOS after reading 2 
blocks");
 
     operator.close();
   }
@@ -153,9 +163,9 @@ public class LeafOperatorTest {
 
     // Then:
     List<Object[]> rows = ((MseBlock.Data) resultBlock).asRowHeap().getRows();
-    Assert.assertEquals(rows.get(0), new Object[]{1, 1660000000000L, 1});
-    Assert.assertEquals(rows.get(1), new Object[]{0, 1600000000000L, 0});
-    Assert.assertTrue(operator.nextBlock().isEos(), "Expected EOS after 
reading 2 blocks");
+    assertEquals(rows.get(0), new Object[]{1, 1660000000000L, 1});
+    assertEquals(rows.get(1), new Object[]{0, 1600000000000L, 0});
+    assertTrue(operator.nextBlock().isEos(), "Expected EOS after reading 2 
blocks");
 
     operator.close();
   }
@@ -184,11 +194,11 @@ public class LeafOperatorTest {
     // Then:
     List<Object[]> rows1 = ((MseBlock.Data) 
resultBlock1).asRowHeap().getRows();
     List<Object[]> rows2 = ((MseBlock.Data) 
resultBlock2).asRowHeap().getRows();
-    Assert.assertEquals(rows1.get(0), new Object[]{"foo", 1});
-    Assert.assertEquals(rows1.get(1), new Object[]{"", 2});
-    Assert.assertEquals(rows2.get(0), new Object[]{"bar", 3});
-    Assert.assertEquals(rows2.get(1), new Object[]{"foo", 4});
-    Assert.assertTrue(resultBlock3.isEos(), "Expected EOS after reading 2 
blocks");
+    assertEquals(rows1.get(0), new Object[]{"foo", 1});
+    assertEquals(rows1.get(1), new Object[]{"", 2});
+    assertEquals(rows2.get(0), new Object[]{"bar", 3});
+    assertEquals(rows2.get(1), new Object[]{"foo", 4});
+    assertTrue(resultBlock3.isEos(), "Expected EOS after reading 2 blocks");
 
     operator.close();
   }
@@ -210,11 +220,11 @@ public class LeafOperatorTest {
     _operatorRef.set(operator);
 
     // Then: the 5th block should be EOS
-    Assert.assertTrue(operator.nextBlock().isData());
-    Assert.assertTrue(operator.nextBlock().isData());
-    Assert.assertTrue(operator.nextBlock().isData());
-    Assert.assertTrue(operator.nextBlock().isData());
-    Assert.assertTrue(operator.nextBlock().isEos(), "Expected EOS after 
reading 5 blocks");
+    assertTrue(operator.nextBlock().isData());
+    assertTrue(operator.nextBlock().isData());
+    assertTrue(operator.nextBlock().isData());
+    assertTrue(operator.nextBlock().isData());
+    assertTrue(operator.nextBlock().isEos(), "Expected EOS after reading 5 
blocks");
 
     operator.close();
   }
@@ -240,7 +250,7 @@ public class LeafOperatorTest {
 
     // Then: error block can be returned as first or second block depending on 
the sequence of the execution
     if (!resultBlock.isError()) {
-      Assert.assertTrue(operator.nextBlock().isError());
+      assertTrue(operator.nextBlock().isError());
     }
 
     operator.close();
@@ -267,7 +277,7 @@ public class LeafOperatorTest {
     MseBlock resultBlock = operator.nextBlock();
 
     // Then:
-    Assert.assertTrue(resultBlock.isEos());
+    assertTrue(resultBlock.isEos());
 
     operator.close();
   }
@@ -340,4 +350,59 @@ public class LeafOperatorTest {
     // Then:
     verify(operator).cancelSseTasks();
   }
+
+  @Test
+  public void executionThreadShouldNotBlockOnLastResultsBlockWhenCancelled()
+      throws Exception {
+    // Given: operator with queue size 1
+    DataSchema schema = new DataSchema(new String[]{"strCol", "intCol"},
+        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, 
DataSchema.ColumnDataType.INT});
+    QueryContext queryContext = 
QueryContextConverterUtils.getQueryContext("SELECT strCol, intCol FROM tbl");
+
+    Map<String, String> opChainMetadata = new HashMap<>();
+    opChainMetadata.put("maxStreamingPendingBlocks", "1");
+    opChainMetadata.put("timeoutMs", "100000");
+    OpChainExecutionContext context = 
OperatorTestUtil.getContext(opChainMetadata);
+    CountDownLatch resultsBlockAdded = new CountDownLatch(1);
+
+    LeafOperator operator =
+        new LeafOperator(context, mockQueryRequests(1), schema, 
mock(QueryExecutor.class), _executorService) {
+          @Override
+          void execute(ThreadExecutionContext parentContext) {
+            try {
+              // Fill queue and block on second add
+              SelectionResultsBlock dataBlock =
+                  new SelectionResultsBlock(schema, Arrays.asList(new 
Object[]{"foo", 1}, new Object[]{"", 2}),
+                      queryContext);
+              // First data block is consumed by the first call of 
getNextBlock()
+              addResultsBlock(dataBlock);
+              // Second data block will remain in the blocking queue and block 
the third data block
+              addResultsBlock(dataBlock);
+              resultsBlockAdded.countDown();
+              addResultsBlock(dataBlock);
+            } catch (Exception e) {
+              assertTrue(e instanceof InterruptedException);
+            }
+          }
+        };
+
+    // Main thread read the next block to start the execution
+    assertTrue(operator.getNextBlock() instanceof MseBlock.Data);
+
+    // Wait for blocking queue to fill up
+    resultsBlockAdded.await();
+
+    // Early terminate the operator, which will also interrupt the child
+    operator.earlyTerminate();
+
+    // Child should still block on adding LAST_RESULTS_BLOCK
+    Thread.sleep(100);
+    assertNotSame(operator._blockingQueue.peek(), 
LeafOperator.LAST_RESULTS_BLOCK);
+
+    // Main thread read the next block, which should return SUCCESS_MSE_BLOCK 
and also unblock the child
+    assertSame(operator.getNextBlock(), SuccessMseBlock.INSTANCE);
+    assertSame(operator._blockingQueue.poll(10, TimeUnit.SECONDS), 
LeafOperator.LAST_RESULTS_BLOCK);
+
+    operator.close();
+  }
 }
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
index 142b0dc2cf9..b8e2be7d17d 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
@@ -18,8 +18,6 @@
  */
 package org.apache.pinot.query.runtime.operator;
 
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
@@ -28,6 +26,7 @@ import org.apache.pinot.common.datatable.StatMap;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.mailbox.ReceivingMailbox;
+import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
 import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.StagePlan;
 import org.apache.pinot.query.routing.WorkerMetadata;
@@ -48,9 +47,10 @@ import static org.mockito.Mockito.when;
 
 public class OperatorTestUtil {
   // simple key-value collision schema/data test set: "Aa" and "BB" have same 
hash code in java.
-  private static final List<List<Object[]>> SIMPLE_KV_DATA_ROWS =
-      ImmutableList.of(ImmutableList.of(new Object[]{1, "Aa"}, new Object[]{2, 
"BB"}, new Object[]{3, "BB"}),
-          ImmutableList.of(new Object[]{1, "AA"}, new Object[]{2, "Aa"}));
+  private static final List<List<Object[]>> SIMPLE_KV_DATA_ROWS = List.of(
+      List.of(new Object[]{1, "Aa"}, new Object[]{2, "BB"}, new Object[]{3, 
"BB"}),
+      List.of(new Object[]{1, "AA"}, new Object[]{2, "Aa"})
+  );
   private static final MockDataBlockOperatorFactory MOCK_OPERATOR_FACTORY;
 
   public static final DataSchema SIMPLE_KV_DATA_SCHEMA = new DataSchema(new 
String[]{"foo", "bar"},
@@ -61,14 +61,14 @@ public class OperatorTestUtil {
 
   public static MultiStageQueryStats getDummyStats(int stageId) {
     MultiStageQueryStats stats = MultiStageQueryStats.emptyStats(stageId);
-    stats.getCurrentStats()
-        .addLastOperator(MultiStageOperator.Type.LEAF, new 
StatMap<>(LeafOperator.StatKey.class));
+    stats.getCurrentStats().addLastOperator(MultiStageOperator.Type.LEAF, new 
StatMap<>(LeafOperator.StatKey.class));
     return stats;
   }
 
   static {
     MOCK_OPERATOR_FACTORY = new 
MockDataBlockOperatorFactory().registerOperator(OP_1, SIMPLE_KV_DATA_SCHEMA)
-        .registerOperator(OP_2, SIMPLE_KV_DATA_SCHEMA).addRows(OP_1, 
SIMPLE_KV_DATA_ROWS.get(0))
+        .registerOperator(OP_2, SIMPLE_KV_DATA_SCHEMA)
+        .addRows(OP_1, SIMPLE_KV_DATA_ROWS.get(0))
         .addRows(OP_2, SIMPLE_KV_DATA_ROWS.get(1));
   }
 
@@ -109,12 +109,12 @@ public class OperatorTestUtil {
 
   public static OpChainExecutionContext getOpChainContext(MailboxService 
mailboxService, long deadlineMs,
       StageMetadata stageMetadata) {
-    return new OpChainExecutionContext(mailboxService, 0, deadlineMs, 
ImmutableMap.of(), stageMetadata,
+    return new OpChainExecutionContext(mailboxService, 0, deadlineMs, 
Map.of(), stageMetadata,
         stageMetadata.getWorkerMetadataList().get(0), null, null, true);
   }
 
   public static OpChainExecutionContext getTracingContext() {
-    return 
getTracingContext(ImmutableMap.of(CommonConstants.Broker.Request.TRACE, 
"true"));
+    return getTracingContext(Map.of(CommonConstants.Broker.Request.TRACE, 
"true"));
   }
 
   public static OpChainExecutionContext getContext(Map<String, String> 
opChainMetadata) {
@@ -122,22 +122,23 @@ public class OperatorTestUtil {
   }
 
   public static OpChainExecutionContext getNoTracingContext() {
-    return getTracingContext(ImmutableMap.of());
+    return getTracingContext(Map.of());
   }
 
   private static OpChainExecutionContext getTracingContext(Map<String, String> 
opChainMetadata) {
     MailboxService mailboxService = mock(MailboxService.class);
     when(mailboxService.getHostname()).thenReturn("localhost");
     when(mailboxService.getPort()).thenReturn(1234);
-    WorkerMetadata workerMetadata = new WorkerMetadata(0, ImmutableMap.of(), 
ImmutableMap.of());
-    StageMetadata stageMetadata = new StageMetadata(0, 
ImmutableList.of(workerMetadata), ImmutableMap.of());
-    OpChainExecutionContext opChainExecutionContext = new 
OpChainExecutionContext(mailboxService, 123L, Long.MAX_VALUE,
-        opChainMetadata, stageMetadata, workerMetadata, null, null, true);
+    WorkerMetadata workerMetadata = new WorkerMetadata(0, Map.of(), Map.of());
+    StageMetadata stageMetadata =
+        new StageMetadata(0, List.of(workerMetadata), 
Map.of(DispatchablePlanFragment.TABLE_NAME_KEY, "testTable"));
+    OpChainExecutionContext opChainExecutionContext =
+        new OpChainExecutionContext(mailboxService, 123L, Long.MAX_VALUE, 
opChainMetadata, stageMetadata,
+            workerMetadata, null, null, true);
 
     StagePlan stagePlan = new StagePlan(null, stageMetadata);
 
-    opChainExecutionContext.setLeafStageContext(
-        new ServerPlanRequestContext(stagePlan, null, null, null));
+    opChainExecutionContext.setLeafStageContext(new 
ServerPlanRequestContext(stagePlan, null, null, null));
     return opChainExecutionContext;
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to