Copilot commented on code in PR #60878:
URL: https://github.com/apache/doris/pull/60878#discussion_r2870342325
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/distribute/worker/job/UnassignedShuffleJob.java:
##########
@@ -79,14 +84,26 @@ public List<AssignedJob> computeAssignedJobs(
}
}
- protected int degreeOfParallelism() {
+ protected int degreeOfParallelism(int childInstanceNum,
ListMultimap<ExchangeNode, AssignedJob> inputJobs) {
// TODO: check we use nested loop join do right outer / semi / anti
join,
// we should add an exchange node with gather distribute under
the nested loop join
int expectInstanceNum = -1;
ConnectContext connectContext = statementContext.getConnectContext();
if (connectContext != null && connectContext.getSessionVariable() !=
null) {
expectInstanceNum =
connectContext.getSessionVariable().getExchangeInstanceParallel();
}
+ // If child fragment uses query cache, limit instance num to avoid too
many instances
+ if (childInstanceNum > 0 && connectContext != null) {
+ boolean childHasQueryCacheParam = inputJobs.values().stream()
+ .anyMatch(job ->
job.unassignedJob().getFragment().queryCacheParam != null);
+ if (childHasQueryCacheParam) {
+ int maxInstanceNum =
connectContext.getSessionVariable().getParallelExecInstanceNum()
+ * Env.getCurrentSystemInfo().getBackendsNumber(false);
Review Comment:
`degreeOfParallelism` calls
`Env.getCurrentSystemInfo().getBackendsNumber(false)`, but
`SystemInfoService#getBackendsNumber` dereferences `ConnectContext.get()`
without a null-check. In code paths where a `ConnectContext` exists on
`statementContext` but isn’t installed in the thread-local (e.g.
`EnvFactory#createCoordinator` creates a new `ConnectContext` when
`ConnectContext.get()` is null), this can throw an NPE. Compute backend count
without relying on `ConnectContext.get()` (or make `getBackendsNumber` safe)
and, if you still need the `beNumberForTest` override, read it from the local
`connectContext.getSessionVariable()` instead.
```suggestion
int backendNumber = 1;
ConnectContext threadLocalCtx = ConnectContext.get();
if (threadLocalCtx != null) {
backendNumber =
Env.getCurrentSystemInfo().getBackendsNumber(false);
}
int maxInstanceNum =
connectContext.getSessionVariable().getParallelExecInstanceNum()
* backendNumber;
```
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/distribute/worker/job/UnassignedShuffleJob.java:
##########
@@ -79,14 +84,26 @@ public List<AssignedJob> computeAssignedJobs(
}
}
- protected int degreeOfParallelism() {
+ protected int degreeOfParallelism(int childInstanceNum,
ListMultimap<ExchangeNode, AssignedJob> inputJobs) {
// TODO: check we use nested loop join do right outer / semi / anti
join,
// we should add an exchange node with gather distribute under
the nested loop join
int expectInstanceNum = -1;
ConnectContext connectContext = statementContext.getConnectContext();
if (connectContext != null && connectContext.getSessionVariable() !=
null) {
expectInstanceNum =
connectContext.getSessionVariable().getExchangeInstanceParallel();
}
+ // If child fragment uses query cache, limit instance num to avoid too
many instances
+ if (childInstanceNum > 0 && connectContext != null) {
+ boolean childHasQueryCacheParam = inputJobs.values().stream()
+ .anyMatch(job ->
job.unassignedJob().getFragment().queryCacheParam != null);
+ if (childHasQueryCacheParam) {
+ int maxInstanceNum =
connectContext.getSessionVariable().getParallelExecInstanceNum()
+ * Env.getCurrentSystemInfo().getBackendsNumber(false);
+ expectInstanceNum = expectInstanceNum > 0
+ ? Math.min(expectInstanceNum,
Math.min(childInstanceNum, maxInstanceNum))
+ : Math.min(childInstanceNum, maxInstanceNum);
Review Comment:
In the query-cache limiting branch, `connectContext.getSessionVariable()` is
used without checking for null (`getParallelExecInstanceNum()`). Earlier code
guarded `getSessionVariable()` for `getExchangeInstanceParallel()`, and other
parts of the codebase treat session variables as nullable. Please reuse a
non-null `SessionVariable` local (or extend the null-check) before reading
`getParallelExecInstanceNum()` to avoid NPEs.
##########
fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/distribute/worker/job/UnassignedShuffleJobTest.java:
##########
@@ -0,0 +1,487 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.plans.distribute.worker.job;
+
+import org.apache.doris.catalog.Env;
+import org.apache.doris.nereids.StatementContext;
+import org.apache.doris.nereids.trees.plans.distribute.DistributeContext;
+import
org.apache.doris.nereids.trees.plans.distribute.worker.DistributedPlanWorker;
+import
org.apache.doris.nereids.trees.plans.distribute.worker.DistributedPlanWorkerManager;
+import org.apache.doris.planner.ExchangeNode;
+import org.apache.doris.planner.PlanFragment;
+import org.apache.doris.qe.ConnectContext;
+import org.apache.doris.qe.SessionVariable;
+import org.apache.doris.system.SystemInfoService;
+import org.apache.doris.thrift.TQueryCacheParam;
+import org.apache.doris.thrift.TUniqueId;
+
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.ListMultimap;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+
+import java.util.BitSet;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicLong;
+
+/**
+ * Unit tests for {@link UnassignedShuffleJob}, specifically the
+ * degreeOfParallelism logic that limits instance count when query cache is
enabled.
+ */
+public class UnassignedShuffleJobTest {
+
+ private ConnectContext connectContext;
+ private SessionVariable sessionVariable;
+ private StatementContext statementContext;
+ private PlanFragment fragment;
+ private MockedStatic<Env> envMockedStatic;
+ private MockedStatic<ConnectContext> connectContextMockedStatic;
+ private AtomicLong instanceIdCounter;
+
+ @BeforeEach
+ public void setUp() {
+ sessionVariable = Mockito.mock(SessionVariable.class);
+ connectContext = Mockito.mock(ConnectContext.class);
+
Mockito.when(connectContext.getSessionVariable()).thenReturn(sessionVariable);
+
+ // nextInstanceId() is called from buildInstances; provide unique IDs
+ instanceIdCounter = new AtomicLong(0);
+ Mockito.when(connectContext.nextInstanceId()).thenAnswer(
+ invocation -> new TUniqueId(0,
instanceIdCounter.incrementAndGet()));
+
+ // Mock static ConnectContext.get() for
SystemInfoService.getBackendsNumber
+ connectContextMockedStatic = Mockito.mockStatic(ConnectContext.class);
+
connectContextMockedStatic.when(ConnectContext::get).thenReturn(connectContext);
+
Review Comment:
The tests always mock `ConnectContext.get()` to return a non-null context.
Since production code can run with a non-null
`statementContext.getConnectContext()` while `ConnectContext.get()`
(thread-local) is null (see `EnvFactory#createCoordinator`), consider adding a
regression test that exercises the query-cache limiting path with
`ConnectContext.get()` returning null to ensure no NPE and correct limiting
behavior.
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/distribute/worker/job/UnassignedShuffleJob.java:
##########
@@ -79,14 +84,26 @@ public List<AssignedJob> computeAssignedJobs(
}
}
- protected int degreeOfParallelism() {
+ protected int degreeOfParallelism(int childInstanceNum,
ListMultimap<ExchangeNode, AssignedJob> inputJobs) {
// TODO: check we use nested loop join do right outer / semi / anti
join,
// we should add an exchange node with gather distribute under
the nested loop join
int expectInstanceNum = -1;
ConnectContext connectContext = statementContext.getConnectContext();
if (connectContext != null && connectContext.getSessionVariable() !=
null) {
expectInstanceNum =
connectContext.getSessionVariable().getExchangeInstanceParallel();
}
+ // If child fragment uses query cache, limit instance num to avoid too
many instances
+ if (childInstanceNum > 0 && connectContext != null) {
+ boolean childHasQueryCacheParam = inputJobs.values().stream()
Review Comment:
`childHasQueryCacheParam` currently scans `inputJobs.values()` (i.e. all
child *instances*) to decide whether any child fragment has `queryCacheParam`.
When instance counts are large (the scenario this change targets), this adds an
avoidable O(total-instances) pass. Prefer checking at fragment/exchange
granularity (e.g., one representative `AssignedJob` per exchange, or the
`exchangeToChildJob` fragments) so the detection cost stays
O(#exchanges/#fragments).
```suggestion
boolean childHasQueryCacheParam =
inputJobs.asMap().values().stream()
.map(Collection::iterator)
.filter(java.util.Iterator::hasNext)
.map(java.util.Iterator::next)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]