This is an automated email from the ASF dual-hosted git repository.

gortiz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new d008709064 Add OOM Protection Support for Multi-Stage Queries (#13598)
d008709064 is described below

commit d00870906412506fbd03f58fb5f4b78022b2b43d
Author: Rajat Venkatesh <1638298+vra...@users.noreply.github.com>
AuthorDate: Tue Sep 3 14:40:47 2024 +0530

    Add OOM Protection Support for Multi-Stage Queries (#13598)
    
    track cpu and memory usage in multi-stage queries if query resource usage 
tracking is enabled
---
 .../MultiStageBrokerRequestHandler.java            |   6 +
 .../CPUMemThreadLevelAccountingObjects.java        |  20 +-
 .../PerQueryCPUMemAccountantFactory.java           |   4 +-
 .../apache/pinot/query/runtime/QueryRunner.java    |   8 +-
 .../runtime/executor/OpChainSchedulerService.java  |   8 +
 .../query/runtime/operator/AggregateOperator.java  |   2 +
 .../query/runtime/operator/HashJoinOperator.java   |   2 +
 .../runtime/operator/MailboxReceiveOperator.java   |   2 +
 .../runtime/operator/MailboxSendOperator.java      |   1 +
 .../query/runtime/operator/MultiStageOperator.java |   9 +
 .../pinot/query/runtime/operator/OpChain.java      |   7 +
 .../pinot/query/runtime/operator/OpChainId.java    |   4 +
 .../pinot/query/runtime/operator/SetOperator.java  |   2 +
 .../pinot/query/runtime/operator/SortOperator.java |   1 +
 .../runtime/operator/WindowAggregateOperator.java  |   1 +
 .../runtime/plan/OpChainExecutionContext.java      |  10 +-
 .../plan/pipeline/PipelineBreakerExecutor.java     |  15 +-
 .../query/service/dispatch/QueryDispatcher.java    |   5 +-
 .../pinot/query/service/server/QueryServer.java    |  10 +-
 .../apache/pinot/query/QueryServerEnclosure.java   |   3 +-
 .../executor/OpChainSchedulerServiceTest.java      |   2 +-
 .../runtime/operator/MailboxSendOperatorTest.java  |   2 +-
 .../runtime/operator/MultiStageAccountingTest.java | 239 +++++++++++++++++++++
 .../query/runtime/operator/OperatorTestUtil.java   |   4 +-
 .../query/service/server/QueryServerTest.java      |   4 +-
 .../spi/accounting/ThreadExecutionContext.java     |  13 ++
 .../spi/accounting/ThreadResourceTracker.java      |   2 +
 .../accounting/ThreadResourceUsageAccountant.java  |   3 +-
 .../java/org/apache/pinot/spi/trace/Tracing.java   |  35 ++-
 29 files changed, 399 insertions(+), 25 deletions(-)

diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
index a1e82dbd53..fad63d18e2 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
@@ -64,10 +64,12 @@ import org.apache.pinot.query.routing.WorkerManager;
 import org.apache.pinot.query.runtime.MultiStageStatsTreeBuilder;
 import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
 import org.apache.pinot.query.service.dispatch.QueryDispatcher;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
 import org.apache.pinot.spi.auth.TableAuthorizationResult;
 import org.apache.pinot.spi.env.PinotConfiguration;
 import org.apache.pinot.spi.exception.DatabaseConflictException;
 import org.apache.pinot.spi.trace.RequestContext;
+import org.apache.pinot.spi.trace.Tracing;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.spi.utils.builder.TableNameBuilder;
 import org.apache.pinot.sql.parsers.SqlNodeAndOptions;
@@ -210,6 +212,8 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
       return new 
BrokerResponseNative(QueryException.getException(QueryException.QUOTA_EXCEEDED_ERROR,
 errorMessage));
     }
 
+    Tracing.ThreadAccountantOps.setupRunner(String.valueOf(requestId), 
ThreadExecutionContext.TaskType.MSE);
+
     long executionStartTimeNs = System.nanoTime();
     QueryDispatcher.QueryResult queryResults;
     try {
@@ -228,6 +232,8 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
       requestContext.setErrorCode(QueryException.QUERY_EXECUTION_ERROR_CODE);
       return new BrokerResponseNative(
           QueryException.getException(QueryException.QUERY_EXECUTION_ERROR, 
consolidatedMessage));
+    } finally {
+      Tracing.getThreadAccountant().clear();
     }
     long executionEndTimeNs = System.nanoTime();
     updatePhaseTimingForTables(tableNames, BrokerQueryPhase.QUERY_EXECUTION, 
executionEndTimeNs - executionStartTimeNs);
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/accounting/CPUMemThreadLevelAccountingObjects.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/accounting/CPUMemThreadLevelAccountingObjects.java
index 6a375b95bb..378fcd991f 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/accounting/CPUMemThreadLevelAccountingObjects.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/accounting/CPUMemThreadLevelAccountingObjects.java
@@ -103,8 +103,15 @@ public class CPUMemThreadLevelAccountingObjects {
       return taskEntry == null ? -1 : taskEntry.getTaskId();
     }
 
-    public void setThreadTaskStatus(@Nonnull String queryId, int taskId, 
@Nonnull Thread anchorThread) {
-      _currentThreadTaskStatus.set(new TaskEntry(queryId, taskId, 
anchorThread));
+    @Override
+    public ThreadExecutionContext.TaskType getTaskType() {
+      TaskEntry taskEntry = _currentThreadTaskStatus.get();
+      return taskEntry == null ? ThreadExecutionContext.TaskType.UNKNOWN : 
taskEntry.getTaskType();
+    }
+
+    public void setThreadTaskStatus(@Nonnull String queryId, int taskId, 
ThreadExecutionContext.TaskType taskType,
+        @Nonnull Thread anchorThread) {
+      _currentThreadTaskStatus.set(new TaskEntry(queryId, taskId, taskType, 
anchorThread));
     }
   }
 
@@ -117,15 +124,17 @@ public class CPUMemThreadLevelAccountingObjects {
     private final String _queryId;
     private final int _taskId;
     private final Thread _anchorThread;
+    private final TaskType _taskType;
 
     public boolean isAnchorThread() {
       return _taskId == CommonConstants.Accounting.ANCHOR_TASK_ID;
     }
 
-    public TaskEntry(String queryId, int taskId, Thread anchorThread) {
+    public TaskEntry(String queryId, int taskId, TaskType taskType, Thread 
anchorThread) {
       _queryId = queryId;
       _taskId = taskId;
       _anchorThread = anchorThread;
+      _taskType = taskType;
     }
 
     public String getQueryId() {
@@ -140,6 +149,11 @@ public class CPUMemThreadLevelAccountingObjects {
       return _anchorThread;
     }
 
+    @Override
+    public TaskType getTaskType() {
+      return _taskType;
+    }
+
     @Override
     public String toString() {
       return "TaskEntry{" + "_queryId='" + _queryId + '\'' + ", _taskId=" + 
_taskId + ", _rootThread=" + _anchorThread
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java
index 34e496d3bf..fef0c3beeb 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java
@@ -293,10 +293,10 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
         // is anchor thread
         assert queryId != null;
         _threadLocalEntry.get().setThreadTaskStatus(queryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID,
-            Thread.currentThread());
+            ThreadExecutionContext.TaskType.UNKNOWN, Thread.currentThread());
       } else {
         // not anchor thread
-        
_threadLocalEntry.get().setThreadTaskStatus(parentContext.getQueryId(), taskId,
+        
_threadLocalEntry.get().setThreadTaskStatus(parentContext.getQueryId(), taskId, 
parentContext.getTaskType(),
             parentContext.getAnchorThread());
       }
     }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
index 0032545c0f..77d9b53d43 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
@@ -50,6 +50,7 @@ import 
org.apache.pinot.query.runtime.plan.PhysicalPlanVisitor;
 import org.apache.pinot.query.runtime.plan.pipeline.PipelineBreakerExecutor;
 import org.apache.pinot.query.runtime.plan.pipeline.PipelineBreakerResult;
 import org.apache.pinot.query.runtime.plan.server.ServerPlanRequestUtils;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
 import org.apache.pinot.spi.env.PinotConfiguration;
 import org.apache.pinot.spi.utils.CommonConstants;
 import 
org.apache.pinot.spi.utils.CommonConstants.Broker.Request.QueryOptionKey;
@@ -152,7 +153,8 @@ public class QueryRunner {
    * <p>This execution entry point should be asynchronously called by the 
request handler and caller should not wait
    * for results/exceptions.</p>
    */
-  public void processQuery(WorkerMetadata workerMetadata, StagePlan stagePlan, 
Map<String, String> requestMetadata) {
+  public void processQuery(WorkerMetadata workerMetadata, StagePlan stagePlan, 
Map<String, String> requestMetadata,
+      @Nullable ThreadExecutionContext parentContext) {
     long requestId = 
Long.parseLong(requestMetadata.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID));
     long timeoutMs = 
Long.parseLong(requestMetadata.get(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS));
     long deadlineMs = System.currentTimeMillis() + timeoutMs;
@@ -163,7 +165,7 @@ public class QueryRunner {
     // run pre-stage execution for all pipeline breakers
     PipelineBreakerResult pipelineBreakerResult =
         PipelineBreakerExecutor.executePipelineBreakers(_opChainScheduler, 
_mailboxService, workerMetadata, stagePlan,
-            opChainMetadata, requestId, deadlineMs);
+            opChainMetadata, requestId, deadlineMs, parentContext);
 
     // Send error block to all the receivers if pipeline breaker fails
     if (pipelineBreakerResult != null && pipelineBreakerResult.getErrorBlock() 
!= null) {
@@ -196,7 +198,7 @@ public class QueryRunner {
     // run OpChain
     OpChainExecutionContext executionContext =
         new OpChainExecutionContext(_mailboxService, requestId, deadlineMs, 
opChainMetadata, stageMetadata,
-            workerMetadata, pipelineBreakerResult);
+            workerMetadata, pipelineBreakerResult, parentContext);
     OpChain opChain;
     if (workerMetadata.isLeafStageWorker()) {
       opChain = ServerPlanRequestUtils.compileLeafStage(executionContext, 
stagePlan, _helixManager, _serverMetrics,
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
index ef0f9a7d12..bfd74c68d5 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
@@ -27,6 +27,9 @@ import org.apache.pinot.core.util.trace.TraceRunnable;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.operator.OpChain;
 import org.apache.pinot.query.runtime.operator.OpChainId;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
+import org.apache.pinot.spi.accounting.ThreadResourceUsageProvider;
+import org.apache.pinot.spi.trace.Tracing;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -50,6 +53,10 @@ public class OpChainSchedulerService {
         TransferableBlock returnedErrorBlock = null;
         Throwable thrown = null;
         try {
+          ThreadResourceUsageProvider threadResourceUsageProvider = new 
ThreadResourceUsageProvider();
+          
Tracing.ThreadAccountantOps.setupWorker(operatorChain.getId().getStageId(),
+              ThreadExecutionContext.TaskType.MSE, threadResourceUsageProvider,
+              operatorChain.getParentContext());
           LOGGER.trace("({}): Executing", operatorChain);
           TransferableBlock result = operatorChain.getRoot().nextBlock();
           while (!result.isEndOfStreamBlock()) {
@@ -76,6 +83,7 @@ public class OpChainSchedulerService {
           } else if (isFinished) {
             operatorChain.close();
           }
+          Tracing.ThreadAccountantOps.clear();
         }
       }
     });
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index ce6d30d451..38ff7d2d5c 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -173,6 +173,7 @@ public class AggregateOperator extends MultiStageOperator {
     TransferableBlock block = _input.nextBlock();
     while (block.isDataBlock()) {
       _groupByExecutor.processBlock(block);
+      sampleAndCheckInterruption();
       block = _input.nextBlock();
     }
     return block;
@@ -187,6 +188,7 @@ public class AggregateOperator extends MultiStageOperator {
     TransferableBlock block = _input.nextBlock();
     while (block.isDataBlock()) {
       _aggregationExecutor.processBlock(block);
+      sampleAndCheckInterruption();
       block = _input.nextBlock();
     }
     return block;
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
index c18deb2ea4..0dee6c06bd 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
@@ -253,6 +253,7 @@ public class HashJoinOperator extends MultiStageOperator {
         hashCollection.add(row);
       }
       _currentRowsInHashTable += container.size();
+      sampleAndCheckInterruption();
       rightBlock = _rightInput.nextBlock();
     }
     if (rightBlock.isErrorBlock()) {
@@ -297,6 +298,7 @@ public class HashJoinOperator extends MultiStageOperator {
       }
       assert leftBlock.isDataBlock();
       List<Object[]> rows = buildJoinedRows(leftBlock);
+      sampleAndCheckInterruption();
       if (!rows.isEmpty()) {
         return new TransferableBlock(rows, _resultSchema, DataBlock.Type.ROW);
       }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java
index c1c7647a1e..82e60e34f6 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java
@@ -60,6 +60,8 @@ public class MailboxReceiveOperator extends 
BaseMailboxReceiveOperator {
     }
     if (block.isSuccessfulEndOfStreamBlock()) {
       updateEosBlock(block, _statMap);
+    } else if (block.isDataBlock()) {
+      sampleAndCheckInterruption();
     }
     return block;
   }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
index 7ebbaa2e91..864f200fe6 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
@@ -143,6 +143,7 @@ public class MailboxSendOperator extends MultiStageOperator 
{
           earlyTerminate();
         }
       }
+      sampleAndCheckInterruption();
       return block;
     } catch (QueryCancelledException e) {
       LOGGER.debug("Query was cancelled! for opChain: {}", _context.getId());
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultiStageOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultiStageOperator.java
index 0321bedc1b..2690a5a7f7 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultiStageOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultiStageOperator.java
@@ -66,6 +66,15 @@ public abstract class MultiStageOperator
 
   public abstract void registerExecution(long time, int numRows);
 
+  // Samples resource usage of the operator. The operator should call this 
function for every block of data or
+  // assuming the block holds 10000 rows or more.
+  protected void sampleAndCheckInterruption() {
+    Tracing.ThreadAccountantOps.sample();
+    if (Tracing.ThreadAccountantOps.isInterrupted()) {
+      earlyTerminate();
+    }
+  }
+
   /**
    * Returns the next block from the operator. It should return non-empty data 
blocks followed by an end-of-stream (EOS)
    * block when all the data is processed, or an error block if an error 
occurred. After it returns EOS or error block,
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChain.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChain.java
index 5d989f169f..86eca19b49 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChain.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChain.java
@@ -22,6 +22,7 @@ import java.util.function.Consumer;
 import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -36,6 +37,7 @@ public class OpChain implements AutoCloseable {
   private final OpChainId _id;
   private final MultiStageOperator _root;
   private final Consumer<OpChainId> _finishCallback;
+  private final ThreadExecutionContext _parentContext;
 
   public OpChain(OpChainExecutionContext context, MultiStageOperator root) {
     this(context, root, (id) -> {
@@ -46,6 +48,7 @@ public class OpChain implements AutoCloseable {
     _id = context.getId();
     _root = root;
     _finishCallback = finishCallback;
+    _parentContext = context.getParentContext();
   }
 
   public OpChainId getId() {
@@ -56,6 +59,10 @@ public class OpChain implements AutoCloseable {
     return _root;
   }
 
+  public ThreadExecutionContext getParentContext() {
+    return _parentContext;
+  }
+
   @Override
   public String toString() {
     return "OpChain{" + _id + "}";
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChainId.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChainId.java
index dae9c2f755..e78dacd8a4 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChainId.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChainId.java
@@ -40,6 +40,10 @@ public class OpChainId {
     return _virtualServerId;
   }
 
+  public int getStageId() {
+    return _stageId;
+  }
+
   @Override
   public String toString() {
     return String.format("%s_%s_%s", _requestId, _virtualServerId, _stageId);
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SetOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SetOperator.java
index ea5cf046d7..e92edb1123 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SetOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SetOperator.java
@@ -119,6 +119,7 @@ public abstract class SetOperator extends 
MultiStageOperator {
           _rightRowSet.add(new Record(row));
         }
       }
+      sampleAndCheckInterruption();
       block = _rightChildOperator.nextBlock();
     }
     if (block.isErrorBlock()) {
@@ -153,6 +154,7 @@ public abstract class SetOperator extends 
MultiStageOperator {
           rows.add(row);
         }
       }
+      sampleAndCheckInterruption();
       if (!rows.isEmpty()) {
         return new TransferableBlock(rows, _dataSchema, DataBlock.Type.ROW);
       }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java
index 9f553424eb..a5a5e15e5f 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java
@@ -181,6 +181,7 @@ public class SortOperator extends MultiStageOperator {
         for (Object[] row : container) {
           SelectionOperatorUtils.addToPriorityQueue(row, _priorityQueue, 
_numRowsToKeep);
         }
+        sampleAndCheckInterruption();
       }
       block = _input.nextBlock();
     }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
index 27001778d4..7d97356eb2 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
@@ -269,6 +269,7 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
         _partitionRows.computeIfAbsent(key, k -> new ArrayList<>()).add(row);
       }
       _numRows += containerSize;
+      sampleAndCheckInterruption();
       block = _input.nextBlock();
     }
     // Early termination if the block is an error block
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java
index 3290478de8..f50c169c45 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java
@@ -28,6 +28,7 @@ import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.operator.OpChainId;
 import org.apache.pinot.query.runtime.plan.pipeline.PipelineBreakerResult;
 import org.apache.pinot.query.runtime.plan.server.ServerPlanRequestContext;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
 import org.apache.pinot.spi.utils.CommonConstants;
 
 
@@ -48,12 +49,13 @@ public class OpChainExecutionContext {
   @Nullable
   private final PipelineBreakerResult _pipelineBreakerResult;
   private final boolean _traceEnabled;
+  private final ThreadExecutionContext _parentContext;
 
   private ServerPlanRequestContext _leafStageContext;
 
   public OpChainExecutionContext(MailboxService mailboxService, long 
requestId, long deadlineMs,
       Map<String, String> opChainMetadata, StageMetadata stageMetadata, 
WorkerMetadata workerMetadata,
-      @Nullable PipelineBreakerResult pipelineBreakerResult) {
+      @Nullable PipelineBreakerResult pipelineBreakerResult, @Nullable 
ThreadExecutionContext parentContext) {
     _mailboxService = mailboxService;
     _requestId = requestId;
     _deadlineMs = deadlineMs;
@@ -65,6 +67,7 @@ public class OpChainExecutionContext {
     _id = new OpChainId(requestId, workerMetadata.getWorkerId(), 
stageMetadata.getStageId());
     _pipelineBreakerResult = pipelineBreakerResult;
     _traceEnabled = 
Boolean.parseBoolean(opChainMetadata.get(CommonConstants.Broker.Request.TRACE));
+    _parentContext = parentContext;
   }
 
   public MailboxService getMailboxService() {
@@ -123,4 +126,9 @@ public class OpChainExecutionContext {
   public void setLeafStageContext(ServerPlanRequestContext leafStageContext) {
     _leafStageContext = leafStageContext;
   }
+
+  @Nullable
+  public ThreadExecutionContext getParentContext() {
+    return _parentContext;
+  }
 }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
index 2e0cc7003d..0c3d1ac7e4 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
@@ -37,6 +37,7 @@ import 
org.apache.pinot.query.runtime.executor.OpChainSchedulerService;
 import org.apache.pinot.query.runtime.operator.OpChain;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.apache.pinot.query.runtime.plan.PhysicalPlanVisitor;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -50,6 +51,14 @@ public class PipelineBreakerExecutor {
 
   private static final Logger LOGGER = 
LoggerFactory.getLogger(PipelineBreakerExecutor.class);
 
+  @Nullable
+  public static PipelineBreakerResult 
executePipelineBreakers(OpChainSchedulerService scheduler,
+      MailboxService mailboxService, WorkerMetadata workerMetadata, StagePlan 
stagePlan,
+      Map<String, String> opChainMetadata, long requestId, long deadlineMs) {
+    return executePipelineBreakers(scheduler, mailboxService, workerMetadata, 
stagePlan, opChainMetadata, requestId,
+        deadlineMs, null);
+  }
+
   /**
    * Execute a pipeline breaker and collect the results (synchronously). 
Currently, pipeline breaker executor can only
    *    execute mailbox receive pipeline breaker.
@@ -61,6 +70,7 @@ public class PipelineBreakerExecutor {
    * @param opChainMetadata request metadata, including query options
    * @param requestId request ID
    * @param deadlineMs execution deadline
+   * @param parentContext Parent thread metadata
    * @return pipeline breaker result;
    *   - If exception occurs, exception block will be wrapped in {@link 
TransferableBlock} and assigned to each PB node.
    *   - Normal stats will be attached to each PB node and downstream 
execution should return with stats attached.
@@ -68,7 +78,8 @@ public class PipelineBreakerExecutor {
   @Nullable
   public static PipelineBreakerResult 
executePipelineBreakers(OpChainSchedulerService scheduler,
       MailboxService mailboxService, WorkerMetadata workerMetadata, StagePlan 
stagePlan,
-      Map<String, String> opChainMetadata, long requestId, long deadlineMs) {
+      Map<String, String> opChainMetadata, long requestId, long deadlineMs,
+      @Nullable ThreadExecutionContext parentContext) {
     PipelineBreakerContext pipelineBreakerContext = new 
PipelineBreakerContext();
     PipelineBreakerVisitor.visitPlanRoot(stagePlan.getRootNode(), 
pipelineBreakerContext);
     if (!pipelineBreakerContext.getPipelineBreakerMap().isEmpty()) {
@@ -78,7 +89,7 @@ public class PipelineBreakerExecutor {
         // see also: MailboxIdUtils TODOs, de-couple mailbox id from query 
information
         OpChainExecutionContext opChainExecutionContext =
             new OpChainExecutionContext(mailboxService, requestId, deadlineMs, 
opChainMetadata,
-                stagePlan.getStageMetadata(), workerMetadata, null);
+                stagePlan.getStageMetadata(), workerMetadata, null, 
parentContext);
         return execute(scheduler, pipelineBreakerContext, 
opChainExecutionContext);
       } catch (Exception e) {
         LOGGER.error("Caught exception executing pipeline breaker for request: 
{}, stage: {}", requestId,
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index 0d2aac509a..fc8874911e 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -60,7 +60,9 @@ import 
org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.operator.MailboxReceiveOperator;
 import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
 import org.apache.pinot.spi.trace.RequestContext;
+import org.apache.pinot.spi.trace.Tracing;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -260,9 +262,10 @@ public class QueryDispatcher {
     Preconditions.checkState(workerMetadataList.size() == 1, "Expecting single 
worker for reduce stage, got: %s",
         workerMetadataList.size());
     StageMetadata stageMetadata = new StageMetadata(0, workerMetadataList, 
dispatchableStagePlan.getCustomProperties());
+    ThreadExecutionContext parentContext = 
Tracing.getThreadAccountant().getThreadExecutionContext();
     OpChainExecutionContext opChainExecutionContext =
         new OpChainExecutionContext(mailboxService, requestId, deadlineMs, 
queryOptions, stageMetadata,
-            workerMetadataList.get(0), null);
+            workerMetadataList.get(0), null, parentContext);
 
     PairList<Integer, String> resultFields = 
dispatchableSubPlan.getQueryResultFields();
     DataSchema sourceDataSchema = receiveNode.getDataSchema();
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
index c8caed9100..763192e16e 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
@@ -37,6 +37,8 @@ import org.apache.pinot.query.routing.StagePlan;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.QueryRunner;
 import org.apache.pinot.query.service.dispatch.QueryDispatcher;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
+import org.apache.pinot.spi.trace.Tracing;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -112,9 +114,12 @@ public class QueryServer extends 
PinotQueryWorkerGrpc.PinotQueryWorkerImplBase {
     long timeoutMs = 
Long.parseLong(requestMetadata.get(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS));
     long deadlineMs = System.currentTimeMillis() + timeoutMs;
 
+    Tracing.ThreadAccountantOps.setupRunner(Long.toString(requestId), 
ThreadExecutionContext.TaskType.MSE);
+
     List<Worker.StagePlan> protoStagePlans = request.getStagePlanList();
     int numStages = protoStagePlans.size();
     CompletableFuture<?>[] stageSubmissionStubs = new 
CompletableFuture[numStages];
+    ThreadExecutionContext parentContext = 
Tracing.getThreadAccountant().getThreadExecutionContext();
     for (int i = 0; i < numStages; i++) {
       Worker.StagePlan protoStagePlan = protoStagePlans.get(i);
       stageSubmissionStubs[i] = CompletableFuture.runAsync(() -> {
@@ -133,8 +138,8 @@ public class QueryServer extends 
PinotQueryWorkerGrpc.PinotQueryWorkerImplBase {
         for (int j = 0; j < numWorkers; j++) {
           WorkerMetadata workerMetadata = workerMetadataList.get(j);
           workerSubmissionStubs[j] =
-              CompletableFuture.runAsync(() -> 
_queryRunner.processQuery(workerMetadata, stagePlan, requestMetadata),
-                  _querySubmissionExecutorService);
+              CompletableFuture.runAsync(() -> 
_queryRunner.processQuery(workerMetadata, stagePlan, requestMetadata,
+                      parentContext), _querySubmissionExecutorService);
         }
         try {
           CompletableFuture.allOf(workerSubmissionStubs)
@@ -167,6 +172,7 @@ public class QueryServer extends 
PinotQueryWorkerGrpc.PinotQueryWorkerImplBase {
           future.cancel(true);
         }
       }
+      Tracing.getThreadAccountant().clear();
     }
     responseObserver.onNext(
         
Worker.QueryResponse.newBuilder().putMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_OK,
 "")
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryServerEnclosure.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryServerEnclosure.java
index 3de9cc6fdb..1e03b31411 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryServerEnclosure.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryServerEnclosure.java
@@ -112,7 +112,8 @@ public class QueryServerEnclosure {
 
   public CompletableFuture<Void> processQuery(WorkerMetadata workerMetadata, 
StagePlan stagePlan,
       Map<String, String> requestMetadataMap) {
-    return CompletableFuture.runAsync(() -> 
_queryRunner.processQuery(workerMetadata, stagePlan, requestMetadataMap),
+    return CompletableFuture.runAsync(
+        () -> _queryRunner.processQuery(workerMetadata, stagePlan, 
requestMetadataMap, null),
         _queryRunner.getExecutorService());
   }
 }
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
index aff6a68853..61aad6a9a7 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
@@ -79,7 +79,7 @@ public class OpChainSchedulerServiceTest {
     WorkerMetadata workerMetadata = new WorkerMetadata(0, ImmutableMap.of(), 
ImmutableMap.of());
     OpChainExecutionContext context =
         new OpChainExecutionContext(mailboxService, 123L, Long.MAX_VALUE, 
ImmutableMap.of(),
-            new StageMetadata(0, ImmutableList.of(workerMetadata), 
ImmutableMap.of()), workerMetadata, null);
+            new StageMetadata(0, ImmutableList.of(workerMetadata), 
ImmutableMap.of()), workerMetadata, null, null);
     return new OpChain(context, operator);
   }
 
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java
index cc92873c94..a54dc182a9 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java
@@ -197,7 +197,7 @@ public class MailboxSendOperatorTest {
     StageMetadata stageMetadata = new StageMetadata(SENDER_STAGE_ID, 
List.of(workerMetadata), Map.of());
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 123L, Long.MAX_VALUE, 
Map.of(), stageMetadata, workerMetadata,
-            null);
+            null, null);
     return new MailboxSendOperator(context, _input, statMap -> _exchange);
   }
 
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java
new file mode 100644
index 0000000000..2821de5e73
--- /dev/null
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java
@@ -0,0 +1,239 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator;
+
+import com.google.common.collect.ImmutableList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import org.apache.calcite.rel.RelFieldCollation;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.pinot.common.datablock.DataBlock;
+import org.apache.pinot.common.metrics.ServerMetrics;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.planner.plannode.AggregateNode;
+import org.apache.pinot.query.planner.plannode.JoinNode;
+import org.apache.pinot.query.planner.plannode.PlanNode;
+import org.apache.pinot.query.planner.plannode.SortNode;
+import org.apache.pinot.query.planner.plannode.WindowNode;
+import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+import org.apache.pinot.query.runtime.blocks.TransferableBlockTestUtils;
+import org.apache.pinot.spi.accounting.QueryResourceTracker;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
+import org.apache.pinot.spi.accounting.ThreadResourceTracker;
+import org.apache.pinot.spi.accounting.ThreadResourceUsageAccountant;
+import org.apache.pinot.spi.accounting.ThreadResourceUsageProvider;
+import org.apache.pinot.spi.env.PinotConfiguration;
+import org.apache.pinot.spi.trace.Tracing;
+import org.apache.pinot.spi.utils.CommonConstants;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.testng.ITest;
+import org.testng.annotations.AfterMethod;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Factory;
+import org.testng.annotations.Test;
+
+import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.DOUBLE;
+import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.INT;
+import static org.mockito.Mockito.when;
+import static org.mockito.MockitoAnnotations.openMocks;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+
+public class MultiStageAccountingTest implements ITest {
+  private AutoCloseable _mocks;
+  @Mock
+  private VirtualServerAddress _serverAddress;
+
+  protected String _testName;
+  protected MultiStageOperator _operator;
+
+  @Factory(dataProvider = "operatorProvider")
+  public MultiStageAccountingTest(String testName, MultiStageOperator 
operator) {
+    _testName = testName;
+    _operator = operator;
+  }
+
+  @BeforeClass
+  public static void setUpClass() {
+    ThreadResourceUsageProvider.setThreadMemoryMeasurementEnabled(true);
+    HashMap<String, Object> configs = new HashMap<>();
+    ServerMetrics.register(Mockito.mock(ServerMetrics.class));
+    
configs.put(CommonConstants.Accounting.CONFIG_OF_ALARMING_LEVEL_HEAP_USAGE_RATIO,
 0.00f);
+    
configs.put(CommonConstants.Accounting.CONFIG_OF_CRITICAL_LEVEL_HEAP_USAGE_RATIO,
 0.00f);
+    configs.put(CommonConstants.Accounting.CONFIG_OF_FACTORY_NAME,
+        "org.apache.pinot.core.accounting.PerQueryCPUMemAccountantFactory");
+    
configs.put(CommonConstants.Accounting.CONFIG_OF_ENABLE_THREAD_MEMORY_SAMPLING, 
true);
+    
configs.put(CommonConstants.Accounting.CONFIG_OF_ENABLE_THREAD_CPU_SAMPLING, 
false);
+    
configs.put(CommonConstants.Accounting.CONFIG_OF_OOM_PROTECTION_KILLING_QUERY, 
true);
+    // init accountant and start watcher task
+    Tracing.ThreadAccountantOps.initializeThreadAccountant(new 
PinotConfiguration(configs), "testGroupBy");
+
+    // Setup Thread Context
+    Tracing.ThreadAccountantOps.setupRunner("MultiStageAccountingTest", 
ThreadExecutionContext.TaskType.MSE);
+    ThreadExecutionContext threadExecutionContext = 
Tracing.getThreadAccountant().getThreadExecutionContext();
+    ThreadResourceUsageProvider threadResourceUsageProvider = new 
ThreadResourceUsageProvider();
+    Tracing.ThreadAccountantOps.setupWorker(1, 
ThreadExecutionContext.TaskType.MSE, threadResourceUsageProvider,
+        threadExecutionContext);
+  }
+
+  @BeforeMethod
+  public void setUp() {
+    _mocks = openMocks(this);
+    when(_serverAddress.toString()).thenReturn(new 
VirtualServerAddress("mock", 80, 0).toString());
+  }
+
+  @AfterMethod
+  public void tearDown()
+      throws Exception {
+    _mocks.close();
+  }
+
+  @Test
+  public void testOperatorAccounting() {
+    _operator.nextBlock().getContainer();
+
+    ThreadResourceUsageAccountant threadAccountant = 
Tracing.getThreadAccountant();
+    Collection<? extends QueryResourceTracker> queryResourceTrackers = 
threadAccountant.getQueryResources().values();
+    Collection<? extends ThreadResourceTracker> threadResourceTrackers = 
threadAccountant.getThreadResources();
+
+    // Then:
+    assertEquals(queryResourceTrackers.size(), 1);
+    assertEquals(threadResourceTrackers.size(), 1);
+    assertTrue(queryResourceTrackers.iterator().next().getAllocatedBytes() > 
0);
+    assertTrue(threadResourceTrackers.iterator().next().getAllocatedBytes() > 
0);
+    assertTrue(_operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second 
block is EOS (done processing)");
+  }
+
+  @DataProvider(name = "operatorProvider")
+  public static Object[][] getOperators() {
+    return new Object[][]{
+        {"AggregateOperator", getAggregateOperator()},
+        {"SortOperator", getSortOperator()},
+        {"HashJoinOperator", getHashJoinOperator()},
+        {"WindowAggregateOperator", getWindowAggregateOperator()},
+        {"SetOperator", getIntersectOperator()}
+    };
+  }
+
+  private static MultiStageOperator getAggregateOperator() {
+    MultiStageOperator input = Mockito.mock();
+    // Given:
+    List<RexExpression.FunctionCall> aggCalls = List.of(getSum(new 
RexExpression.InputRef(1)));
+    List<Integer> filterArgs = List.of(-1);
+    List<Integer> groupKeys = List.of(0);
+    DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new 
DataSchema.ColumnDataType[]{INT, DOUBLE});
+    when(input.nextBlock()).thenReturn(OperatorTestUtil.block(inSchema, new 
Object[]{2, 1.0}))
+        
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
+    DataSchema resultSchema =
+        new DataSchema(new String[]{"group", "sum"}, new 
DataSchema.ColumnDataType[]{INT, DOUBLE});
+    return new AggregateOperator(OperatorTestUtil.getTracingContext(), input,
+        new AggregateNode(-1, resultSchema, PlanNode.NodeHint.EMPTY, 
List.of(), aggCalls, filterArgs, groupKeys,
+            AggregateNode.AggType.DIRECT));
+  }
+
+  private static MultiStageOperator getHashJoinOperator() {
+    MultiStageOperator leftInput = Mockito.mock();
+    MultiStageOperator rightInput = Mockito.mock();
+
+    DataSchema leftSchema = new DataSchema(new String[]{"int_col", 
"string_col"}, new DataSchema.ColumnDataType[]{
+        DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING
+    });
+    DataSchema rightSchema = new DataSchema(new String[]{"int_col", 
"string_col"}, new DataSchema.ColumnDataType[]{
+        DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING
+    });
+    when(leftInput.nextBlock()).thenReturn(
+            OperatorTestUtil.block(leftSchema, new Object[]{1, "Aa"}, new 
Object[]{2, "BB"}))
+        
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
+    when(rightInput.nextBlock()).thenReturn(
+            OperatorTestUtil.block(rightSchema, new Object[]{2, "Aa"}, new 
Object[]{2, "BB"}, new Object[]{3, "BB"}))
+        
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
+    DataSchema resultSchema = new DataSchema(new String[]{"int_col1", 
"string_col1", "int_col2", "string_co2"},
+        new DataSchema.ColumnDataType[]{
+            DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, 
DataSchema.ColumnDataType.INT,
+            DataSchema.ColumnDataType.STRING
+        });
+    return new HashJoinOperator(OperatorTestUtil.getTracingContext(), 
leftInput, leftSchema, rightInput,
+        new JoinNode(-1, resultSchema, PlanNode.NodeHint.EMPTY, List.of(), 
JoinRelType.INNER, List.of(0), List.of(0),
+            List.of()));
+  }
+
+  private static MultiStageOperator getSortOperator() {
+    MultiStageOperator input = Mockito.mock();
+    // Given:
+    DataSchema schema = new DataSchema(new String[]{"sort"}, new 
DataSchema.ColumnDataType[]{INT});
+    when(input.nextBlock()).thenReturn(
+            new TransferableBlock(List.of(new Object[]{2}, new Object[]{1}), 
schema, DataBlock.Type.ROW))
+        
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
+    List<RelFieldCollation> collations =
+        List.of(new RelFieldCollation(0, 
RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST));
+    return new SortOperator(OperatorTestUtil.getTracingContext(), input,
+        new SortNode(-1, schema, PlanNode.NodeHint.EMPTY, List.of(), 
collations, 10, 0));
+  }
+
+  private static MultiStageOperator getWindowAggregateOperator() {
+    MultiStageOperator input = Mockito.mock();
+    // Given:
+    DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new 
DataSchema.ColumnDataType[]{INT, INT});
+    when(input.nextBlock()).thenReturn(OperatorTestUtil.block(inputSchema, new 
Object[]{2, 1}))
+        
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
+    DataSchema resultSchema = new DataSchema(new String[]{"group", "arg", 
"sum"}, new DataSchema.ColumnDataType[]{
+        INT, INT, DOUBLE
+    });
+    List<Integer> keys = List.of(0);
+    List<RexExpression.FunctionCall> aggCalls = List.of(getSum(new 
RexExpression.InputRef(1)));
+    return new WindowAggregateOperator(OperatorTestUtil.getTracingContext(), 
input, inputSchema,
+        new WindowNode(-1, resultSchema, PlanNode.NodeHint.EMPTY, List.of(), 
keys, List.of(), aggCalls,
+            WindowNode.WindowFrameType.RANGE, Integer.MIN_VALUE, 
Integer.MAX_VALUE, List.of()));
+  }
+
+  private static MultiStageOperator getIntersectOperator() {
+    MultiStageOperator leftOperator = Mockito.mock();
+    MultiStageOperator rightOperator = Mockito.mock();
+
+    DataSchema schema = new DataSchema(new String[]{"int_col", "string_col"}, 
new DataSchema.ColumnDataType[]{
+        DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING
+    });
+    Mockito.when(leftOperator.nextBlock())
+        .thenReturn(OperatorTestUtil.block(schema, new Object[]{1, "AA"}, new 
Object[]{2, "BB"}, new Object[]{3, "CC"}))
+        
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
+    Mockito.when(rightOperator.nextBlock())
+        .thenReturn(OperatorTestUtil.block(schema, new Object[]{1, "AA"}, new 
Object[]{2, "BB"}, new Object[]{4, "DD"}))
+        
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
+
+    return new IntersectOperator(OperatorTestUtil.getTracingContext(), 
ImmutableList.of(leftOperator, rightOperator),
+        schema);
+  }
+
+  private static RexExpression.FunctionCall getSum(RexExpression arg) {
+    return new RexExpression.FunctionCall(DataSchema.ColumnDataType.INT, 
SqlKind.SUM.name(), List.of(arg));
+  }
+
+  @Override
+  public String getTestName() {
+    return _testName;
+  }
+}
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
index da4537cb19..f279e5992b 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
@@ -83,7 +83,7 @@ public class OperatorTestUtil {
   public static OpChainExecutionContext getOpChainContext(MailboxService 
mailboxService, long deadlineMs,
       StageMetadata stageMetadata) {
     return new OpChainExecutionContext(mailboxService, 0, deadlineMs, 
ImmutableMap.of(), stageMetadata,
-        stageMetadata.getWorkerMetadataList().get(0), null);
+        stageMetadata.getWorkerMetadataList().get(0), null, null);
   }
 
   public static OpChainExecutionContext getTracingContext() {
@@ -101,7 +101,7 @@ public class OperatorTestUtil {
     WorkerMetadata workerMetadata = new WorkerMetadata(0, ImmutableMap.of(), 
ImmutableMap.of());
     StageMetadata stageMetadata = new StageMetadata(0, 
ImmutableList.of(workerMetadata), ImmutableMap.of());
     OpChainExecutionContext opChainExecutionContext = new 
OpChainExecutionContext(mailboxService, 123L, Long.MAX_VALUE,
-        opChainMetadata, stageMetadata, workerMetadata, null);
+        opChainMetadata, stageMetadata, workerMetadata, null, null);
 
     StagePlan stagePlan = new StagePlan(null, stageMetadata);
 
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
index 2a568f45fc..3a0b23408e 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
@@ -105,7 +105,7 @@ public class QueryServerTest extends QueryTestSet {
     Worker.QueryRequest queryRequest = getQueryRequest(queryPlan, 1);
     Map<String, String> requestMetadata = 
QueryPlanSerDeUtils.fromProtoProperties(queryRequest.getMetadata());
     QueryRunner mockRunner = 
_queryRunnerMap.get(Integer.parseInt(requestMetadata.get(KEY_OF_SERVER_INSTANCE_PORT)));
-    doThrow(new RuntimeException("foo")).when(mockRunner).processQuery(any(), 
any(), any());
+    doThrow(new RuntimeException("foo")).when(mockRunner).processQuery(any(), 
any(), any(), any());
     // submit the request for testing.
     Worker.QueryResponse resp = submitRequest(queryRequest, requestMetadata);
     // reset the mock runner before assert.
@@ -148,7 +148,7 @@ public class QueryServerTest extends QueryTestSet {
             return planNode.equals(stagePlan.getRootNode()) && 
isStageMetadataEqual(stageMetadata,
                 stagePlan.getStageMetadata());
           }), argThat(requestMetadataMap -> requestId.equals(
-              
requestMetadataMap.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID))));
+              
requestMetadataMap.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID))),
 any());
           return true;
         } catch (Throwable t) {
           return false;
diff --git 
a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadExecutionContext.java
 
b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadExecutionContext.java
index 68fd9b03e2..7ea59ad378 100644
--- 
a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadExecutionContext.java
+++ 
b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadExecutionContext.java
@@ -23,6 +23,17 @@ package org.apache.pinot.spi.accounting;
  */
 public interface ThreadExecutionContext {
 
+   /**
+    * SSE: Single Stage Engine
+    * MSE: Multi Stage Engine
+    * UNKNOWN: Default
+    */
+   public enum TaskType {
+      SSE,
+      MSE,
+      UNKNOWN
+   }
+
    /**
     * get query id of the execution context
     * @return query id in string
@@ -34,4 +45,6 @@ public interface ThreadExecutionContext {
     * @return get the anchor thread of execution context
     */
    Thread getAnchorThread();
+
+   TaskType getTaskType();
 }
diff --git 
a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceTracker.java
 
b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceTracker.java
index ff78a0d33c..418210b376 100644
--- 
a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceTracker.java
+++ 
b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceTracker.java
@@ -49,4 +49,6 @@ public interface ThreadResourceTracker {
    * @return an int containing the task id.
    */
   int getTaskId();
+
+  ThreadExecutionContext.TaskType getTaskType();
 }
diff --git 
a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceUsageAccountant.java
 
b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceUsageAccountant.java
index 8be0632e6b..bab51cf21c 100644
--- 
a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceUsageAccountant.java
+++ 
b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceUsageAccountant.java
@@ -42,7 +42,8 @@ public interface ThreadResourceUsageAccountant {
    * @param taskId a unique task id
    * @param parentContext the parent execution context, null for root(runner) 
thread
    */
-  void createExecutionContext(String queryId, int taskId, @Nullable 
ThreadExecutionContext parentContext);
+  void createExecutionContext(String queryId, int taskId, 
ThreadExecutionContext.TaskType taskType,
+      @Nullable ThreadExecutionContext parentContext);
 
   /**
    * get the executon context of current thread
diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/trace/Tracing.java 
b/pinot-spi/src/main/java/org/apache/pinot/spi/trace/Tracing.java
index 910ebc35cf..59ad65eef5 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/trace/Tracing.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/trace/Tracing.java
@@ -200,7 +200,8 @@ public class Tracing {
     }
 
     @Override
-    public final void createExecutionContext(String queryId, int taskId, 
ThreadExecutionContext parentContext) {
+    public final void createExecutionContext(String queryId, int taskId, 
ThreadExecutionContext.TaskType taskType,
+        ThreadExecutionContext parentContext) {
       _anchorThread.set(parentContext == null ? Thread.currentThread() : 
parentContext.getAnchorThread());
       createExecutionContextInner(queryId, taskId, parentContext);
     }
@@ -220,6 +221,11 @@ public class Tracing {
         public Thread getAnchorThread() {
           return _anchorThread.get();
         }
+
+        @Override
+        public TaskType getTaskType() {
+          return TaskType.UNKNOWN;
+        }
       };
     }
 
@@ -254,14 +260,37 @@ public class Tracing {
     }
 
     public static void setupRunner(String queryId) {
+      setupRunner(queryId, ThreadExecutionContext.TaskType.SSE);
+    }
+
+    public static void setupRunner(String queryId, 
ThreadExecutionContext.TaskType taskType) {
       Tracing.getThreadAccountant().setThreadResourceUsageProvider(new 
ThreadResourceUsageProvider());
-      Tracing.getThreadAccountant().createExecutionContext(queryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID, null);
+      Tracing.getThreadAccountant()
+          .createExecutionContext(queryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID, taskType, null);
     }
 
+    /**
+     * Setup metadata of query worker threads. This function assumes that the 
workers are for Single Stage Engine.
+     * @param taskId Query task ID of the thread. In SSE, ID is an 
incrementing counter. In MSE, id is the stage id.
+     * @param threadResourceUsageProvider Object that measures resource usage.
+     * @param threadExecutionContext Context holds metadata about the query.
+     */
     public static void setupWorker(int taskId, ThreadResourceUsageProvider 
threadResourceUsageProvider,
         ThreadExecutionContext threadExecutionContext) {
+      setupWorker(taskId, ThreadExecutionContext.TaskType.SSE, 
threadResourceUsageProvider, threadExecutionContext);
+    }
+
+    /**
+     * Setup metadata of query worker threads.
+     * @param taskId Query task ID of the thread. In SSE, ID is an 
incrementing counter. In MSE, id is the stage id.
+     * @param threadResourceUsageProvider Object that measures resource usage.
+     * @param threadExecutionContext Context holds metadata about the query.
+     */
+    public static void setupWorker(int taskId, ThreadExecutionContext.TaskType 
taskType,
+        ThreadResourceUsageProvider threadResourceUsageProvider,
+        ThreadExecutionContext threadExecutionContext) {
       
Tracing.getThreadAccountant().setThreadResourceUsageProvider(threadResourceUsageProvider);
-      Tracing.getThreadAccountant().createExecutionContext(null, taskId, 
threadExecutionContext);
+      Tracing.getThreadAccountant().createExecutionContext(null, taskId, 
taskType, threadExecutionContext);
     }
 
     public static void sample() {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to