This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new c9a82c40a2 [Multi-stage] Optimize query dispatch (#12358)
c9a82c40a2 is described below

commit c9a82c40a2c8bed5e86d8278e0bb57bfc5bee86f
Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com>
AuthorDate: Sat Feb 3 11:58:58 2024 -0800

    [Multi-stage] Optimize query dispatch (#12358)
---
 .../runtime/plan/serde/QueryPlanSerDeUtils.java    |  48 ++++------
 .../dispatch/AsyncQueryDispatchResponse.java       |  17 ++--
 .../query/service/dispatch/DispatchClient.java     |  22 +----
 .../query/service/dispatch/DispatchObserver.java   |  17 ++--
 .../query/service/dispatch/QueryDispatcher.java    | 104 +++++++++++++++------
 .../query/service/server/QueryServerTest.java      |  25 +++--
 6 files changed, 130 insertions(+), 103 deletions(-)

diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
index c4bded9373..91bbcc2010 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
@@ -27,7 +27,6 @@ import java.util.regex.Pattern;
 import org.apache.commons.lang.StringUtils;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
-import org.apache.pinot.query.planner.physical.DispatchableSubPlan;
 import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
 import org.apache.pinot.query.planner.plannode.StageNodeSerDeUtils;
 import org.apache.pinot.query.routing.MailboxMetadata;
@@ -42,8 +41,8 @@ import org.apache.pinot.query.runtime.plan.StageMetadata;
  * This utility class serialize/deserialize between {@link Worker.StagePlan} 
elements to Planner elements.
  */
 public class QueryPlanSerDeUtils {
-  private static final Pattern VIRTUAL_SERVER_PATTERN = Pattern.compile(
-      "(?<virtualid>[0-9]+)@(?<host>[^:]+):(?<port>[0-9]+)");
+  private static final Pattern VIRTUAL_SERVER_PATTERN =
+      Pattern.compile("(?<virtualid>[0-9]+)@(?<host>[^:]+):(?<port>[0-9]+)");
 
   private QueryPlanSerDeUtils() {
     // do not instantiate.
@@ -57,18 +56,6 @@ public class QueryPlanSerDeUtils {
     return distributedStagePlans;
   }
 
-  public static Worker.StagePlan serialize(DispatchableSubPlan 
dispatchableSubPlan, int stageId,
-      QueryServerInstance queryServerInstance, List<Integer> workerIds) {
-    return Worker.StagePlan.newBuilder()
-        .setStageId(stageId)
-        .setStageRoot(StageNodeSerDeUtils.serializeStageNode(
-            (AbstractPlanNode) 
dispatchableSubPlan.getQueryStageList().get(stageId).getPlanFragment()
-                .getFragmentRoot()))
-        .setStageMetadata(
-            
toProtoStageMetadata(dispatchableSubPlan.getQueryStageList().get(stageId), 
queryServerInstance, workerIds))
-        .build();
-  }
-
   public static VirtualServerAddress protoToAddress(String virtualAddressStr) {
     Matcher matcher = VIRTUAL_SERVER_PATTERN.matcher(virtualAddressStr);
     if (!matcher.matches()) {
@@ -78,8 +65,8 @@ public class QueryPlanSerDeUtils {
     }
 
     // Skipped netty and grpc port as they are not used in worker instance.
-    return new VirtualServerAddress(matcher.group("host"),
-        Integer.parseInt(matcher.group("port")), 
Integer.parseInt(matcher.group("virtualid")));
+    return new VirtualServerAddress(matcher.group("host"), 
Integer.parseInt(matcher.group("port")),
+        Integer.parseInt(matcher.group("virtualid")));
   }
 
   public static String addressToProto(VirtualServerAddress serverAddress) {
@@ -145,17 +132,21 @@ public class QueryPlanSerDeUtils {
     return mailboxMetadata;
   }
 
-  private static Worker.StageMetadata 
toProtoStageMetadata(DispatchablePlanFragment planFragment,
-      QueryServerInstance queryServerInstance, List<Integer> workerIds) {
-    Worker.StageMetadata.Builder builder = Worker.StageMetadata.newBuilder();
-    for (WorkerMetadata workerMetadata : planFragment.getWorkerMetadataList()) 
{
-      builder.addWorkerMetadata(toProtoWorkerMetadata(workerMetadata));
+  public static Worker.StageMetadata 
toProtoStageMetadata(List<Worker.WorkerMetadata> workerMetadataList,
+      Map<String, String> customProperties, QueryServerInstance 
serverInstance, List<Integer> workerIds) {
+    return 
Worker.StageMetadata.newBuilder().addAllWorkerMetadata(workerMetadataList)
+        .putAllCustomProperty(customProperties)
+        .setServerAddress(String.format("%s:%d", serverInstance.getHostname(), 
serverInstance.getQueryMailboxPort()))
+        .addAllWorkerIds(workerIds).build();
+  }
+
+  public static List<Worker.WorkerMetadata> 
toProtoWorkerMetadataList(DispatchablePlanFragment planFragment) {
+    List<WorkerMetadata> workerMetadataList = 
planFragment.getWorkerMetadataList();
+    List<Worker.WorkerMetadata> protoWorkerMetadataList = new 
ArrayList<>(workerMetadataList.size());
+    for (WorkerMetadata workerMetadata : workerMetadataList) {
+      protoWorkerMetadataList.add(toProtoWorkerMetadata(workerMetadata));
     }
-    builder.putAllCustomProperty(planFragment.getCustomProperties());
-    builder.setServerAddress(String.format("%s:%d", 
queryServerInstance.getHostname(),
-        queryServerInstance.getQueryMailboxPort()));
-    builder.addAllWorkerIds(workerIds);
-    return builder.build();
+    return protoWorkerMetadataList;
   }
 
   private static Worker.WorkerMetadata toProtoWorkerMetadata(WorkerMetadata 
workerMetadata) {
@@ -166,8 +157,7 @@ public class QueryPlanSerDeUtils {
     return builder.build();
   }
 
-  private static Map<Integer, Worker.MailboxMetadata> toProtoMailboxMap(
-      Map<Integer, MailboxMetadata> mailBoxInfosMap) {
+  private static Map<Integer, Worker.MailboxMetadata> 
toProtoMailboxMap(Map<Integer, MailboxMetadata> mailBoxInfosMap) {
     Map<Integer, Worker.MailboxMetadata> mailboxMetadataMap = new HashMap<>();
     for (Map.Entry<Integer, MailboxMetadata> entry : 
mailBoxInfosMap.entrySet()) {
       mailboxMetadataMap.put(entry.getKey(), toProtoMailbox(entry.getValue()));
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java
index 185ba4f607..076d8ce221 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java
@@ -30,27 +30,22 @@ import org.apache.pinot.query.routing.QueryServerInstance;
  * {@link #getThrowable()} to check if it is null.
  */
 class AsyncQueryDispatchResponse {
-  private final QueryServerInstance _virtualServer;
-  private final int _stageId;
+  private final QueryServerInstance _serverInstance;
   private final Worker.QueryResponse _queryResponse;
   private final Throwable _throwable;
 
-  public AsyncQueryDispatchResponse(QueryServerInstance virtualServer, int 
stageId, Worker.QueryResponse queryResponse,
+  public AsyncQueryDispatchResponse(QueryServerInstance serverInstance, 
@Nullable Worker.QueryResponse queryResponse,
       @Nullable Throwable throwable) {
-    _virtualServer = virtualServer;
-    _stageId = stageId;
+    _serverInstance = serverInstance;
     _queryResponse = queryResponse;
     _throwable = throwable;
   }
 
-  public QueryServerInstance getVirtualServer() {
-    return _virtualServer;
-  }
-
-  public int getStageId() {
-    return _stageId;
+  public QueryServerInstance getServerInstance() {
+    return _serverInstance;
   }
 
+  @Nullable
   public Worker.QueryResponse getQueryResponse() {
     return _queryResponse;
   }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
index 03861a436e..5b036930ce 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
@@ -26,8 +26,6 @@ import java.util.function.Consumer;
 import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.query.routing.QueryServerInstance;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 
 /**
@@ -37,8 +35,8 @@ import org.slf4j.LoggerFactory;
  *       let that take care of pooling. (2) Create a DispatchClient interface 
and implement pooled/non-pooled versions.
  */
 class DispatchClient {
-  private static final Logger LOGGER = 
LoggerFactory.getLogger(DispatchClient.class);
   private static final StreamObserver<Worker.CancelResponse> 
NO_OP_CANCEL_STREAM_OBSERVER = new CancelObserver();
+
   private final ManagedChannel _channel;
   private final PinotQueryWorkerGrpc.PinotQueryWorkerStub _dispatchStub;
 
@@ -51,23 +49,13 @@ class DispatchClient {
     return _channel;
   }
 
-  public void submit(Worker.QueryRequest request, int stageId, 
QueryServerInstance virtualServer, Deadline deadline,
+  public void submit(Worker.QueryRequest request, QueryServerInstance 
virtualServer, Deadline deadline,
       Consumer<AsyncQueryDispatchResponse> callback) {
-    try {
-      _dispatchStub.withDeadline(deadline).submit(request, new 
DispatchObserver(stageId, virtualServer, callback));
-    } catch (Exception e) {
-      LOGGER.error("Query Dispatch failed at client-side", e);
-      callback.accept(new AsyncQueryDispatchResponse(
-          virtualServer, stageId, Worker.QueryResponse.getDefaultInstance(), 
e));
-    }
+    _dispatchStub.withDeadline(deadline).submit(request, new 
DispatchObserver(virtualServer, callback));
   }
 
   public void cancel(long requestId) {
-    try {
-      Worker.CancelRequest cancelRequest = 
Worker.CancelRequest.newBuilder().setRequestId(requestId).build();
-      _dispatchStub.cancel(cancelRequest, NO_OP_CANCEL_STREAM_OBSERVER);
-    } catch (Exception e) {
-      LOGGER.error("Query Cancellation failed at client-side", e);
-    }
+    Worker.CancelRequest cancelRequest = 
Worker.CancelRequest.newBuilder().setRequestId(requestId).build();
+    _dispatchStub.cancel(cancelRequest, NO_OP_CANCEL_STREAM_OBSERVER);
   }
 }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java
index 2a7425dd99..9b99691655 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java
@@ -28,15 +28,13 @@ import org.apache.pinot.query.routing.QueryServerInstance;
  * A {@link StreamObserver} used by {@link DispatchClient} to subscribe to the 
response of a async Query Dispatch call.
  */
 class DispatchObserver implements StreamObserver<Worker.QueryResponse> {
-  private int _stageId;
-  private QueryServerInstance _virtualServer;
-  private Consumer<AsyncQueryDispatchResponse> _callback;
+  private final QueryServerInstance _serverInstance;
+  private final Consumer<AsyncQueryDispatchResponse> _callback;
+
   private Worker.QueryResponse _queryResponse;
 
-  public DispatchObserver(int stageId, QueryServerInstance virtualServer,
-      Consumer<AsyncQueryDispatchResponse> callback) {
-    _stageId = stageId;
-    _virtualServer = virtualServer;
+  public DispatchObserver(QueryServerInstance serverInstance, 
Consumer<AsyncQueryDispatchResponse> callback) {
+    _serverInstance = serverInstance;
     _callback = callback;
   }
 
@@ -48,12 +46,11 @@ class DispatchObserver implements 
StreamObserver<Worker.QueryResponse> {
   @Override
   public void onError(Throwable throwable) {
     _callback.accept(
-        new AsyncQueryDispatchResponse(_virtualServer, _stageId, 
Worker.QueryResponse.getDefaultInstance(),
-            throwable));
+        new AsyncQueryDispatchResponse(_serverInstance, 
Worker.QueryResponse.getDefaultInstance(), throwable));
   }
 
   @Override
   public void onCompleted() {
-    _callback.accept(new AsyncQueryDispatchResponse(_virtualServer, _stageId, 
_queryResponse, null));
+    _callback.accept(new AsyncQueryDispatchResponse(_serverInstance, 
_queryResponse, null));
   }
 }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index 3f1f43c1eb..2029e31a6f 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -27,16 +27,17 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ArrayBlockingQueue;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
-import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import javax.annotation.Nullable;
 import org.apache.calcite.util.Pair;
 import org.apache.pinot.common.datablock.DataBlock;
+import org.apache.pinot.common.proto.Plan;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.common.response.broker.ResultTable;
 import org.apache.pinot.common.utils.DataSchema;
@@ -48,8 +49,10 @@ import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.planner.PlanFragment;
 import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
 import org.apache.pinot.query.planner.physical.DispatchableSubPlan;
+import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
 import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
 import org.apache.pinot.query.planner.plannode.PlanNode;
+import org.apache.pinot.query.planner.plannode.StageNodeSerDeUtils;
 import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
@@ -107,50 +110,76 @@ public class QueryDispatcher {
   void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long 
timeoutMs, Map<String, String> queryOptions)
       throws Exception {
     Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS);
-    BlockingQueue<AsyncQueryDispatchResponse> dispatchCallbacks = new 
LinkedBlockingQueue<>();
     List<DispatchablePlanFragment> stagePlans = 
dispatchableSubPlan.getQueryStageList();
     int numStages = stagePlans.size();
-    int numDispatchCalls = 0;
-    // Do not submit the reduce stage (stage 0)
+    Set<QueryServerInstance> serverInstances = new HashSet<>();
+    // TODO: If serialization is slow, consider serializing each stage in 
parallel
+    StageInfo[] stageInfoMap = new StageInfo[numStages];
+    // Ignore the reduce stage (stage 0)
     for (int stageId = 1; stageId < numStages; stageId++) {
-      for (Map.Entry<QueryServerInstance, List<Integer>> entry : 
stagePlans.get(stageId)
-          .getServerInstanceToWorkerIdMap().entrySet()) {
-        QueryServerInstance queryServerInstance = entry.getKey();
-        Worker.QueryRequest.Builder queryRequestBuilder = 
Worker.QueryRequest.newBuilder();
-        queryRequestBuilder.addStagePlan(
-            QueryPlanSerDeUtils.serialize(dispatchableSubPlan, stageId, 
queryServerInstance, entry.getValue()));
-        Worker.QueryRequest queryRequest =
-            
queryRequestBuilder.putMetadata(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID,
-                    String.valueOf(requestId))
-                
.putMetadata(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, 
String.valueOf(timeoutMs))
-                .putAllMetadata(queryOptions).build();
-        DispatchClient client = getOrCreateDispatchClient(queryServerInstance);
-        int finalStageId = stageId;
-        _executorService.submit(
-            () -> client.submit(queryRequest, finalStageId, 
queryServerInstance, deadline, dispatchCallbacks::offer));
-        numDispatchCalls++;
-      }
+      DispatchablePlanFragment stagePlan = stagePlans.get(stageId);
+      
serverInstances.addAll(stagePlan.getServerInstanceToWorkerIdMap().keySet());
+      Plan.StageNode rootNode =
+          StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) 
stagePlan.getPlanFragment().getFragmentRoot());
+      List<Worker.WorkerMetadata> workerMetadataList = 
QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan);
+      stageInfoMap[stageId] = new StageInfo(rootNode, workerMetadataList, 
stagePlan.getCustomProperties());
+    }
+    Map<String, String> requestMetadata = new HashMap<>();
+    requestMetadata.put(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID, 
Long.toString(requestId));
+    
requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, 
Long.toString(timeoutMs));
+    requestMetadata.putAll(queryOptions);
+
+    // Submit the query plan to all servers in parallel
+    int numServers = serverInstances.size();
+    BlockingQueue<AsyncQueryDispatchResponse> dispatchCallbacks = new 
ArrayBlockingQueue<>(numServers);
+    for (QueryServerInstance serverInstance : serverInstances) {
+      _executorService.submit(() -> {
+        try {
+          Worker.QueryRequest.Builder requestBuilder = 
Worker.QueryRequest.newBuilder();
+          for (int stageId = 1; stageId < numStages; stageId++) {
+            List<Integer> workerIds = 
stagePlans.get(stageId).getServerInstanceToWorkerIdMap().get(serverInstance);
+            if (workerIds != null) {
+              StageInfo stageInfo = stageInfoMap[stageId];
+              Worker.StageMetadata stageMetadata =
+                  
QueryPlanSerDeUtils.toProtoStageMetadata(stageInfo._workerMetadataList, 
stageInfo._customProperties,
+                      serverInstance, workerIds);
+              Worker.StagePlan stagePlan =
+                  
Worker.StagePlan.newBuilder().setStageId(stageId).setStageRoot(stageInfo._rootNode)
+                      .setStageMetadata(stageMetadata).build();
+              requestBuilder.addStagePlan(stagePlan);
+            }
+          }
+          requestBuilder.putAllMetadata(requestMetadata);
+          
getOrCreateDispatchClient(serverInstance).submit(requestBuilder.build(), 
serverInstance, deadline,
+              dispatchCallbacks::offer);
+        } catch (Throwable t) {
+          LOGGER.warn("Caught exception while dispatching query: {} to server: 
{}", requestId, serverInstance, t);
+          dispatchCallbacks.offer(new 
AsyncQueryDispatchResponse(serverInstance, null, t));
+        }
+      });
     }
-    int successfulDispatchCalls = 0;
+
+    int numSuccessCalls = 0;
     // TODO: Cancel all dispatched requests if one of the dispatch errors out 
or deadline is breached.
-    while (!deadline.isExpired() && successfulDispatchCalls < 
numDispatchCalls) {
+    while (!deadline.isExpired() && numSuccessCalls < numServers) {
       AsyncQueryDispatchResponse resp =
           
dispatchCallbacks.poll(deadline.timeRemaining(TimeUnit.MILLISECONDS), 
TimeUnit.MILLISECONDS);
       if (resp != null) {
         if (resp.getThrowable() != null) {
           throw new RuntimeException(
-              String.format("Error dispatching query to server=%s stage=%s", 
resp.getVirtualServer(),
-                  resp.getStageId()), resp.getThrowable());
+              String.format("Error dispatching query: %d to server: %s", 
requestId, resp.getServerInstance()),
+              resp.getThrowable());
         } else {
           Worker.QueryResponse response = resp.getQueryResponse();
+          assert response != null;
           if 
(response.containsMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR))
 {
             throw new RuntimeException(
-                String.format("Unable to execute query plan at stage %s on 
server %s: ERROR: %s", resp.getStageId(),
-                    resp.getVirtualServer(),
+                String.format("Unable to execute query plan for request: %d on 
server: %s, ERROR: %s", requestId,
+                    resp.getServerInstance(),
                     
response.getMetadataOrDefault(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR,
                         "null")));
           }
-          successfulDispatchCalls++;
+          numSuccessCalls++;
         }
       }
     }
@@ -159,6 +188,19 @@ public class QueryDispatcher {
     }
   }
 
+  private static class StageInfo {
+    final Plan.StageNode _rootNode;
+    final List<Worker.WorkerMetadata> _workerMetadataList;
+    final Map<String, String> _customProperties;
+
+    StageInfo(Plan.StageNode rootNode, List<Worker.WorkerMetadata> 
workerMetadataList,
+        Map<String, String> customProperties) {
+      _rootNode = rootNode;
+      _workerMetadataList = workerMetadataList;
+      _customProperties = customProperties;
+    }
+  }
+
   private void cancel(long requestId, DispatchableSubPlan dispatchableSubPlan) 
{
     List<DispatchablePlanFragment> stagePlans = 
dispatchableSubPlan.getQueryStageList();
     int numStages = stagePlans.size();
@@ -168,7 +210,11 @@ public class QueryDispatcher {
       
serversToCancel.addAll(stagePlans.get(stageId).getServerInstanceToWorkerIdMap().keySet());
     }
     for (QueryServerInstance queryServerInstance : serversToCancel) {
-      getOrCreateDispatchClient(queryServerInstance).cancel(requestId);
+      try {
+        getOrCreateDispatchClient(queryServerInstance).cancel(requestId);
+      } catch (Throwable t) {
+        LOGGER.warn("Caught exception while cancelling query: {} on server: 
{}", requestId, queryServerInstance, t);
+      }
     }
   }
 
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
index 140851f666..4e5a003427 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.query.service.server;
 
-import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
 import io.grpc.Deadline;
 import io.grpc.ManagedChannel;
@@ -30,6 +29,7 @@ import java.util.Map;
 import java.util.Random;
 import java.util.concurrent.TimeUnit;
 import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
+import org.apache.pinot.common.proto.Plan;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.query.QueryEnvironment;
@@ -37,7 +37,9 @@ import org.apache.pinot.query.QueryEnvironmentTestBase;
 import org.apache.pinot.query.QueryTestSet;
 import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
 import org.apache.pinot.query.planner.physical.DispatchableSubPlan;
+import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
 import org.apache.pinot.query.planner.plannode.PlanNode;
+import org.apache.pinot.query.planner.plannode.StageNodeSerDeUtils;
 import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.QueryRunner;
@@ -228,15 +230,24 @@ public class QueryServerTest extends QueryTestSet {
   }
 
   private Worker.QueryRequest getQueryRequest(DispatchableSubPlan 
dispatchableSubPlan, int stageId) {
-    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap =
-        
dispatchableSubPlan.getQueryStageList().get(stageId).getServerInstanceToWorkerIdMap();
+    DispatchablePlanFragment planFragment = 
dispatchableSubPlan.getQueryStageList().get(stageId);
+    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap = 
planFragment.getServerInstanceToWorkerIdMap();
     // this particular test set requires the request to have a single 
QueryServerInstance to dispatch to
     // as it is not testing the multi-tenancy dispatch (which is in the 
QueryDispatcherTest)
-    QueryServerInstance serverInstance = 
serverInstanceToWorkerIdMap.keySet().iterator().next();
-    int workerId = serverInstanceToWorkerIdMap.get(serverInstance).get(0);
+    Map.Entry<QueryServerInstance, List<Integer>> entry = 
serverInstanceToWorkerIdMap.entrySet().iterator().next();
+    QueryServerInstance serverInstance = entry.getKey();
+    List<Integer> workerIds = entry.getValue();
+    Plan.StageNode stageRoot =
+        StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) 
planFragment.getPlanFragment().getFragmentRoot());
+    List<Worker.WorkerMetadata> protoWorkerMetadataList = 
QueryPlanSerDeUtils.toProtoWorkerMetadataList(planFragment);
+    Worker.StageMetadata stageMetadata =
+        QueryPlanSerDeUtils.toProtoStageMetadata(protoWorkerMetadataList, 
planFragment.getCustomProperties(),
+            serverInstance, workerIds);
+    Worker.StagePlan stagePlan =
+        
Worker.StagePlan.newBuilder().setStageId(stageId).setStageRoot(stageRoot).setStageMetadata(stageMetadata)
+            .build();
 
-    return Worker.QueryRequest.newBuilder().addStagePlan(
-            QueryPlanSerDeUtils.serialize(dispatchableSubPlan, stageId, 
serverInstance, ImmutableList.of(workerId)))
+    return Worker.QueryRequest.newBuilder().addStagePlan(stagePlan)
         // the default configurations that must exist.
         .putMetadata(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID,
             String.valueOf(RANDOM_REQUEST_ID_GEN.nextLong()))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to