This is an automated email from the ASF dual-hosted git repository. gortiz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new d008709064 Add OOM Protection Support for Multi-Stage Queries (#13598) d008709064 is described below commit d00870906412506fbd03f58fb5f4b78022b2b43d Author: Rajat Venkatesh <1638298+vra...@users.noreply.github.com> AuthorDate: Tue Sep 3 14:40:47 2024 +0530 Add OOM Protection Support for Multi-Stage Queries (#13598) track cpu and memory usage in multi-stage queries if query resource usage tracking is enabled --- .../MultiStageBrokerRequestHandler.java | 6 + .../CPUMemThreadLevelAccountingObjects.java | 20 +- .../PerQueryCPUMemAccountantFactory.java | 4 +- .../apache/pinot/query/runtime/QueryRunner.java | 8 +- .../runtime/executor/OpChainSchedulerService.java | 8 + .../query/runtime/operator/AggregateOperator.java | 2 + .../query/runtime/operator/HashJoinOperator.java | 2 + .../runtime/operator/MailboxReceiveOperator.java | 2 + .../runtime/operator/MailboxSendOperator.java | 1 + .../query/runtime/operator/MultiStageOperator.java | 9 + .../pinot/query/runtime/operator/OpChain.java | 7 + .../pinot/query/runtime/operator/OpChainId.java | 4 + .../pinot/query/runtime/operator/SetOperator.java | 2 + .../pinot/query/runtime/operator/SortOperator.java | 1 + .../runtime/operator/WindowAggregateOperator.java | 1 + .../runtime/plan/OpChainExecutionContext.java | 10 +- .../plan/pipeline/PipelineBreakerExecutor.java | 15 +- .../query/service/dispatch/QueryDispatcher.java | 5 +- .../pinot/query/service/server/QueryServer.java | 10 +- .../apache/pinot/query/QueryServerEnclosure.java | 3 +- .../executor/OpChainSchedulerServiceTest.java | 2 +- .../runtime/operator/MailboxSendOperatorTest.java | 2 +- .../runtime/operator/MultiStageAccountingTest.java | 239 +++++++++++++++++++++ .../query/runtime/operator/OperatorTestUtil.java | 4 +- .../query/service/server/QueryServerTest.java | 4 +- .../spi/accounting/ThreadExecutionContext.java | 13 ++ .../spi/accounting/ThreadResourceTracker.java | 2 + .../accounting/ThreadResourceUsageAccountant.java | 3 +- .../java/org/apache/pinot/spi/trace/Tracing.java | 35 ++- 29 files changed, 399 insertions(+), 25 deletions(-) diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java index a1e82dbd53..fad63d18e2 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java @@ -64,10 +64,12 @@ import org.apache.pinot.query.routing.WorkerManager; import org.apache.pinot.query.runtime.MultiStageStatsTreeBuilder; import org.apache.pinot.query.runtime.plan.MultiStageQueryStats; import org.apache.pinot.query.service.dispatch.QueryDispatcher; +import org.apache.pinot.spi.accounting.ThreadExecutionContext; import org.apache.pinot.spi.auth.TableAuthorizationResult; import org.apache.pinot.spi.env.PinotConfiguration; import org.apache.pinot.spi.exception.DatabaseConflictException; import org.apache.pinot.spi.trace.RequestContext; +import org.apache.pinot.spi.trace.Tracing; import org.apache.pinot.spi.utils.CommonConstants; import org.apache.pinot.spi.utils.builder.TableNameBuilder; import org.apache.pinot.sql.parsers.SqlNodeAndOptions; @@ -210,6 +212,8 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler { return new BrokerResponseNative(QueryException.getException(QueryException.QUOTA_EXCEEDED_ERROR, errorMessage)); } + Tracing.ThreadAccountantOps.setupRunner(String.valueOf(requestId), ThreadExecutionContext.TaskType.MSE); + long executionStartTimeNs = System.nanoTime(); QueryDispatcher.QueryResult queryResults; try { @@ -228,6 +232,8 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler { requestContext.setErrorCode(QueryException.QUERY_EXECUTION_ERROR_CODE); return new BrokerResponseNative( QueryException.getException(QueryException.QUERY_EXECUTION_ERROR, consolidatedMessage)); + } finally { + Tracing.getThreadAccountant().clear(); } long executionEndTimeNs = System.nanoTime(); updatePhaseTimingForTables(tableNames, BrokerQueryPhase.QUERY_EXECUTION, executionEndTimeNs - executionStartTimeNs); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/accounting/CPUMemThreadLevelAccountingObjects.java b/pinot-core/src/main/java/org/apache/pinot/core/accounting/CPUMemThreadLevelAccountingObjects.java index 6a375b95bb..378fcd991f 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/accounting/CPUMemThreadLevelAccountingObjects.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/accounting/CPUMemThreadLevelAccountingObjects.java @@ -103,8 +103,15 @@ public class CPUMemThreadLevelAccountingObjects { return taskEntry == null ? -1 : taskEntry.getTaskId(); } - public void setThreadTaskStatus(@Nonnull String queryId, int taskId, @Nonnull Thread anchorThread) { - _currentThreadTaskStatus.set(new TaskEntry(queryId, taskId, anchorThread)); + @Override + public ThreadExecutionContext.TaskType getTaskType() { + TaskEntry taskEntry = _currentThreadTaskStatus.get(); + return taskEntry == null ? ThreadExecutionContext.TaskType.UNKNOWN : taskEntry.getTaskType(); + } + + public void setThreadTaskStatus(@Nonnull String queryId, int taskId, ThreadExecutionContext.TaskType taskType, + @Nonnull Thread anchorThread) { + _currentThreadTaskStatus.set(new TaskEntry(queryId, taskId, taskType, anchorThread)); } } @@ -117,15 +124,17 @@ public class CPUMemThreadLevelAccountingObjects { private final String _queryId; private final int _taskId; private final Thread _anchorThread; + private final TaskType _taskType; public boolean isAnchorThread() { return _taskId == CommonConstants.Accounting.ANCHOR_TASK_ID; } - public TaskEntry(String queryId, int taskId, Thread anchorThread) { + public TaskEntry(String queryId, int taskId, TaskType taskType, Thread anchorThread) { _queryId = queryId; _taskId = taskId; _anchorThread = anchorThread; + _taskType = taskType; } public String getQueryId() { @@ -140,6 +149,11 @@ public class CPUMemThreadLevelAccountingObjects { return _anchorThread; } + @Override + public TaskType getTaskType() { + return _taskType; + } + @Override public String toString() { return "TaskEntry{" + "_queryId='" + _queryId + '\'' + ", _taskId=" + _taskId + ", _rootThread=" + _anchorThread diff --git a/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java index 34e496d3bf..fef0c3beeb 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java @@ -293,10 +293,10 @@ public class PerQueryCPUMemAccountantFactory implements ThreadAccountantFactory // is anchor thread assert queryId != null; _threadLocalEntry.get().setThreadTaskStatus(queryId, CommonConstants.Accounting.ANCHOR_TASK_ID, - Thread.currentThread()); + ThreadExecutionContext.TaskType.UNKNOWN, Thread.currentThread()); } else { // not anchor thread - _threadLocalEntry.get().setThreadTaskStatus(parentContext.getQueryId(), taskId, + _threadLocalEntry.get().setThreadTaskStatus(parentContext.getQueryId(), taskId, parentContext.getTaskType(), parentContext.getAnchorThread()); } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java index 0032545c0f..77d9b53d43 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java @@ -50,6 +50,7 @@ import org.apache.pinot.query.runtime.plan.PhysicalPlanVisitor; import org.apache.pinot.query.runtime.plan.pipeline.PipelineBreakerExecutor; import org.apache.pinot.query.runtime.plan.pipeline.PipelineBreakerResult; import org.apache.pinot.query.runtime.plan.server.ServerPlanRequestUtils; +import org.apache.pinot.spi.accounting.ThreadExecutionContext; import org.apache.pinot.spi.env.PinotConfiguration; import org.apache.pinot.spi.utils.CommonConstants; import org.apache.pinot.spi.utils.CommonConstants.Broker.Request.QueryOptionKey; @@ -152,7 +153,8 @@ public class QueryRunner { * <p>This execution entry point should be asynchronously called by the request handler and caller should not wait * for results/exceptions.</p> */ - public void processQuery(WorkerMetadata workerMetadata, StagePlan stagePlan, Map<String, String> requestMetadata) { + public void processQuery(WorkerMetadata workerMetadata, StagePlan stagePlan, Map<String, String> requestMetadata, + @Nullable ThreadExecutionContext parentContext) { long requestId = Long.parseLong(requestMetadata.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID)); long timeoutMs = Long.parseLong(requestMetadata.get(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS)); long deadlineMs = System.currentTimeMillis() + timeoutMs; @@ -163,7 +165,7 @@ public class QueryRunner { // run pre-stage execution for all pipeline breakers PipelineBreakerResult pipelineBreakerResult = PipelineBreakerExecutor.executePipelineBreakers(_opChainScheduler, _mailboxService, workerMetadata, stagePlan, - opChainMetadata, requestId, deadlineMs); + opChainMetadata, requestId, deadlineMs, parentContext); // Send error block to all the receivers if pipeline breaker fails if (pipelineBreakerResult != null && pipelineBreakerResult.getErrorBlock() != null) { @@ -196,7 +198,7 @@ public class QueryRunner { // run OpChain OpChainExecutionContext executionContext = new OpChainExecutionContext(_mailboxService, requestId, deadlineMs, opChainMetadata, stageMetadata, - workerMetadata, pipelineBreakerResult); + workerMetadata, pipelineBreakerResult, parentContext); OpChain opChain; if (workerMetadata.isLeafStageWorker()) { opChain = ServerPlanRequestUtils.compileLeafStage(executionContext, stagePlan, _helixManager, _serverMetrics, diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java index ef0f9a7d12..bfd74c68d5 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java @@ -27,6 +27,9 @@ import org.apache.pinot.core.util.trace.TraceRunnable; import org.apache.pinot.query.runtime.blocks.TransferableBlock; import org.apache.pinot.query.runtime.operator.OpChain; import org.apache.pinot.query.runtime.operator.OpChainId; +import org.apache.pinot.spi.accounting.ThreadExecutionContext; +import org.apache.pinot.spi.accounting.ThreadResourceUsageProvider; +import org.apache.pinot.spi.trace.Tracing; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,6 +53,10 @@ public class OpChainSchedulerService { TransferableBlock returnedErrorBlock = null; Throwable thrown = null; try { + ThreadResourceUsageProvider threadResourceUsageProvider = new ThreadResourceUsageProvider(); + Tracing.ThreadAccountantOps.setupWorker(operatorChain.getId().getStageId(), + ThreadExecutionContext.TaskType.MSE, threadResourceUsageProvider, + operatorChain.getParentContext()); LOGGER.trace("({}): Executing", operatorChain); TransferableBlock result = operatorChain.getRoot().nextBlock(); while (!result.isEndOfStreamBlock()) { @@ -76,6 +83,7 @@ public class OpChainSchedulerService { } else if (isFinished) { operatorChain.close(); } + Tracing.ThreadAccountantOps.clear(); } } }); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java index ce6d30d451..38ff7d2d5c 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java @@ -173,6 +173,7 @@ public class AggregateOperator extends MultiStageOperator { TransferableBlock block = _input.nextBlock(); while (block.isDataBlock()) { _groupByExecutor.processBlock(block); + sampleAndCheckInterruption(); block = _input.nextBlock(); } return block; @@ -187,6 +188,7 @@ public class AggregateOperator extends MultiStageOperator { TransferableBlock block = _input.nextBlock(); while (block.isDataBlock()) { _aggregationExecutor.processBlock(block); + sampleAndCheckInterruption(); block = _input.nextBlock(); } return block; diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java index c18deb2ea4..0dee6c06bd 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java @@ -253,6 +253,7 @@ public class HashJoinOperator extends MultiStageOperator { hashCollection.add(row); } _currentRowsInHashTable += container.size(); + sampleAndCheckInterruption(); rightBlock = _rightInput.nextBlock(); } if (rightBlock.isErrorBlock()) { @@ -297,6 +298,7 @@ public class HashJoinOperator extends MultiStageOperator { } assert leftBlock.isDataBlock(); List<Object[]> rows = buildJoinedRows(leftBlock); + sampleAndCheckInterruption(); if (!rows.isEmpty()) { return new TransferableBlock(rows, _resultSchema, DataBlock.Type.ROW); } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java index c1c7647a1e..82e60e34f6 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java @@ -60,6 +60,8 @@ public class MailboxReceiveOperator extends BaseMailboxReceiveOperator { } if (block.isSuccessfulEndOfStreamBlock()) { updateEosBlock(block, _statMap); + } else if (block.isDataBlock()) { + sampleAndCheckInterruption(); } return block; } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java index 7ebbaa2e91..864f200fe6 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java @@ -143,6 +143,7 @@ public class MailboxSendOperator extends MultiStageOperator { earlyTerminate(); } } + sampleAndCheckInterruption(); return block; } catch (QueryCancelledException e) { LOGGER.debug("Query was cancelled! for opChain: {}", _context.getId()); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultiStageOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultiStageOperator.java index 0321bedc1b..2690a5a7f7 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultiStageOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultiStageOperator.java @@ -66,6 +66,15 @@ public abstract class MultiStageOperator public abstract void registerExecution(long time, int numRows); + // Samples resource usage of the operator. The operator should call this function for every block of data or + // assuming the block holds 10000 rows or more. + protected void sampleAndCheckInterruption() { + Tracing.ThreadAccountantOps.sample(); + if (Tracing.ThreadAccountantOps.isInterrupted()) { + earlyTerminate(); + } + } + /** * Returns the next block from the operator. It should return non-empty data blocks followed by an end-of-stream (EOS) * block when all the data is processed, or an error block if an error occurred. After it returns EOS or error block, diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChain.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChain.java index 5d989f169f..86eca19b49 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChain.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChain.java @@ -22,6 +22,7 @@ import java.util.function.Consumer; import org.apache.pinot.core.common.Operator; import org.apache.pinot.query.runtime.blocks.TransferableBlock; import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; +import org.apache.pinot.spi.accounting.ThreadExecutionContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,6 +37,7 @@ public class OpChain implements AutoCloseable { private final OpChainId _id; private final MultiStageOperator _root; private final Consumer<OpChainId> _finishCallback; + private final ThreadExecutionContext _parentContext; public OpChain(OpChainExecutionContext context, MultiStageOperator root) { this(context, root, (id) -> { @@ -46,6 +48,7 @@ public class OpChain implements AutoCloseable { _id = context.getId(); _root = root; _finishCallback = finishCallback; + _parentContext = context.getParentContext(); } public OpChainId getId() { @@ -56,6 +59,10 @@ public class OpChain implements AutoCloseable { return _root; } + public ThreadExecutionContext getParentContext() { + return _parentContext; + } + @Override public String toString() { return "OpChain{" + _id + "}"; diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChainId.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChainId.java index dae9c2f755..e78dacd8a4 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChainId.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/OpChainId.java @@ -40,6 +40,10 @@ public class OpChainId { return _virtualServerId; } + public int getStageId() { + return _stageId; + } + @Override public String toString() { return String.format("%s_%s_%s", _requestId, _virtualServerId, _stageId); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SetOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SetOperator.java index ea5cf046d7..e92edb1123 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SetOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SetOperator.java @@ -119,6 +119,7 @@ public abstract class SetOperator extends MultiStageOperator { _rightRowSet.add(new Record(row)); } } + sampleAndCheckInterruption(); block = _rightChildOperator.nextBlock(); } if (block.isErrorBlock()) { @@ -153,6 +154,7 @@ public abstract class SetOperator extends MultiStageOperator { rows.add(row); } } + sampleAndCheckInterruption(); if (!rows.isEmpty()) { return new TransferableBlock(rows, _dataSchema, DataBlock.Type.ROW); } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java index 9f553424eb..a5a5e15e5f 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java @@ -181,6 +181,7 @@ public class SortOperator extends MultiStageOperator { for (Object[] row : container) { SelectionOperatorUtils.addToPriorityQueue(row, _priorityQueue, _numRowsToKeep); } + sampleAndCheckInterruption(); } block = _input.nextBlock(); } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java index 27001778d4..7d97356eb2 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java @@ -269,6 +269,7 @@ public class WindowAggregateOperator extends MultiStageOperator { _partitionRows.computeIfAbsent(key, k -> new ArrayList<>()).add(row); } _numRows += containerSize; + sampleAndCheckInterruption(); block = _input.nextBlock(); } // Early termination if the block is an error block diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java index 3290478de8..f50c169c45 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java @@ -28,6 +28,7 @@ import org.apache.pinot.query.routing.WorkerMetadata; import org.apache.pinot.query.runtime.operator.OpChainId; import org.apache.pinot.query.runtime.plan.pipeline.PipelineBreakerResult; import org.apache.pinot.query.runtime.plan.server.ServerPlanRequestContext; +import org.apache.pinot.spi.accounting.ThreadExecutionContext; import org.apache.pinot.spi.utils.CommonConstants; @@ -48,12 +49,13 @@ public class OpChainExecutionContext { @Nullable private final PipelineBreakerResult _pipelineBreakerResult; private final boolean _traceEnabled; + private final ThreadExecutionContext _parentContext; private ServerPlanRequestContext _leafStageContext; public OpChainExecutionContext(MailboxService mailboxService, long requestId, long deadlineMs, Map<String, String> opChainMetadata, StageMetadata stageMetadata, WorkerMetadata workerMetadata, - @Nullable PipelineBreakerResult pipelineBreakerResult) { + @Nullable PipelineBreakerResult pipelineBreakerResult, @Nullable ThreadExecutionContext parentContext) { _mailboxService = mailboxService; _requestId = requestId; _deadlineMs = deadlineMs; @@ -65,6 +67,7 @@ public class OpChainExecutionContext { _id = new OpChainId(requestId, workerMetadata.getWorkerId(), stageMetadata.getStageId()); _pipelineBreakerResult = pipelineBreakerResult; _traceEnabled = Boolean.parseBoolean(opChainMetadata.get(CommonConstants.Broker.Request.TRACE)); + _parentContext = parentContext; } public MailboxService getMailboxService() { @@ -123,4 +126,9 @@ public class OpChainExecutionContext { public void setLeafStageContext(ServerPlanRequestContext leafStageContext) { _leafStageContext = leafStageContext; } + + @Nullable + public ThreadExecutionContext getParentContext() { + return _parentContext; + } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java index 2e0cc7003d..0c3d1ac7e4 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java @@ -37,6 +37,7 @@ import org.apache.pinot.query.runtime.executor.OpChainSchedulerService; import org.apache.pinot.query.runtime.operator.OpChain; import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; import org.apache.pinot.query.runtime.plan.PhysicalPlanVisitor; +import org.apache.pinot.spi.accounting.ThreadExecutionContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,6 +51,14 @@ public class PipelineBreakerExecutor { private static final Logger LOGGER = LoggerFactory.getLogger(PipelineBreakerExecutor.class); + @Nullable + public static PipelineBreakerResult executePipelineBreakers(OpChainSchedulerService scheduler, + MailboxService mailboxService, WorkerMetadata workerMetadata, StagePlan stagePlan, + Map<String, String> opChainMetadata, long requestId, long deadlineMs) { + return executePipelineBreakers(scheduler, mailboxService, workerMetadata, stagePlan, opChainMetadata, requestId, + deadlineMs, null); + } + /** * Execute a pipeline breaker and collect the results (synchronously). Currently, pipeline breaker executor can only * execute mailbox receive pipeline breaker. @@ -61,6 +70,7 @@ public class PipelineBreakerExecutor { * @param opChainMetadata request metadata, including query options * @param requestId request ID * @param deadlineMs execution deadline + * @param parentContext Parent thread metadata * @return pipeline breaker result; * - If exception occurs, exception block will be wrapped in {@link TransferableBlock} and assigned to each PB node. * - Normal stats will be attached to each PB node and downstream execution should return with stats attached. @@ -68,7 +78,8 @@ public class PipelineBreakerExecutor { @Nullable public static PipelineBreakerResult executePipelineBreakers(OpChainSchedulerService scheduler, MailboxService mailboxService, WorkerMetadata workerMetadata, StagePlan stagePlan, - Map<String, String> opChainMetadata, long requestId, long deadlineMs) { + Map<String, String> opChainMetadata, long requestId, long deadlineMs, + @Nullable ThreadExecutionContext parentContext) { PipelineBreakerContext pipelineBreakerContext = new PipelineBreakerContext(); PipelineBreakerVisitor.visitPlanRoot(stagePlan.getRootNode(), pipelineBreakerContext); if (!pipelineBreakerContext.getPipelineBreakerMap().isEmpty()) { @@ -78,7 +89,7 @@ public class PipelineBreakerExecutor { // see also: MailboxIdUtils TODOs, de-couple mailbox id from query information OpChainExecutionContext opChainExecutionContext = new OpChainExecutionContext(mailboxService, requestId, deadlineMs, opChainMetadata, - stagePlan.getStageMetadata(), workerMetadata, null); + stagePlan.getStageMetadata(), workerMetadata, null, parentContext); return execute(scheduler, pipelineBreakerContext, opChainExecutionContext); } catch (Exception e) { LOGGER.error("Caught exception executing pipeline breaker for request: {}, stage: {}", requestId, diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java index 0d2aac509a..fc8874911e 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java @@ -60,7 +60,9 @@ import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils; import org.apache.pinot.query.runtime.operator.MailboxReceiveOperator; import org.apache.pinot.query.runtime.plan.MultiStageQueryStats; import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; +import org.apache.pinot.spi.accounting.ThreadExecutionContext; import org.apache.pinot.spi.trace.RequestContext; +import org.apache.pinot.spi.trace.Tracing; import org.apache.pinot.spi.utils.CommonConstants; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -260,9 +262,10 @@ public class QueryDispatcher { Preconditions.checkState(workerMetadataList.size() == 1, "Expecting single worker for reduce stage, got: %s", workerMetadataList.size()); StageMetadata stageMetadata = new StageMetadata(0, workerMetadataList, dispatchableStagePlan.getCustomProperties()); + ThreadExecutionContext parentContext = Tracing.getThreadAccountant().getThreadExecutionContext(); OpChainExecutionContext opChainExecutionContext = new OpChainExecutionContext(mailboxService, requestId, deadlineMs, queryOptions, stageMetadata, - workerMetadataList.get(0), null); + workerMetadataList.get(0), null, parentContext); PairList<Integer, String> resultFields = dispatchableSubPlan.getQueryResultFields(); DataSchema sourceDataSchema = receiveNode.getDataSchema(); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java index c8caed9100..763192e16e 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java @@ -37,6 +37,8 @@ import org.apache.pinot.query.routing.StagePlan; import org.apache.pinot.query.routing.WorkerMetadata; import org.apache.pinot.query.runtime.QueryRunner; import org.apache.pinot.query.service.dispatch.QueryDispatcher; +import org.apache.pinot.spi.accounting.ThreadExecutionContext; +import org.apache.pinot.spi.trace.Tracing; import org.apache.pinot.spi.utils.CommonConstants; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -112,9 +114,12 @@ public class QueryServer extends PinotQueryWorkerGrpc.PinotQueryWorkerImplBase { long timeoutMs = Long.parseLong(requestMetadata.get(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS)); long deadlineMs = System.currentTimeMillis() + timeoutMs; + Tracing.ThreadAccountantOps.setupRunner(Long.toString(requestId), ThreadExecutionContext.TaskType.MSE); + List<Worker.StagePlan> protoStagePlans = request.getStagePlanList(); int numStages = protoStagePlans.size(); CompletableFuture<?>[] stageSubmissionStubs = new CompletableFuture[numStages]; + ThreadExecutionContext parentContext = Tracing.getThreadAccountant().getThreadExecutionContext(); for (int i = 0; i < numStages; i++) { Worker.StagePlan protoStagePlan = protoStagePlans.get(i); stageSubmissionStubs[i] = CompletableFuture.runAsync(() -> { @@ -133,8 +138,8 @@ public class QueryServer extends PinotQueryWorkerGrpc.PinotQueryWorkerImplBase { for (int j = 0; j < numWorkers; j++) { WorkerMetadata workerMetadata = workerMetadataList.get(j); workerSubmissionStubs[j] = - CompletableFuture.runAsync(() -> _queryRunner.processQuery(workerMetadata, stagePlan, requestMetadata), - _querySubmissionExecutorService); + CompletableFuture.runAsync(() -> _queryRunner.processQuery(workerMetadata, stagePlan, requestMetadata, + parentContext), _querySubmissionExecutorService); } try { CompletableFuture.allOf(workerSubmissionStubs) @@ -167,6 +172,7 @@ public class QueryServer extends PinotQueryWorkerGrpc.PinotQueryWorkerImplBase { future.cancel(true); } } + Tracing.getThreadAccountant().clear(); } responseObserver.onNext( Worker.QueryResponse.newBuilder().putMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_OK, "") diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryServerEnclosure.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryServerEnclosure.java index 3de9cc6fdb..1e03b31411 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryServerEnclosure.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryServerEnclosure.java @@ -112,7 +112,8 @@ public class QueryServerEnclosure { public CompletableFuture<Void> processQuery(WorkerMetadata workerMetadata, StagePlan stagePlan, Map<String, String> requestMetadataMap) { - return CompletableFuture.runAsync(() -> _queryRunner.processQuery(workerMetadata, stagePlan, requestMetadataMap), + return CompletableFuture.runAsync( + () -> _queryRunner.processQuery(workerMetadata, stagePlan, requestMetadataMap, null), _queryRunner.getExecutorService()); } } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java index aff6a68853..61aad6a9a7 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java @@ -79,7 +79,7 @@ public class OpChainSchedulerServiceTest { WorkerMetadata workerMetadata = new WorkerMetadata(0, ImmutableMap.of(), ImmutableMap.of()); OpChainExecutionContext context = new OpChainExecutionContext(mailboxService, 123L, Long.MAX_VALUE, ImmutableMap.of(), - new StageMetadata(0, ImmutableList.of(workerMetadata), ImmutableMap.of()), workerMetadata, null); + new StageMetadata(0, ImmutableList.of(workerMetadata), ImmutableMap.of()), workerMetadata, null, null); return new OpChain(context, operator); } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java index cc92873c94..a54dc182a9 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java @@ -197,7 +197,7 @@ public class MailboxSendOperatorTest { StageMetadata stageMetadata = new StageMetadata(SENDER_STAGE_ID, List.of(workerMetadata), Map.of()); OpChainExecutionContext context = new OpChainExecutionContext(_mailboxService, 123L, Long.MAX_VALUE, Map.of(), stageMetadata, workerMetadata, - null); + null, null); return new MailboxSendOperator(context, _input, statMap -> _exchange); } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java new file mode 100644 index 0000000000..2821de5e73 --- /dev/null +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java @@ -0,0 +1,239 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator; + +import com.google.common.collect.ImmutableList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.SqlKind; +import org.apache.pinot.common.datablock.DataBlock; +import org.apache.pinot.common.metrics.ServerMetrics; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.planner.plannode.AggregateNode; +import org.apache.pinot.query.planner.plannode.JoinNode; +import org.apache.pinot.query.planner.plannode.PlanNode; +import org.apache.pinot.query.planner.plannode.SortNode; +import org.apache.pinot.query.planner.plannode.WindowNode; +import org.apache.pinot.query.routing.VirtualServerAddress; +import org.apache.pinot.query.runtime.blocks.TransferableBlock; +import org.apache.pinot.query.runtime.blocks.TransferableBlockTestUtils; +import org.apache.pinot.spi.accounting.QueryResourceTracker; +import org.apache.pinot.spi.accounting.ThreadExecutionContext; +import org.apache.pinot.spi.accounting.ThreadResourceTracker; +import org.apache.pinot.spi.accounting.ThreadResourceUsageAccountant; +import org.apache.pinot.spi.accounting.ThreadResourceUsageProvider; +import org.apache.pinot.spi.env.PinotConfiguration; +import org.apache.pinot.spi.trace.Tracing; +import org.apache.pinot.spi.utils.CommonConstants; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.testng.ITest; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Factory; +import org.testng.annotations.Test; + +import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.DOUBLE; +import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.INT; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + + +public class MultiStageAccountingTest implements ITest { + private AutoCloseable _mocks; + @Mock + private VirtualServerAddress _serverAddress; + + protected String _testName; + protected MultiStageOperator _operator; + + @Factory(dataProvider = "operatorProvider") + public MultiStageAccountingTest(String testName, MultiStageOperator operator) { + _testName = testName; + _operator = operator; + } + + @BeforeClass + public static void setUpClass() { + ThreadResourceUsageProvider.setThreadMemoryMeasurementEnabled(true); + HashMap<String, Object> configs = new HashMap<>(); + ServerMetrics.register(Mockito.mock(ServerMetrics.class)); + configs.put(CommonConstants.Accounting.CONFIG_OF_ALARMING_LEVEL_HEAP_USAGE_RATIO, 0.00f); + configs.put(CommonConstants.Accounting.CONFIG_OF_CRITICAL_LEVEL_HEAP_USAGE_RATIO, 0.00f); + configs.put(CommonConstants.Accounting.CONFIG_OF_FACTORY_NAME, + "org.apache.pinot.core.accounting.PerQueryCPUMemAccountantFactory"); + configs.put(CommonConstants.Accounting.CONFIG_OF_ENABLE_THREAD_MEMORY_SAMPLING, true); + configs.put(CommonConstants.Accounting.CONFIG_OF_ENABLE_THREAD_CPU_SAMPLING, false); + configs.put(CommonConstants.Accounting.CONFIG_OF_OOM_PROTECTION_KILLING_QUERY, true); + // init accountant and start watcher task + Tracing.ThreadAccountantOps.initializeThreadAccountant(new PinotConfiguration(configs), "testGroupBy"); + + // Setup Thread Context + Tracing.ThreadAccountantOps.setupRunner("MultiStageAccountingTest", ThreadExecutionContext.TaskType.MSE); + ThreadExecutionContext threadExecutionContext = Tracing.getThreadAccountant().getThreadExecutionContext(); + ThreadResourceUsageProvider threadResourceUsageProvider = new ThreadResourceUsageProvider(); + Tracing.ThreadAccountantOps.setupWorker(1, ThreadExecutionContext.TaskType.MSE, threadResourceUsageProvider, + threadExecutionContext); + } + + @BeforeMethod + public void setUp() { + _mocks = openMocks(this); + when(_serverAddress.toString()).thenReturn(new VirtualServerAddress("mock", 80, 0).toString()); + } + + @AfterMethod + public void tearDown() + throws Exception { + _mocks.close(); + } + + @Test + public void testOperatorAccounting() { + _operator.nextBlock().getContainer(); + + ThreadResourceUsageAccountant threadAccountant = Tracing.getThreadAccountant(); + Collection<? extends QueryResourceTracker> queryResourceTrackers = threadAccountant.getQueryResources().values(); + Collection<? extends ThreadResourceTracker> threadResourceTrackers = threadAccountant.getThreadResources(); + + // Then: + assertEquals(queryResourceTrackers.size(), 1); + assertEquals(threadResourceTrackers.size(), 1); + assertTrue(queryResourceTrackers.iterator().next().getAllocatedBytes() > 0); + assertTrue(threadResourceTrackers.iterator().next().getAllocatedBytes() > 0); + assertTrue(_operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @DataProvider(name = "operatorProvider") + public static Object[][] getOperators() { + return new Object[][]{ + {"AggregateOperator", getAggregateOperator()}, + {"SortOperator", getSortOperator()}, + {"HashJoinOperator", getHashJoinOperator()}, + {"WindowAggregateOperator", getWindowAggregateOperator()}, + {"SetOperator", getIntersectOperator()} + }; + } + + private static MultiStageOperator getAggregateOperator() { + MultiStageOperator input = Mockito.mock(); + // Given: + List<RexExpression.FunctionCall> aggCalls = List.of(getSum(new RexExpression.InputRef(1))); + List<Integer> filterArgs = List.of(-1); + List<Integer> groupKeys = List.of(0); + DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new DataSchema.ColumnDataType[]{INT, DOUBLE}); + when(input.nextBlock()).thenReturn(OperatorTestUtil.block(inSchema, new Object[]{2, 1.0})) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + DataSchema resultSchema = + new DataSchema(new String[]{"group", "sum"}, new DataSchema.ColumnDataType[]{INT, DOUBLE}); + return new AggregateOperator(OperatorTestUtil.getTracingContext(), input, + new AggregateNode(-1, resultSchema, PlanNode.NodeHint.EMPTY, List.of(), aggCalls, filterArgs, groupKeys, + AggregateNode.AggType.DIRECT)); + } + + private static MultiStageOperator getHashJoinOperator() { + MultiStageOperator leftInput = Mockito.mock(); + MultiStageOperator rightInput = Mockito.mock(); + + DataSchema leftSchema = new DataSchema(new String[]{"int_col", "string_col"}, new DataSchema.ColumnDataType[]{ + DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING + }); + DataSchema rightSchema = new DataSchema(new String[]{"int_col", "string_col"}, new DataSchema.ColumnDataType[]{ + DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING + }); + when(leftInput.nextBlock()).thenReturn( + OperatorTestUtil.block(leftSchema, new Object[]{1, "Aa"}, new Object[]{2, "BB"})) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + when(rightInput.nextBlock()).thenReturn( + OperatorTestUtil.block(rightSchema, new Object[]{2, "Aa"}, new Object[]{2, "BB"}, new Object[]{3, "BB"})) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + DataSchema resultSchema = new DataSchema(new String[]{"int_col1", "string_col1", "int_col2", "string_co2"}, + new DataSchema.ColumnDataType[]{ + DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT, + DataSchema.ColumnDataType.STRING + }); + return new HashJoinOperator(OperatorTestUtil.getTracingContext(), leftInput, leftSchema, rightInput, + new JoinNode(-1, resultSchema, PlanNode.NodeHint.EMPTY, List.of(), JoinRelType.INNER, List.of(0), List.of(0), + List.of())); + } + + private static MultiStageOperator getSortOperator() { + MultiStageOperator input = Mockito.mock(); + // Given: + DataSchema schema = new DataSchema(new String[]{"sort"}, new DataSchema.ColumnDataType[]{INT}); + when(input.nextBlock()).thenReturn( + new TransferableBlock(List.of(new Object[]{2}, new Object[]{1}), schema, DataBlock.Type.ROW)) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + List<RelFieldCollation> collations = + List.of(new RelFieldCollation(0, RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST)); + return new SortOperator(OperatorTestUtil.getTracingContext(), input, + new SortNode(-1, schema, PlanNode.NodeHint.EMPTY, List.of(), collations, 10, 0)); + } + + private static MultiStageOperator getWindowAggregateOperator() { + MultiStageOperator input = Mockito.mock(); + // Given: + DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new DataSchema.ColumnDataType[]{INT, INT}); + when(input.nextBlock()).thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{2, 1})) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + DataSchema resultSchema = new DataSchema(new String[]{"group", "arg", "sum"}, new DataSchema.ColumnDataType[]{ + INT, INT, DOUBLE + }); + List<Integer> keys = List.of(0); + List<RexExpression.FunctionCall> aggCalls = List.of(getSum(new RexExpression.InputRef(1))); + return new WindowAggregateOperator(OperatorTestUtil.getTracingContext(), input, inputSchema, + new WindowNode(-1, resultSchema, PlanNode.NodeHint.EMPTY, List.of(), keys, List.of(), aggCalls, + WindowNode.WindowFrameType.RANGE, Integer.MIN_VALUE, Integer.MAX_VALUE, List.of())); + } + + private static MultiStageOperator getIntersectOperator() { + MultiStageOperator leftOperator = Mockito.mock(); + MultiStageOperator rightOperator = Mockito.mock(); + + DataSchema schema = new DataSchema(new String[]{"int_col", "string_col"}, new DataSchema.ColumnDataType[]{ + DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING + }); + Mockito.when(leftOperator.nextBlock()) + .thenReturn(OperatorTestUtil.block(schema, new Object[]{1, "AA"}, new Object[]{2, "BB"}, new Object[]{3, "CC"})) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + Mockito.when(rightOperator.nextBlock()) + .thenReturn(OperatorTestUtil.block(schema, new Object[]{1, "AA"}, new Object[]{2, "BB"}, new Object[]{4, "DD"})) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + + return new IntersectOperator(OperatorTestUtil.getTracingContext(), ImmutableList.of(leftOperator, rightOperator), + schema); + } + + private static RexExpression.FunctionCall getSum(RexExpression arg) { + return new RexExpression.FunctionCall(DataSchema.ColumnDataType.INT, SqlKind.SUM.name(), List.of(arg)); + } + + @Override + public String getTestName() { + return _testName; + } +} diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java index da4537cb19..f279e5992b 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java @@ -83,7 +83,7 @@ public class OperatorTestUtil { public static OpChainExecutionContext getOpChainContext(MailboxService mailboxService, long deadlineMs, StageMetadata stageMetadata) { return new OpChainExecutionContext(mailboxService, 0, deadlineMs, ImmutableMap.of(), stageMetadata, - stageMetadata.getWorkerMetadataList().get(0), null); + stageMetadata.getWorkerMetadataList().get(0), null, null); } public static OpChainExecutionContext getTracingContext() { @@ -101,7 +101,7 @@ public class OperatorTestUtil { WorkerMetadata workerMetadata = new WorkerMetadata(0, ImmutableMap.of(), ImmutableMap.of()); StageMetadata stageMetadata = new StageMetadata(0, ImmutableList.of(workerMetadata), ImmutableMap.of()); OpChainExecutionContext opChainExecutionContext = new OpChainExecutionContext(mailboxService, 123L, Long.MAX_VALUE, - opChainMetadata, stageMetadata, workerMetadata, null); + opChainMetadata, stageMetadata, workerMetadata, null, null); StagePlan stagePlan = new StagePlan(null, stageMetadata); diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java index 2a568f45fc..3a0b23408e 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java @@ -105,7 +105,7 @@ public class QueryServerTest extends QueryTestSet { Worker.QueryRequest queryRequest = getQueryRequest(queryPlan, 1); Map<String, String> requestMetadata = QueryPlanSerDeUtils.fromProtoProperties(queryRequest.getMetadata()); QueryRunner mockRunner = _queryRunnerMap.get(Integer.parseInt(requestMetadata.get(KEY_OF_SERVER_INSTANCE_PORT))); - doThrow(new RuntimeException("foo")).when(mockRunner).processQuery(any(), any(), any()); + doThrow(new RuntimeException("foo")).when(mockRunner).processQuery(any(), any(), any(), any()); // submit the request for testing. Worker.QueryResponse resp = submitRequest(queryRequest, requestMetadata); // reset the mock runner before assert. @@ -148,7 +148,7 @@ public class QueryServerTest extends QueryTestSet { return planNode.equals(stagePlan.getRootNode()) && isStageMetadataEqual(stageMetadata, stagePlan.getStageMetadata()); }), argThat(requestMetadataMap -> requestId.equals( - requestMetadataMap.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID)))); + requestMetadataMap.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID))), any()); return true; } catch (Throwable t) { return false; diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadExecutionContext.java b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadExecutionContext.java index 68fd9b03e2..7ea59ad378 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadExecutionContext.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadExecutionContext.java @@ -23,6 +23,17 @@ package org.apache.pinot.spi.accounting; */ public interface ThreadExecutionContext { + /** + * SSE: Single Stage Engine + * MSE: Multi Stage Engine + * UNKNOWN: Default + */ + public enum TaskType { + SSE, + MSE, + UNKNOWN + } + /** * get query id of the execution context * @return query id in string @@ -34,4 +45,6 @@ public interface ThreadExecutionContext { * @return get the anchor thread of execution context */ Thread getAnchorThread(); + + TaskType getTaskType(); } diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceTracker.java b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceTracker.java index ff78a0d33c..418210b376 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceTracker.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceTracker.java @@ -49,4 +49,6 @@ public interface ThreadResourceTracker { * @return an int containing the task id. */ int getTaskId(); + + ThreadExecutionContext.TaskType getTaskType(); } diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceUsageAccountant.java b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceUsageAccountant.java index 8be0632e6b..bab51cf21c 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceUsageAccountant.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/accounting/ThreadResourceUsageAccountant.java @@ -42,7 +42,8 @@ public interface ThreadResourceUsageAccountant { * @param taskId a unique task id * @param parentContext the parent execution context, null for root(runner) thread */ - void createExecutionContext(String queryId, int taskId, @Nullable ThreadExecutionContext parentContext); + void createExecutionContext(String queryId, int taskId, ThreadExecutionContext.TaskType taskType, + @Nullable ThreadExecutionContext parentContext); /** * get the executon context of current thread diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/trace/Tracing.java b/pinot-spi/src/main/java/org/apache/pinot/spi/trace/Tracing.java index 910ebc35cf..59ad65eef5 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/trace/Tracing.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/trace/Tracing.java @@ -200,7 +200,8 @@ public class Tracing { } @Override - public final void createExecutionContext(String queryId, int taskId, ThreadExecutionContext parentContext) { + public final void createExecutionContext(String queryId, int taskId, ThreadExecutionContext.TaskType taskType, + ThreadExecutionContext parentContext) { _anchorThread.set(parentContext == null ? Thread.currentThread() : parentContext.getAnchorThread()); createExecutionContextInner(queryId, taskId, parentContext); } @@ -220,6 +221,11 @@ public class Tracing { public Thread getAnchorThread() { return _anchorThread.get(); } + + @Override + public TaskType getTaskType() { + return TaskType.UNKNOWN; + } }; } @@ -254,14 +260,37 @@ public class Tracing { } public static void setupRunner(String queryId) { + setupRunner(queryId, ThreadExecutionContext.TaskType.SSE); + } + + public static void setupRunner(String queryId, ThreadExecutionContext.TaskType taskType) { Tracing.getThreadAccountant().setThreadResourceUsageProvider(new ThreadResourceUsageProvider()); - Tracing.getThreadAccountant().createExecutionContext(queryId, CommonConstants.Accounting.ANCHOR_TASK_ID, null); + Tracing.getThreadAccountant() + .createExecutionContext(queryId, CommonConstants.Accounting.ANCHOR_TASK_ID, taskType, null); } + /** + * Setup metadata of query worker threads. This function assumes that the workers are for Single Stage Engine. + * @param taskId Query task ID of the thread. In SSE, ID is an incrementing counter. In MSE, id is the stage id. + * @param threadResourceUsageProvider Object that measures resource usage. + * @param threadExecutionContext Context holds metadata about the query. + */ public static void setupWorker(int taskId, ThreadResourceUsageProvider threadResourceUsageProvider, ThreadExecutionContext threadExecutionContext) { + setupWorker(taskId, ThreadExecutionContext.TaskType.SSE, threadResourceUsageProvider, threadExecutionContext); + } + + /** + * Setup metadata of query worker threads. + * @param taskId Query task ID of the thread. In SSE, ID is an incrementing counter. In MSE, id is the stage id. + * @param threadResourceUsageProvider Object that measures resource usage. + * @param threadExecutionContext Context holds metadata about the query. + */ + public static void setupWorker(int taskId, ThreadExecutionContext.TaskType taskType, + ThreadResourceUsageProvider threadResourceUsageProvider, + ThreadExecutionContext threadExecutionContext) { Tracing.getThreadAccountant().setThreadResourceUsageProvider(threadResourceUsageProvider); - Tracing.getThreadAccountant().createExecutionContext(null, taskId, threadExecutionContext); + Tracing.getThreadAccountant().createExecutionContext(null, taskId, taskType, threadExecutionContext); } public static void sample() { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org