This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new f1966d9fa0 [Multi-stage] Support partition based colocated join 
(#10886)
f1966d9fa0 is described below

commit f1966d9fa01040774ff2e153ab69f9c69903fcbe
Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com>
AuthorDate: Thu Jun 15 12:28:28 2023 -0700

    [Multi-stage] Support partition based colocated join (#10886)
---
 .../pinot/broker/routing/BrokerRoutingManager.java |   6 +-
 .../SegmentPartitionMetadataManager.java           |  16 +-
 .../SegmentPartitionMetadataManagerTest.java       |   1 +
 .../apache/pinot/core/routing/RoutingManager.java  |   9 +
 .../pinot/core/routing}/TablePartitionInfo.java    |  11 +-
 .../rel/rules/PinotJoinExchangeNodeInsertRule.java |  10 +-
 .../planner/logical/RelToPlanNodeConverter.java    |  15 +-
 .../planner/physical/MailboxAssignmentVisitor.java | 104 ++---
 .../planner/physical/PinotDispatchPlanner.java     |  36 +-
 .../pinot/query/planner/plannode/JoinNode.java     |   9 +-
 .../apache/pinot/query/routing/WorkerManager.java  | 453 ++++++++++++++++-----
 .../query/testutils/MockRoutingManagerFactory.java |   8 +
 .../operator/exchange/SingletonExchange.java       |  26 +-
 .../runtime/operator/HashJoinOperatorTest.java     |  32 +-
 .../operator/exchange/SingletonExchangeTest.java   |  43 +-
 .../pinot/tools/ColocatedJoinEngineQuickStart.java |   5 +-
 .../userAttributes_offline_table_config.json       |  44 +-
 .../userGroups_offline_table_config.json           |  11 +-
 18 files changed, 536 insertions(+), 303 deletions(-)

diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java
index eae1977a28..5ea2ea5373 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java
@@ -46,7 +46,6 @@ import 
org.apache.pinot.broker.routing.instanceselector.InstanceSelectorFactory;
 import 
org.apache.pinot.broker.routing.segmentmetadata.SegmentZkMetadataFetchListener;
 import 
org.apache.pinot.broker.routing.segmentmetadata.SegmentZkMetadataFetcher;
 import 
org.apache.pinot.broker.routing.segmentpartition.SegmentPartitionMetadataManager;
-import org.apache.pinot.broker.routing.segmentpartition.TablePartitionInfo;
 import org.apache.pinot.broker.routing.segmentpreselector.SegmentPreSelector;
 import 
org.apache.pinot.broker.routing.segmentpreselector.SegmentPreSelectorFactory;
 import org.apache.pinot.broker.routing.segmentpruner.SegmentPruner;
@@ -63,6 +62,7 @@ import org.apache.pinot.common.utils.config.TagNameUtils;
 import org.apache.pinot.common.utils.helix.HelixHelper;
 import org.apache.pinot.core.routing.RoutingManager;
 import org.apache.pinot.core.routing.RoutingTable;
+import org.apache.pinot.core.routing.TablePartitionInfo;
 import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.core.transport.ServerInstance;
 import 
org.apache.pinot.core.transport.server.routing.stats.ServerRoutingStatsManager;
@@ -651,8 +651,7 @@ public class BrokerRoutingManager implements 
RoutingManager, ClusterChangeHandle
 
   @Override
   public Map<String, ServerInstance> getEnabledServersForTableTenant(String 
tableNameWithType) {
-    return _tableTenantServersMap.containsKey(tableNameWithType) ? 
_tableTenantServersMap.get(tableNameWithType)
-        : new HashMap<String, ServerInstance>();
+    return _tableTenantServersMap.getOrDefault(tableNameWithType, 
Collections.emptyMap());
   }
 
   private String getIdealStatePath(String tableNameWithType) {
@@ -680,6 +679,7 @@ public class BrokerRoutingManager implements 
RoutingManager, ClusterChangeHandle
   }
 
   @Nullable
+  @Override
   public TablePartitionInfo getTablePartitionInfo(String tableNameWithType) {
     RoutingEntry routingEntry = _routingEntryMap.get(tableNameWithType);
     if (routingEntry == null) {
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpartition/SegmentPartitionMetadataManager.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpartition/SegmentPartitionMetadataManager.java
index 199cecaedd..0ed89225f3 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpartition/SegmentPartitionMetadataManager.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpartition/SegmentPartitionMetadataManager.java
@@ -30,6 +30,8 @@ import org.apache.helix.model.ExternalView;
 import org.apache.helix.model.IdealState;
 import org.apache.helix.zookeeper.datamodel.ZNRecord;
 import 
org.apache.pinot.broker.routing.segmentmetadata.SegmentZkMetadataFetchListener;
+import org.apache.pinot.core.routing.TablePartitionInfo;
+import org.apache.pinot.core.routing.TablePartitionInfo.PartitionInfo;
 import org.apache.pinot.segment.spi.partition.PartitionFunction;
 import 
org.apache.pinot.spi.utils.CommonConstants.Helix.StateModel.SegmentStateModel;
 import org.slf4j.Logger;
@@ -120,7 +122,7 @@ public class SegmentPartitionMetadataManager implements 
SegmentZkMetadataFetchLi
   }
 
   private void computeTablePartitionInfo() {
-    TablePartitionInfo.PartitionInfo[] partitionInfoMap = new 
TablePartitionInfo.PartitionInfo[_numPartitions];
+    PartitionInfo[] partitionInfoMap = new PartitionInfo[_numPartitions];
     Set<String> segmentsWithInvalidPartition = new HashSet<>();
     for (Map.Entry<String, SegmentInfo> entry : _segmentInfoMap.entrySet()) {
       String segment = entry.getKey();
@@ -131,16 +133,16 @@ public class SegmentPartitionMetadataManager implements 
SegmentZkMetadataFetchLi
         segmentsWithInvalidPartition.add(segment);
         continue;
       }
-      TablePartitionInfo.PartitionInfo partitionInfo = 
partitionInfoMap[partitionId];
+      PartitionInfo partitionInfo = partitionInfoMap[partitionId];
       if (partitionInfo == null) {
-        partitionInfo = new TablePartitionInfo.PartitionInfo();
-        partitionInfo._segments = new ArrayList<>();
-        partitionInfo._segments.add(segment);
-        partitionInfo._fullyReplicatedServers = new HashSet<>(onlineServers);
+        Set<String> fullyReplicatedServers = new HashSet<>(onlineServers);
+        List<String> segments = new ArrayList<>();
+        segments.add(segment);
+        partitionInfo = new PartitionInfo(fullyReplicatedServers, segments);
         partitionInfoMap[partitionId] = partitionInfo;
       } else {
-        partitionInfo._segments.add(segment);
         partitionInfo._fullyReplicatedServers.retainAll(onlineServers);
+        partitionInfo._segments.add(segment);
       }
     }
     if (!segmentsWithInvalidPartition.isEmpty()) {
diff --git 
a/pinot-broker/src/test/java/org/apache/pinot/broker/routing/segmentpartition/SegmentPartitionMetadataManagerTest.java
 
b/pinot-broker/src/test/java/org/apache/pinot/broker/routing/segmentpartition/SegmentPartitionMetadataManagerTest.java
index 5cc80bed60..6d5999707b 100644
--- 
a/pinot-broker/src/test/java/org/apache/pinot/broker/routing/segmentpartition/SegmentPartitionMetadataManagerTest.java
+++ 
b/pinot-broker/src/test/java/org/apache/pinot/broker/routing/segmentpartition/SegmentPartitionMetadataManagerTest.java
@@ -37,6 +37,7 @@ import org.apache.pinot.common.metadata.ZKMetadataProvider;
 import org.apache.pinot.common.metadata.segment.SegmentPartitionMetadata;
 import org.apache.pinot.common.metadata.segment.SegmentZKMetadata;
 import org.apache.pinot.controller.helix.ControllerTest;
+import org.apache.pinot.core.routing.TablePartitionInfo;
 import org.apache.pinot.segment.spi.partition.metadata.ColumnPartitionMetadata;
 import org.apache.pinot.spi.config.table.ColumnPartitionConfig;
 import org.apache.pinot.spi.config.table.SegmentPartitionConfig;
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/routing/RoutingManager.java 
b/pinot-core/src/main/java/org/apache/pinot/core/routing/RoutingManager.java
index 2d6ad0a8ac..5232addfca 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/routing/RoutingManager.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/routing/RoutingManager.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.core.routing;
 
 import java.util.Map;
+import javax.annotation.Nullable;
 import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.spi.annotations.InterfaceAudience;
@@ -50,6 +51,7 @@ public interface RoutingManager {
    * @param brokerRequest the broker request constructed from a query.
    * @return the route table.
    */
+  @Nullable
   RoutingTable getRoutingTable(BrokerRequest brokerRequest, long requestId);
 
   /**
@@ -66,8 +68,15 @@ public interface RoutingManager {
    * @param offlineTableName offline table name
    * @return time boundary info.
    */
+  @Nullable
   TimeBoundaryInfo getTimeBoundaryInfo(String offlineTableName);
 
+  /**
+   * Returns the {@link TablePartitionInfo} for a given table.
+   */
+  @Nullable
+  TablePartitionInfo getTablePartitionInfo(String tableNameWithType);
+
   /**
    * Returns all enabled server instances for a given table's server tenant.
    *
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpartition/TablePartitionInfo.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/routing/TablePartitionInfo.java
similarity index 87%
rename from 
pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpartition/TablePartitionInfo.java
rename to 
pinot-core/src/main/java/org/apache/pinot/core/routing/TablePartitionInfo.java
index c2d95f108a..1faef75c77 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpartition/TablePartitionInfo.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/routing/TablePartitionInfo.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.broker.routing.segmentpartition;
+package org.apache.pinot.core.routing;
 
 import java.util.List;
 import java.util.Set;
@@ -65,7 +65,12 @@ public class TablePartitionInfo {
   }
 
   public static class PartitionInfo {
-    List<String> _segments;
-    Set<String> _fullyReplicatedServers;
+    public final Set<String> _fullyReplicatedServers;
+    public final List<String> _segments;
+
+    public PartitionInfo(Set<String> fullyReplicatedServers, List<String> 
segments) {
+      _fullyReplicatedServers = fullyReplicatedServers;
+      _segments = segments;
+    }
   }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java
 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java
index 3dcad6998f..692a02d1e9 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java
@@ -65,8 +65,9 @@ public class PinotJoinExchangeNodeInsertRule extends 
RelOptRule {
     RelNode rightExchange;
     JoinInfo joinInfo = join.analyzeCondition();
 
-    boolean isColocatedJoin = 
PinotHintStrategyTable.containsHintOption(join.getHints(),
-        PinotHintOptions.JOIN_HINT_OPTIONS, 
PinotHintOptions.JoinHintOptions.IS_COLOCATED_BY_JOIN_KEYS);
+    boolean isColocatedJoin =
+        PinotHintStrategyTable.containsHintOption(join.getHints(), 
PinotHintOptions.JOIN_HINT_OPTIONS,
+            PinotHintOptions.JoinHintOptions.IS_COLOCATED_BY_JOIN_KEYS);
     if (isColocatedJoin) {
       // join exchange are colocated, we should directly pass through via join 
key
       leftExchange = PinotLogicalExchange.create(leftInput, 
RelDistributions.SINGLETON);
@@ -82,10 +83,9 @@ public class PinotJoinExchangeNodeInsertRule extends 
RelOptRule {
     }
 
     RelNode newJoinNode =
-        new LogicalJoin(join.getCluster(), join.getTraitSet(), leftExchange, 
rightExchange, join.getCondition(),
-            join.getVariablesSet(), join.getJoinType(), join.isSemiJoinDone(),
+        new LogicalJoin(join.getCluster(), join.getTraitSet(), 
join.getHints(), leftExchange, rightExchange,
+            join.getCondition(), join.getVariablesSet(), join.getJoinType(), 
join.isSemiJoinDone(),
             ImmutableList.copyOf(join.getSystemFieldList()));
-
     call.transformTo(newJoinNode);
   }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
index 63da99df08..6b1dc25642 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
@@ -31,6 +31,8 @@ import org.apache.calcite.rel.core.JoinInfo;
 import org.apache.calcite.rel.core.JoinRelType;
 import org.apache.calcite.rel.core.SetOp;
 import org.apache.calcite.rel.core.SortExchange;
+import org.apache.calcite.rel.hint.PinotHintOptions;
+import org.apache.calcite.rel.hint.PinotHintStrategyTable;
 import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.rel.logical.LogicalFilter;
 import org.apache.calcite.rel.logical.LogicalJoin;
@@ -178,12 +180,15 @@ public final class RelToPlanNodeConverter {
 
     // Parse out all equality JOIN conditions
     JoinInfo joinInfo = node.analyzeCondition();
-    FieldSelectionKeySelector leftFieldSelectionKeySelector = new 
FieldSelectionKeySelector(joinInfo.leftKeys);
-    FieldSelectionKeySelector rightFieldSelectionKeySelector = new 
FieldSelectionKeySelector(joinInfo.rightKeys);
+    JoinNode.JoinKeys joinKeys = new JoinNode.JoinKeys(new 
FieldSelectionKeySelector(joinInfo.leftKeys),
+        new FieldSelectionKeySelector(joinInfo.rightKeys));
+    List<RexExpression> joinClause =
+        
joinInfo.nonEquiConditions.stream().map(RexExpression::toRexExpression).collect(Collectors.toList());
+    boolean isColocatedJoin =
+        PinotHintStrategyTable.containsHintOption(node.getHints(), 
PinotHintOptions.JOIN_HINT_OPTIONS,
+            PinotHintOptions.JoinHintOptions.IS_COLOCATED_BY_JOIN_KEYS);
     return new JoinNode(currentStageId, toDataSchema(node.getRowType()), 
toDataSchema(node.getLeft().getRowType()),
-        toDataSchema(node.getRight().getRowType()), joinType,
-        new JoinNode.JoinKeys(leftFieldSelectionKeySelector, 
rightFieldSelectionKeySelector),
-        
joinInfo.nonEquiConditions.stream().map(RexExpression::toRexExpression).collect(Collectors.toList()));
+        toDataSchema(node.getRight().getRowType()), joinType, joinKeys, 
joinClause, isColocatedJoin);
   }
 
   private static DataSchema toDataSchema(RelDataType rowType) {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/MailboxAssignmentVisitor.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/MailboxAssignmentVisitor.java
index 5b3040fd35..180f5a413a 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/MailboxAssignmentVisitor.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/MailboxAssignmentVisitor.java
@@ -18,11 +18,12 @@
  */
 package org.apache.pinot.query.planner.physical;
 
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import org.apache.calcite.rel.RelDistribution;
 import 
org.apache.pinot.query.planner.plannode.DefaultPostOrderTraversalVisitor;
-import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
 import org.apache.pinot.query.planner.plannode.MailboxSendNode;
 import org.apache.pinot.query.planner.plannode.PlanNode;
 import org.apache.pinot.query.routing.MailboxMetadata;
@@ -35,57 +36,58 @@ public class MailboxAssignmentVisitor extends 
DefaultPostOrderTraversalVisitor<V
 
   @Override
   public Void process(PlanNode node, DispatchablePlanContext context) {
-    if (node instanceof MailboxSendNode || node instanceof MailboxReceiveNode) 
{
-      int receiverStageId =
-          isMailboxReceiveNode(node) ? node.getPlanFragmentId() : 
((MailboxSendNode) node).getReceiverStageId();
-      int senderStageId =
-          isMailboxReceiveNode(node) ? ((MailboxReceiveNode) 
node).getSenderStageId() : node.getPlanFragmentId();
-      DispatchablePlanMetadata receiverStagePlanMetadata =
-          context.getDispatchablePlanMetadataMap().get(receiverStageId);
-      DispatchablePlanMetadata senderStagePlanMetadata = 
context.getDispatchablePlanMetadataMap().get(senderStageId);
-      
receiverStagePlanMetadata.getServerInstanceToWorkerIdMap().entrySet().stream().forEach(receiverEntry
 -> {
-        QueryServerInstance receiverServerInstance = receiverEntry.getKey();
-        List<Integer> receiverWorkerIds = receiverEntry.getValue();
-        for (int receiverWorkerId : receiverWorkerIds) {
-          
receiverStagePlanMetadata.getWorkerIdToMailBoxIdsMap().putIfAbsent(receiverWorkerId,
 new HashMap<>());
-          
senderStagePlanMetadata.getServerInstanceToWorkerIdMap().entrySet().stream().forEach(senderEntry
 -> {
-            QueryServerInstance senderServerInstance = senderEntry.getKey();
-            List<Integer> senderWorkerIds = senderEntry.getValue();
-            for (int senderWorkerId : senderWorkerIds) {
-              MailboxMetadata mailboxMetadata =
-                  isMailboxReceiveNode(node)
-                      ? getMailboxMetadata(receiverStagePlanMetadata, 
senderStageId, receiverWorkerId)
-                      : getMailboxMetadata(senderStagePlanMetadata, 
receiverStageId, senderWorkerId);
-              mailboxMetadata.getMailBoxIdList().add(
-                  MailboxIdUtils.toPlanMailboxId(senderStageId, 
senderWorkerId, receiverStageId, receiverWorkerId));
-              VirtualServerAddress virtualServerAddress =
-                  isMailboxReceiveNode(node)
-                      ? new VirtualServerAddress(senderServerInstance, 
senderWorkerId)
-                      : new VirtualServerAddress(receiverServerInstance, 
receiverWorkerId);
-              
mailboxMetadata.getVirtualAddressList().add(virtualServerAddress);
-            }
-          });
-        }
-      });
-    }
-    return null;
-  }
+    if (node instanceof MailboxSendNode) {
+      MailboxSendNode sendNode = (MailboxSendNode) node;
+      int senderFragmentId = sendNode.getPlanFragmentId();
+      int receiverFragmentId = sendNode.getReceiverStageId();
+      Map<Integer, DispatchablePlanMetadata> metadataMap = 
context.getDispatchablePlanMetadataMap();
+      DispatchablePlanMetadata senderMetadata = 
metadataMap.get(senderFragmentId);
+      DispatchablePlanMetadata receiverMetadata = 
metadataMap.get(receiverFragmentId);
+      Map<QueryServerInstance, List<Integer>> senderWorkerIdsMap = 
senderMetadata.getServerInstanceToWorkerIdMap();
+      Map<QueryServerInstance, List<Integer>> receiverWorkerIdsMap = 
receiverMetadata.getServerInstanceToWorkerIdMap();
+      Map<Integer, Map<Integer, MailboxMetadata>> senderMailboxesMap = 
senderMetadata.getWorkerIdToMailBoxIdsMap();
+      Map<Integer, Map<Integer, MailboxMetadata>> receiverMailboxesMap = 
receiverMetadata.getWorkerIdToMailBoxIdsMap();
 
-  private static boolean isMailboxReceiveNode(PlanNode node) {
-    return node instanceof MailboxReceiveNode;
-  }
-
-  private MailboxMetadata getMailboxMetadata(DispatchablePlanMetadata 
dispatchablePlanMetadata, int planFragmentId,
-      int workerId) {
-    Map<Integer, Map<Integer, MailboxMetadata>> workerIdToMailBoxIdsMap =
-        dispatchablePlanMetadata.getWorkerIdToMailBoxIdsMap();
-    if (!workerIdToMailBoxIdsMap.containsKey(workerId)) {
-      workerIdToMailBoxIdsMap.put(workerId, new HashMap<>());
-    }
-    Map<Integer, MailboxMetadata> planFragmentToMailboxMetadataMap = 
workerIdToMailBoxIdsMap.get(workerId);
-    if (!planFragmentToMailboxMetadataMap.containsKey(planFragmentId)) {
-      planFragmentToMailboxMetadataMap.put(planFragmentId, new 
MailboxMetadata());
+      if (sendNode.getDistributionType() == RelDistribution.Type.SINGLETON) {
+        // For SINGLETON exchange type, send the data to the same instance 
(same worker id)
+        senderWorkerIdsMap.forEach((serverInstance, workerIds) -> {
+          for (int workerId : workerIds) {
+            MailboxMetadata mailboxMetadata = new 
MailboxMetadata(Collections.singletonList(
+                MailboxIdUtils.toPlanMailboxId(senderFragmentId, workerId, 
receiverFragmentId, workerId)),
+                Collections.singletonList(new 
VirtualServerAddress(serverInstance, workerId)), Collections.emptyMap());
+            senderMailboxesMap.computeIfAbsent(workerId, k -> new 
HashMap<>()).put(receiverFragmentId, mailboxMetadata);
+            receiverMailboxesMap.computeIfAbsent(workerId, k -> new 
HashMap<>()).put(senderFragmentId, mailboxMetadata);
+          }
+        });
+      } else {
+        // For other exchange types, send the data to all the instances in the 
receiver fragment
+        // TODO: Add support for more exchange types
+        senderWorkerIdsMap.forEach((senderServerInstance, senderWorkerIds) -> {
+          for (int senderWorkerId : senderWorkerIds) {
+            Map<Integer, MailboxMetadata> senderMailboxMetadataMap =
+                senderMailboxesMap.computeIfAbsent(senderWorkerId, k -> new 
HashMap<>());
+            receiverWorkerIdsMap.forEach((receiverServerInstance, 
receiverWorkerIds) -> {
+              for (int receiverWorkerId : receiverWorkerIds) {
+                Map<Integer, MailboxMetadata> receiverMailboxMetadataMap =
+                    receiverMailboxesMap.computeIfAbsent(receiverWorkerId, k 
-> new HashMap<>());
+                String mailboxId = 
MailboxIdUtils.toPlanMailboxId(senderFragmentId, senderWorkerId, 
receiverFragmentId,
+                    receiverWorkerId);
+                MailboxMetadata senderMailboxMetadata =
+                    
senderMailboxMetadataMap.computeIfAbsent(receiverFragmentId, k -> new 
MailboxMetadata());
+                senderMailboxMetadata.getMailBoxIdList().add(mailboxId);
+                senderMailboxMetadata.getVirtualAddressList()
+                    .add(new VirtualServerAddress(receiverServerInstance, 
receiverWorkerId));
+                MailboxMetadata receiverMailboxMetadata =
+                    
receiverMailboxMetadataMap.computeIfAbsent(senderFragmentId, k -> new 
MailboxMetadata());
+                receiverMailboxMetadata.getMailBoxIdList().add(mailboxId);
+                receiverMailboxMetadata.getVirtualAddressList()
+                    .add(new VirtualServerAddress(senderServerInstance, 
senderWorkerId));
+              }
+            });
+          }
+        });
+      }
     }
-    return planFragmentToMailboxMetadataMap.get(planFragmentId);
+    return null;
   }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
index bded25f200..521a99f39c 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
@@ -50,28 +50,24 @@ public class PinotDispatchPlanner {
    */
   public DispatchableSubPlan createDispatchableSubPlan(SubPlan subPlan) {
     // perform physical plan conversion and assign workers to each stage.
-    DispatchablePlanContext dispatchablePlanContext = new 
DispatchablePlanContext(_workerManager, _requestId,
-        _plannerContext, subPlan.getSubPlanMetadata().getFields(), 
subPlan.getSubPlanMetadata().getTableNames());
-    PlanNode subPlanRoot = subPlan.getSubPlanRoot().getFragmentRoot();
+    DispatchablePlanContext context = new 
DispatchablePlanContext(_workerManager, _requestId, _plannerContext,
+        subPlan.getSubPlanMetadata().getFields(), 
subPlan.getSubPlanMetadata().getTableNames());
+    PlanFragment rootFragment = subPlan.getSubPlanRoot();
+    PlanNode rootNode = rootFragment.getFragmentRoot();
     // 1. start by visiting the sub plan fragment root.
-    subPlanRoot.visit(DispatchablePlanVisitor.INSTANCE, 
dispatchablePlanContext);
+    rootNode.visit(DispatchablePlanVisitor.INSTANCE, context);
     // 2. add a special stage for the global mailbox receive, this runs on the 
dispatcher.
-    dispatchablePlanContext.getDispatchablePlanStageRootMap().put(0, 
subPlanRoot);
+    context.getDispatchablePlanStageRootMap().put(0, rootNode);
     // 3. add worker assignment after the dispatchable plan context is 
fulfilled after the visit.
-    computeWorkerAssignment(subPlan.getSubPlanRoot(), dispatchablePlanContext);
+    context.getWorkerManager().assignWorkers(rootFragment, context);
     // 4. compute the mailbox assignment for each stage.
     // TODO: refactor this to be a pluggable interface.
-    computeMailboxAssignment(dispatchablePlanContext);
+    rootNode.visit(MailboxAssignmentVisitor.INSTANCE, context);
     // 5. Run physical optimizations
-    runPhysicalOptimizers(subPlanRoot, dispatchablePlanContext, _tableCache);
+    runPhysicalOptimizers(rootNode, context, _tableCache);
     // 6. convert it into query plan.
     // TODO: refactor this to be a pluggable interface.
-    return finalizeDispatchableSubPlan(subPlan.getSubPlanRoot(), 
dispatchablePlanContext);
-  }
-
-  private void computeMailboxAssignment(DispatchablePlanContext 
dispatchablePlanContext) {
-    
dispatchablePlanContext.getDispatchablePlanStageRootMap().get(0).visit(MailboxAssignmentVisitor.INSTANCE,
-        dispatchablePlanContext);
+    return finalizeDispatchableSubPlan(rootFragment, context);
   }
 
   // TODO: Switch to Worker SPI to avoid multiple-places where workers are 
assigned.
@@ -90,16 +86,4 @@ public class PinotDispatchPlanner {
         
dispatchablePlanContext.constructDispatchablePlanFragmentList(subPlanRoot),
         dispatchablePlanContext.getTableNames());
   }
-
-  private static void computeWorkerAssignment(PlanFragment planFragment, 
DispatchablePlanContext context) {
-    computeWorkerAssignment(planFragment.getFragmentRoot(), context);
-    planFragment.getChildren().forEach(child -> computeWorkerAssignment(child, 
context));
-  }
-
-  private static void computeWorkerAssignment(PlanNode node, 
DispatchablePlanContext context) {
-    int planFragmentId = node.getPlanFragmentId();
-    context.getWorkerManager()
-        .assignWorkerToStage(planFragmentId, 
context.getDispatchablePlanMetadataMap().get(planFragmentId),
-            context.getRequestId(), context.getPlannerContext().getOptions(), 
context.getTableNames());
-  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/JoinNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/JoinNode.java
index 6d089c6239..b0f576258c 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/JoinNode.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/JoinNode.java
@@ -36,6 +36,8 @@ public class JoinNode extends AbstractPlanNode {
   @ProtoProperties
   private List<RexExpression> _joinClause;
   @ProtoProperties
+  private boolean _isColocatedJoin;
+  @ProtoProperties
   private List<String> _leftColumnNames;
   @ProtoProperties
   private List<String> _rightColumnNames;
@@ -45,13 +47,14 @@ public class JoinNode extends AbstractPlanNode {
   }
 
   public JoinNode(int planFragmentId, DataSchema dataSchema, DataSchema 
leftSchema, DataSchema rightSchema,
-      JoinRelType joinRelType, JoinKeys joinKeys, List<RexExpression> 
joinClause) {
+      JoinRelType joinRelType, JoinKeys joinKeys, List<RexExpression> 
joinClause, boolean isColocatedJoin) {
     super(planFragmentId, dataSchema);
     _leftColumnNames = Arrays.asList(leftSchema.getColumnNames());
     _rightColumnNames = Arrays.asList(rightSchema.getColumnNames());
     _joinRelType = joinRelType;
     _joinKeys = joinKeys;
     _joinClause = joinClause;
+    _isColocatedJoin = isColocatedJoin;
   }
 
   public JoinRelType getJoinRelType() {
@@ -66,6 +69,10 @@ public class JoinNode extends AbstractPlanNode {
     return _joinClause;
   }
 
+  public boolean isColocatedJoin() {
+    return _isColocatedJoin;
+  }
+
   public List<String> getLeftColumnNames() {
     return _leftColumnNames;
   }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
index 7ebc6e96d8..76e596a110 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
@@ -27,17 +27,25 @@ import java.util.List;
 import java.util.Map;
 import java.util.Random;
 import java.util.Set;
-import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.pinot.core.routing.RoutingManager;
 import org.apache.pinot.core.routing.RoutingTable;
+import org.apache.pinot.core.routing.TablePartitionInfo;
+import org.apache.pinot.core.routing.TablePartitionInfo.PartitionInfo;
 import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.core.transport.ServerInstance;
-import org.apache.pinot.query.planner.PlannerUtils;
+import org.apache.pinot.query.planner.PlanFragment;
+import org.apache.pinot.query.planner.physical.DispatchablePlanContext;
 import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
+import org.apache.pinot.query.planner.plannode.JoinNode;
+import org.apache.pinot.query.planner.plannode.PlanNode;
 import org.apache.pinot.spi.config.table.TableType;
-import org.apache.pinot.spi.utils.CommonConstants;
+import 
org.apache.pinot.spi.utils.CommonConstants.Broker.Request.QueryOptionKey;
 import org.apache.pinot.spi.utils.builder.TableNameBuilder;
 import org.apache.pinot.sql.parsers.CalciteSqlCompiler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 
 /**
@@ -45,11 +53,9 @@ import org.apache.pinot.sql.parsers.CalciteSqlCompiler;
  *
  * <p>It contains the logic to assign worker to a particular stages. If it is 
a leaf stage the logic fallback to
  * how Pinot server assigned server and server-segment mapping.
- *
- * TODO: Currently it is implemented by wrapping routing manager from Pinot 
Broker. however we can abstract out
- * the worker manager later when we split out the query-spi layer.
  */
 public class WorkerManager {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(WorkerManager.class);
   private static final Random RANDOM = new Random();
 
   private final String _hostName;
@@ -62,42 +68,48 @@ public class WorkerManager {
     _routingManager = routingManager;
   }
 
-  public void assignWorkerToStage(int planFragmentId, DispatchablePlanMetadata 
dispatchablePlanMetadata, long requestId,
-      Map<String, String> options, Set<String> tableNames) {
-    if (PlannerUtils.isRootPlanFragment(planFragmentId)) {
-      // --- ROOT STAGE / BROKER REDUCE STAGE ---
-      // ROOT stage doesn't have a QueryServer as it is strictly only reducing 
results.
-      // here we simply assign the worker instance with identical 
server/mailbox port number.
-      
dispatchablePlanMetadata.setServerInstanceToWorkerIdMap(Collections.singletonMap(
-          new QueryServerInstance(_hostName, _port, _port), 
Collections.singletonList(0)));
-      dispatchablePlanMetadata.setTotalWorkerCount(1);
-    } else if (isLeafStage(dispatchablePlanMetadata)) {
-      // --- LEAF STAGE ---
-      assignWorkerToLeafStage(requestId, dispatchablePlanMetadata);
+  public void assignWorkers(PlanFragment rootFragment, DispatchablePlanContext 
context) {
+    // ROOT stage doesn't have a QueryServer as it is strictly only reducing 
results, so here we simply assign the
+    // worker instance with identical server/mailbox port number.
+    DispatchablePlanMetadata metadata = 
context.getDispatchablePlanMetadataMap().get(0);
+    metadata.setServerInstanceToWorkerIdMap(
+        Collections.singletonMap(new QueryServerInstance(_hostName, _port, 
_port), Collections.singletonList(0)));
+    metadata.setTotalWorkerCount(1);
+    for (PlanFragment child : rootFragment.getChildren()) {
+      assignWorkersToNonRootFragment(child, context);
+    }
+  }
+
+  private void assignWorkersToNonRootFragment(PlanFragment fragment, 
DispatchablePlanContext context) {
+    if 
(isLeafPlan(context.getDispatchablePlanMetadataMap().get(fragment.getFragmentId())))
 {
+      assignWorkersToLeafFragment(fragment, context);
     } else {
-      // --- INTERMEDIATE STAGES ---
-      // If the query has more than one table, it is possible that the tables 
could be hosted on different tenants.
-      // The intermediate stage will be processed on servers randomly picked 
from the tenants belonging to either or
-      // all of the tables in the query.
-      // TODO: actually make assignment strategy decisions for intermediate 
stages
-      assignWorkerToIntermediateStage(dispatchablePlanMetadata, tableNames, 
options);
+      assignWorkersToIntermediateFragment(fragment, context);
     }
   }
 
-  private void assignWorkerToLeafStage(long requestId, 
DispatchablePlanMetadata dispatchablePlanMetadata) {
+  // TODO: Find a better way to determine whether a stage is leaf stage or 
intermediary. We could have query plans that
+  //       process table data even in intermediary stages.
+  private static boolean isLeafPlan(DispatchablePlanMetadata metadata) {
+    return metadata.getScannedTables().size() == 1;
+  }
+
+  private void assignWorkersToLeafFragment(PlanFragment fragment, 
DispatchablePlanContext context) {
+    DispatchablePlanMetadata metadata = 
context.getDispatchablePlanMetadataMap().get(fragment.getFragmentId());
     // table scan stage, need to attach server as well as segment info for 
each physical table type.
-    List<String> scannedTables = dispatchablePlanMetadata.getScannedTables();
+    List<String> scannedTables = metadata.getScannedTables();
     String logicalTableName = scannedTables.get(0);
-    Map<String, RoutingTable> routingTableMap = 
getRoutingTable(logicalTableName, requestId);
+    Map<String, RoutingTable> routingTableMap = 
getRoutingTable(logicalTableName, context.getRequestId());
     if (routingTableMap.size() == 0) {
       throw new IllegalArgumentException("Unable to find routing entries for 
table: " + logicalTableName);
     }
     // acquire time boundary info if it is a hybrid table.
     if (routingTableMap.size() > 1) {
-      TimeBoundaryInfo timeBoundaryInfo = 
_routingManager.getTimeBoundaryInfo(TableNameBuilder
-          
.forType(TableType.OFFLINE).tableNameWithType(TableNameBuilder.extractRawTableName(logicalTableName)));
+      TimeBoundaryInfo timeBoundaryInfo = _routingManager.getTimeBoundaryInfo(
+          TableNameBuilder.forType(TableType.OFFLINE)
+              
.tableNameWithType(TableNameBuilder.extractRawTableName(logicalTableName)));
       if (timeBoundaryInfo != null) {
-        dispatchablePlanMetadata.setTimeBoundaryInfo(timeBoundaryInfo);
+        metadata.setTimeBoundaryInfo(timeBoundaryInfo);
       } else {
         // remove offline table routing if no time boundary info is acquired.
         routingTableMap.remove(TableType.OFFLINE.name());
@@ -110,8 +122,8 @@ public class WorkerManager {
       String tableType = routingEntry.getKey();
       RoutingTable routingTable = routingEntry.getValue();
       // for each server instance, attach all table types and their associated 
segment list.
-      for (Map.Entry<ServerInstance, List<String>> serverEntry
-          : routingTable.getServerInstanceToSegmentsMap().entrySet()) {
+      for (Map.Entry<ServerInstance, List<String>> serverEntry : 
routingTable.getServerInstanceToSegmentsMap()
+          .entrySet()) {
         serverInstanceToSegmentsMap.putIfAbsent(serverEntry.getKey(), new 
HashMap<>());
         Map<String, List<String>> tableTypeToSegmentListMap = 
serverInstanceToSegmentsMap.get(serverEntry.getKey());
         Preconditions.checkState(tableTypeToSegmentListMap.put(tableType, 
serverEntry.getValue()) == null,
@@ -127,53 +139,14 @@ public class WorkerManager {
       workerIdToSegmentsMap.put(globalIdx, entry.getValue());
       globalIdx++;
     }
-    
dispatchablePlanMetadata.setServerInstanceToWorkerIdMap(serverInstanceToWorkerIdMap);
-    dispatchablePlanMetadata.setWorkerIdToSegmentsMap(workerIdToSegmentsMap);
-    dispatchablePlanMetadata.setTotalWorkerCount(globalIdx);
-  }
+    metadata.setServerInstanceToWorkerIdMap(serverInstanceToWorkerIdMap);
+    metadata.setWorkerIdToSegmentsMap(workerIdToSegmentsMap);
+    metadata.setTotalWorkerCount(globalIdx);
 
-  private void assignWorkerToIntermediateStage(DispatchablePlanMetadata 
dispatchablePlanMetadata,
-      Set<String> tableNames, Map<String, String> options) {
-    // If the query has more than one table, it is possible that the tables 
could be hosted on different tenants.
-    // The intermediate stage will be processed on servers randomly picked 
from the tenants belonging to either or
-    // all of the tables in the query.
-    // TODO: actually make assignment strategy decisions for intermediate 
stages
-    Set<ServerInstance> serverInstances = new HashSet<>();
-    if (tableNames.size() == 0) {
-      // This could be the case from queries that don't actually fetch values 
from the tables. In such cases the
-      // routing need not be tenant aware.
-      // Eg: SELECT 1 AS one FROM select_having_expression_test_test_having 
HAVING 1 > 2;
-      serverInstances = 
_routingManager.getEnabledServerInstanceMap().values().stream().collect(Collectors.toSet());
-    } else {
-      serverInstances = fetchServersForIntermediateStage(tableNames);
+    // NOTE: For pipeline breaker, leaf fragment can also have children
+    for (PlanFragment child : fragment.getChildren()) {
+      assignWorkersToNonRootFragment(child, context);
     }
-    assignServers(dispatchablePlanMetadata, serverInstances, 
dispatchablePlanMetadata.isRequiresSingletonInstance(),
-        options);
-  }
-
-  private static void assignServers(DispatchablePlanMetadata 
dispatchablePlanMetadata, Set<ServerInstance> servers,
-      boolean requiresSingletonInstance, Map<String, String> options) {
-    int stageParallelism = Integer.parseInt(
-        
options.getOrDefault(CommonConstants.Broker.Request.QueryOptionKey.STAGE_PARALLELISM,
 "1"));
-    List<ServerInstance> serverInstances = new ArrayList<>(servers);
-    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap = new 
HashMap<>();
-    if (requiresSingletonInstance) {
-      // require singleton should return a single global worker ID with 0;
-      ServerInstance serverInstance = 
serverInstances.get(RANDOM.nextInt(serverInstances.size()));
-      serverInstanceToWorkerIdMap.put(new QueryServerInstance(serverInstance), 
Collections.singletonList(0));
-      dispatchablePlanMetadata.setTotalWorkerCount(1);
-    } else {
-      int globalIdx = 0;
-      for (ServerInstance server : servers) {
-        List<Integer> workerIdList = new ArrayList<>();
-        for (int virtualId = 0; virtualId < stageParallelism; virtualId++) {
-          workerIdList.add(globalIdx++);
-        }
-        serverInstanceToWorkerIdMap.put(new QueryServerInstance(server), 
workerIdList);
-      }
-      dispatchablePlanMetadata.setTotalWorkerCount(globalIdx);
-    }
-    
dispatchablePlanMetadata.setServerInstanceToWorkerIdMap(serverInstanceToWorkerIdMap);
   }
 
   /**
@@ -206,41 +179,321 @@ public class WorkerManager {
   }
 
   private RoutingTable getRoutingTable(String tableName, TableType tableType, 
long requestId) {
-    String tableNameWithType = 
TableNameBuilder.forType(tableType).tableNameWithType(
-        TableNameBuilder.extractRawTableName(tableName));
+    String tableNameWithType =
+        
TableNameBuilder.forType(tableType).tableNameWithType(TableNameBuilder.extractRawTableName(tableName));
     return _routingManager.getRoutingTable(
         CalciteSqlCompiler.compileToBrokerRequest("SELECT * FROM " + 
tableNameWithType), requestId);
   }
 
-  // TODO: Find a better way to determine whether a stage is leaf stage or 
intermediary. We could have query plans that
-  //       process table data even in intermediary stages.
-  private boolean isLeafStage(DispatchablePlanMetadata 
dispatchablePlanMetadata) {
-    return dispatchablePlanMetadata.getScannedTables().size() == 1;
+  private void assignWorkersToIntermediateFragment(PlanFragment fragment, 
DispatchablePlanContext context) {
+    if (isColocatedJoin(fragment.getFragmentRoot())) {
+      // TODO: Make it more general so that it can be used for other 
partitioned cases (e.g. group-by, window function)
+      try {
+        assignWorkersForColocatedJoin(fragment, context);
+        return;
+      } catch (Exception e) {
+        LOGGER.warn("[RequestId: {}] Caught exception while assigning workers 
for colocated join, "
+            + "falling back to regular worker assignment", 
context.getRequestId(), e);
+      }
+    }
+
+    // If the query has more than one table, it is possible that the tables 
could be hosted on different tenants.
+    // The intermediate stage will be processed on servers randomly picked 
from the tenants belonging to either or
+    // all of the tables in the query.
+    // TODO: actually make assignment strategy decisions for intermediate 
stages
+    List<ServerInstance> serverInstances;
+    Set<String> tableNames = context.getTableNames();
+    if (tableNames.size() == 0) {
+      // TODO: Short circuit it when no table needs to be scanned
+      // This could be the case from queries that don't actually fetch values 
from the tables. In such cases the
+      // routing need not be tenant aware.
+      // Eg: SELECT 1 AS one FROM select_having_expression_test_test_having 
HAVING 1 > 2;
+      serverInstances = new 
ArrayList<>(_routingManager.getEnabledServerInstanceMap().values());
+    } else {
+      serverInstances = fetchServersForIntermediateStage(tableNames);
+    }
+    DispatchablePlanMetadata metadata = 
context.getDispatchablePlanMetadataMap().get(fragment.getFragmentId());
+    Map<String, String> options = context.getPlannerContext().getOptions();
+    int stageParallelism = 
Integer.parseInt(options.getOrDefault(QueryOptionKey.STAGE_PARALLELISM, "1"));
+    if (metadata.isRequiresSingletonInstance()) {
+      // require singleton should return a single global worker ID with 0;
+      ServerInstance serverInstance = 
serverInstances.get(RANDOM.nextInt(serverInstances.size()));
+      metadata.setServerInstanceToWorkerIdMap(
+          Collections.singletonMap(new QueryServerInstance(serverInstance), 
Collections.singletonList(0)));
+      metadata.setTotalWorkerCount(1);
+    } else {
+      Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap = 
new HashMap<>();
+      int nextWorkerId = 0;
+      for (ServerInstance serverInstance : serverInstances) {
+        List<Integer> workerIds = new ArrayList<>();
+        for (int i = 0; i < stageParallelism; i++) {
+          workerIds.add(nextWorkerId++);
+        }
+        serverInstanceToWorkerIdMap.put(new 
QueryServerInstance(serverInstance), workerIds);
+      }
+      metadata.setServerInstanceToWorkerIdMap(serverInstanceToWorkerIdMap);
+      metadata.setTotalWorkerCount(nextWorkerId);
+    }
+
+    for (PlanFragment child : fragment.getChildren()) {
+      assignWorkersToNonRootFragment(child, context);
+    }
   }
 
-  private Set<ServerInstance> fetchServersForIntermediateStage(Set<String> 
tableNames) {
-    Set<ServerInstance> serverInstances = new HashSet<>();
+  private boolean isColocatedJoin(PlanNode planNode) {
+    if (planNode instanceof JoinNode) {
+      return ((JoinNode) planNode).isColocatedJoin();
+    }
+    for (PlanNode child : planNode.getInputs()) {
+      if (isColocatedJoin(child)) {
+        return true;
+      }
+    }
+    return false;
+  }
 
-    for (String table : tableNames) {
-      String rawTableName = TableNameBuilder.extractRawTableName(table);
-      TableType tableType = TableNameBuilder.getTableTypeFromTableName(table);
-      if (tableType == null) {
-        String offlineTable = 
TableNameBuilder.forType(TableType.OFFLINE).tableNameWithType(rawTableName);
-        String realtimeTable = 
TableNameBuilder.forType(TableType.REALTIME).tableNameWithType(rawTableName);
+  private void assignWorkersForColocatedJoin(PlanFragment fragment, 
DispatchablePlanContext context) {
+    List<PlanFragment> children = fragment.getChildren();
+    Preconditions.checkArgument(children.size() == 2, "Expecting 2 children, 
find: %s", children.size());
+    PlanFragment leftFragment = children.get(0);
+    PlanFragment rightFragment = children.get(1);
+    Map<Integer, DispatchablePlanMetadata> metadataMap = 
context.getDispatchablePlanMetadataMap();
+    // TODO: Support multi-level colocated join (more than 2 tables colocated)
+    DispatchablePlanMetadata leftMetadata = 
metadataMap.get(leftFragment.getFragmentId());
+    Preconditions.checkArgument(isLeafPlan(leftMetadata), "Left side is not 
leaf");
+    DispatchablePlanMetadata rightMetadata = 
metadataMap.get(rightFragment.getFragmentId());
+    Preconditions.checkArgument(isLeafPlan(rightMetadata), "Right side is not 
leaf");
+
+    String leftTable = leftMetadata.getScannedTables().get(0);
+    String rightTable = rightMetadata.getScannedTables().get(0);
+    ColocatedTableInfo leftColocatedTableInfo = 
getColocatedTableInfo(leftTable);
+    ColocatedTableInfo rightColocatedTableInfo = 
getColocatedTableInfo(rightTable);
+    ColocatedPartitionInfo[] leftPartitionInfoMap = 
leftColocatedTableInfo._partitionInfoMap;
+    ColocatedPartitionInfo[] rightPartitionInfoMap = 
rightColocatedTableInfo._partitionInfoMap;
+    // TODO: Support colocated join when both side have different number of 
partitions (e.g. left: 8, right: 16)
+    int numPartitions = leftPartitionInfoMap.length;
+    Preconditions.checkState(numPartitions == rightPartitionInfoMap.length,
+        "Got different number of partitions in left table: %s (%s) and right 
table: %s (%s)", leftTable, numPartitions,
+        rightTable, rightPartitionInfoMap.length);
+
+    // Pick one server per partition
+    int nextWorkerId = 0;
+    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap = new 
HashMap<>();
+    Map<Integer, Map<String, List<String>>> leftWorkerIdToSegmentsMap = new 
HashMap<>();
+    Map<Integer, Map<String, List<String>>> rightWorkerIdToSegmentsMap = new 
HashMap<>();
+    Map<String, ServerInstance> enabledServerInstanceMap = 
_routingManager.getEnabledServerInstanceMap();
+    for (int i = 0; i < numPartitions; i++) {
+      ColocatedPartitionInfo leftPartitionInfo = leftPartitionInfoMap[i];
+      ColocatedPartitionInfo rightPartitionInfo = rightPartitionInfoMap[i];
+      if (leftPartitionInfo == null && rightPartitionInfo == null) {
+        continue;
+      }
+      // TODO: Currently we don't support the case when for a partition only 
one side has segments. The reason is that
+      //       the leaf stage won't be able to directly return empty response.
+      Preconditions.checkState(leftPartitionInfo != null && rightPartitionInfo 
!= null,
+          "One side doesn't have any segment for partition: %s", i);
+      Set<String> candidates = new 
HashSet<>(leftPartitionInfo._fullyReplicatedServers);
+      candidates.retainAll(rightPartitionInfo._fullyReplicatedServers);
+      ServerInstance serverInstance = pickRandomEnabledServer(candidates, 
enabledServerInstanceMap);
+      Preconditions.checkState(serverInstance != null,
+          "Failed to find enabled fully replicated server for partition: %s in 
table: %s and %s", i, leftTable,
+          rightTable);
+      QueryServerInstance queryServerInstance = new 
QueryServerInstance(serverInstance);
+      int workerId = nextWorkerId++;
+      serverInstanceToWorkerIdMap.computeIfAbsent(queryServerInstance, k -> 
new ArrayList<>()).add(workerId);
+      leftWorkerIdToSegmentsMap.put(workerId, 
getSegmentsMap(leftPartitionInfo));
+      rightWorkerIdToSegmentsMap.put(workerId, 
getSegmentsMap(rightPartitionInfo));
+    }
+
+    DispatchablePlanMetadata joinMetadata = 
metadataMap.get(fragment.getFragmentId());
+    joinMetadata.setServerInstanceToWorkerIdMap(serverInstanceToWorkerIdMap);
+    joinMetadata.setTotalWorkerCount(nextWorkerId);
+
+    leftMetadata.setServerInstanceToWorkerIdMap(serverInstanceToWorkerIdMap);
+    leftMetadata.setWorkerIdToSegmentsMap(leftWorkerIdToSegmentsMap);
+    leftMetadata.setTotalWorkerCount(nextWorkerId);
+    leftMetadata.setTimeBoundaryInfo(leftColocatedTableInfo._timeBoundaryInfo);
 
-        // Servers in the offline table's tenant.
-        Map<String, ServerInstance> servers = 
_routingManager.getEnabledServersForTableTenant(offlineTable);
-        serverInstances.addAll(servers.values());
+    rightMetadata.setServerInstanceToWorkerIdMap(serverInstanceToWorkerIdMap);
+    rightMetadata.setWorkerIdToSegmentsMap(rightWorkerIdToSegmentsMap);
+    rightMetadata.setTotalWorkerCount(nextWorkerId);
+    
rightMetadata.setTimeBoundaryInfo(rightColocatedTableInfo._timeBoundaryInfo);
 
-        // Servers in the online table's tenant.
-        servers = 
_routingManager.getEnabledServersForTableTenant(realtimeTable);
-        serverInstances.addAll(servers.values());
+    // NOTE: For pipeline breaker, leaf fragment can also have children
+    for (PlanFragment child : leftFragment.getChildren()) {
+      assignWorkersToNonRootFragment(child, context);
+    }
+    for (PlanFragment child : rightFragment.getChildren()) {
+      assignWorkersToNonRootFragment(child, context);
+    }
+  }
+
+  private ColocatedTableInfo getColocatedTableInfo(String tableName) {
+    TableType tableType = 
TableNameBuilder.getTableTypeFromTableName(tableName);
+    if (tableType == null) {
+      String offlineTableName = 
TableNameBuilder.OFFLINE.tableNameWithType(tableName);
+      String realtimeTableName = 
TableNameBuilder.REALTIME.tableNameWithType(tableName);
+      boolean offlineRoutingExists = 
_routingManager.routingExists(offlineTableName);
+      boolean realtimeRoutingExists = 
_routingManager.routingExists(realtimeTableName);
+      Preconditions.checkState(offlineRoutingExists || realtimeRoutingExists, 
"Routing doesn't exist for table: %s",
+          tableName);
+      if (offlineRoutingExists && realtimeRoutingExists) {
+        // For hybrid table, find the common servers for each partition
+        TimeBoundaryInfo timeBoundaryInfo = 
_routingManager.getTimeBoundaryInfo(offlineTableName);
+        // Ignore OFFLINE side when time boundary info is unavailable
+        if (timeBoundaryInfo == null) {
+          return getRealtimeColocatedTableInfo(realtimeTableName);
+        }
+        PartitionInfo[] offlinePartitionInfoMap = 
getTablePartitionInfo(offlineTableName).getPartitionInfoMap();
+        PartitionInfo[] realtimePartitionInfoMap = 
getTablePartitionInfo(realtimeTableName).getPartitionInfoMap();
+        int numPartitions = offlinePartitionInfoMap.length;
+        Preconditions.checkState(numPartitions == 
realtimePartitionInfoMap.length,
+            "Got different number of partitions in OFFLINE side: %s and 
REALTIME side: %s of table: %s", numPartitions,
+            realtimePartitionInfoMap.length, tableName);
+        ColocatedPartitionInfo[] colocatedPartitionInfoMap = new 
ColocatedPartitionInfo[numPartitions];
+        for (int i = 0; i < numPartitions; i++) {
+          PartitionInfo offlinePartitionInfo = offlinePartitionInfoMap[i];
+          PartitionInfo realtimePartitionInfo = realtimePartitionInfoMap[i];
+          if (offlinePartitionInfo == null && realtimePartitionInfo == null) {
+            continue;
+          }
+          if (offlinePartitionInfo == null) {
+            colocatedPartitionInfoMap[i] =
+                new 
ColocatedPartitionInfo(realtimePartitionInfo._fullyReplicatedServers, null,
+                    realtimePartitionInfo._segments);
+            continue;
+          }
+          if (realtimePartitionInfo == null) {
+            colocatedPartitionInfoMap[i] =
+                new 
ColocatedPartitionInfo(offlinePartitionInfo._fullyReplicatedServers, 
offlinePartitionInfo._segments,
+                    null);
+            continue;
+          }
+          Set<String> fullyReplicatedServers = new 
HashSet<>(offlinePartitionInfo._fullyReplicatedServers);
+          
fullyReplicatedServers.retainAll(realtimePartitionInfo._fullyReplicatedServers);
+          Preconditions.checkState(!fullyReplicatedServers.isEmpty(),
+              "Failed to find fully replicated server for partition: %s in 
hybrid table: %s", i, tableName);
+          colocatedPartitionInfoMap[i] =
+              new ColocatedPartitionInfo(fullyReplicatedServers, 
offlinePartitionInfo._segments,
+                  realtimePartitionInfo._segments);
+        }
+        return new ColocatedTableInfo(colocatedPartitionInfoMap, 
timeBoundaryInfo);
+      } else if (offlineRoutingExists) {
+        return getOfflineColocatedTableInfo(offlineTableName);
+      } else {
+        return getRealtimeColocatedTableInfo(realtimeTableName);
+      }
+    } else {
+      if (tableType == TableType.OFFLINE) {
+        return getOfflineColocatedTableInfo(tableName);
       } else {
-        Map<String, ServerInstance> servers = 
_routingManager.getEnabledServersForTableTenant(table);
-        serverInstances.addAll(servers.values());
+        return getRealtimeColocatedTableInfo(tableName);
+      }
+    }
+  }
+
+  private TablePartitionInfo getTablePartitionInfo(String tableNameWithType) {
+    TablePartitionInfo tablePartitionInfo = 
_routingManager.getTablePartitionInfo(tableNameWithType);
+    Preconditions.checkState(tablePartitionInfo != null, "Failed to find table 
partition info for table: %s",
+        tableNameWithType);
+    
Preconditions.checkState(tablePartitionInfo.getSegmentsWithInvalidPartition().isEmpty(),
+        "Find %s segments with invalid partition for table: %s",
+        tablePartitionInfo.getSegmentsWithInvalidPartition().size(), 
tableNameWithType);
+    return tablePartitionInfo;
+  }
+
+  private ColocatedTableInfo getOfflineColocatedTableInfo(String 
offlineTableName) {
+    PartitionInfo[] partitionInfoMap = 
getTablePartitionInfo(offlineTableName).getPartitionInfoMap();
+    int numPartitions = partitionInfoMap.length;
+    ColocatedPartitionInfo[] colocatedPartitionInfoMap = new 
ColocatedPartitionInfo[numPartitions];
+    for (int i = 0; i < numPartitions; i++) {
+      PartitionInfo partitionInfo = partitionInfoMap[i];
+      if (partitionInfo != null) {
+        colocatedPartitionInfoMap[i] =
+            new ColocatedPartitionInfo(partitionInfo._fullyReplicatedServers, 
partitionInfo._segments, null);
       }
     }
+    return new ColocatedTableInfo(colocatedPartitionInfoMap, null);
+  }
+
+  private ColocatedTableInfo getRealtimeColocatedTableInfo(String 
realtimeTableName) {
+    PartitionInfo[] partitionInfoMap = 
getTablePartitionInfo(realtimeTableName).getPartitionInfoMap();
+    int numPartitions = partitionInfoMap.length;
+    ColocatedPartitionInfo[] colocatedPartitionInfoMap = new 
ColocatedPartitionInfo[numPartitions];
+    for (int i = 0; i < numPartitions; i++) {
+      PartitionInfo partitionInfo = partitionInfoMap[i];
+      if (partitionInfo != null) {
+        colocatedPartitionInfoMap[i] =
+            new ColocatedPartitionInfo(partitionInfo._fullyReplicatedServers, 
null, partitionInfo._segments);
+      }
+    }
+    return new ColocatedTableInfo(colocatedPartitionInfoMap, null);
+  }
+
+  private static class ColocatedTableInfo {
+    final ColocatedPartitionInfo[] _partitionInfoMap;
+    final TimeBoundaryInfo _timeBoundaryInfo;
+
+    ColocatedTableInfo(ColocatedPartitionInfo[] partitionInfoMap, @Nullable 
TimeBoundaryInfo timeBoundaryInfo) {
+      _partitionInfoMap = partitionInfoMap;
+      _timeBoundaryInfo = timeBoundaryInfo;
+    }
+  }
 
-    return serverInstances;
+  private static class ColocatedPartitionInfo {
+    final Set<String> _fullyReplicatedServers;
+    final List<String> _offlineSegments;
+    final List<String> _realtimeSegments;
+
+    public ColocatedPartitionInfo(Set<String> fullyReplicatedServers, 
@Nullable List<String> offlineSegments,
+        @Nullable List<String> realtimeSegments) {
+      _fullyReplicatedServers = fullyReplicatedServers;
+      _offlineSegments = offlineSegments;
+      _realtimeSegments = realtimeSegments;
+    }
+  }
+
+  @Nullable
+  private static ServerInstance pickRandomEnabledServer(Set<String> candidates,
+      Map<String, ServerInstance> enabledServerInstanceMap) {
+    if (candidates.isEmpty()) {
+      return null;
+    }
+    String[] servers = candidates.toArray(new String[0]);
+    ArrayUtils.shuffle(servers, RANDOM);
+    for (String server : servers) {
+      ServerInstance serverInstance = enabledServerInstanceMap.get(server);
+      if (serverInstance != null) {
+        return serverInstance;
+      }
+    }
+    return null;
+  }
+
+  private static Map<String, List<String>> 
getSegmentsMap(ColocatedPartitionInfo partitionInfo) {
+    Map<String, List<String>> segmentsMap = new HashMap<>();
+    if (partitionInfo._offlineSegments != null) {
+      segmentsMap.put(TableType.OFFLINE.name(), 
partitionInfo._offlineSegments);
+    }
+    if (partitionInfo._realtimeSegments != null) {
+      segmentsMap.put(TableType.REALTIME.name(), 
partitionInfo._realtimeSegments);
+    }
+    return segmentsMap;
+  }
+
+  private List<ServerInstance> fetchServersForIntermediateStage(Set<String> 
tableNames) {
+    Set<ServerInstance> serverInstances = new HashSet<>();
+    for (String tableName : tableNames) {
+      TableType tableType = 
TableNameBuilder.getTableTypeFromTableName(tableName);
+      if (tableType == null) {
+        String offlineTableName = 
TableNameBuilder.forType(TableType.OFFLINE).tableNameWithType(tableName);
+        
serverInstances.addAll(_routingManager.getEnabledServersForTableTenant(offlineTableName).values());
+        String realtimeTableName = 
TableNameBuilder.forType(TableType.REALTIME).tableNameWithType(tableName);
+        
serverInstances.addAll(_routingManager.getEnabledServersForTableTenant(realtimeTableName).values());
+      } else {
+        
serverInstances.addAll(_routingManager.getEnabledServersForTableTenant(tableName).values());
+      }
+    }
+    return new ArrayList<>(serverInstances);
   }
 }
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java
index 18850dbccc..58db9c7f41 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java
@@ -24,12 +24,14 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
 import org.apache.helix.model.InstanceConfig;
 import org.apache.helix.zookeeper.datamodel.ZNRecord;
 import org.apache.pinot.common.config.provider.TableCache;
 import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.core.routing.RoutingManager;
 import org.apache.pinot.core.routing.RoutingTable;
+import org.apache.pinot.core.routing.TablePartitionInfo;
 import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.spi.config.table.TableType;
@@ -178,6 +180,12 @@ public class MockRoutingManagerFactory {
           String.valueOf(System.currentTimeMillis() - 
TimeUnit.DAYS.toMillis(1))) : null;
     }
 
+    @Nullable
+    @Override
+    public TablePartitionInfo getTablePartitionInfo(String tableNameWithType) {
+      return null;
+    }
+
     @Override
     public Map<String, ServerInstance> getEnabledServersForTableTenant(String 
tableNameWithType) {
       return _serverInstances;
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/exchange/SingletonExchange.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/exchange/SingletonExchange.java
index b8365ebf2d..4634eb09f2 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/exchange/SingletonExchange.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/exchange/SingletonExchange.java
@@ -18,7 +18,7 @@
  */
 package org.apache.pinot.query.runtime.operator.exchange;
 
-import java.io.IOException;
+import com.google.common.base.Preconditions;
 import java.util.List;
 import java.util.function.Consumer;
 import org.apache.pinot.query.mailbox.InMemorySendingMailbox;
@@ -37,28 +37,14 @@ class SingletonExchange extends BlockExchange {
   SingletonExchange(OpChainId opChainId, List<SendingMailbox> 
sendingMailboxes, BlockSplitter splitter,
       Consumer<OpChainId> callback, long deadlineMs) {
     super(opChainId, sendingMailboxes, splitter, callback, deadlineMs);
+    Preconditions.checkArgument(
+        sendingMailboxes.size() == 1 && sendingMailboxes.get(0) instanceof 
InMemorySendingMailbox,
+        "Expect single InMemorySendingMailbox for SingletonExchange");
   }
 
   @Override
-  protected void route(List<SendingMailbox> mailbox, TransferableBlock block)
+  protected void route(List<SendingMailbox> sendingMailboxes, 
TransferableBlock block)
       throws Exception {
-    boolean isLocalExchangeSent = false;
-    for (SendingMailbox sendingMailbox : mailbox) {
-      if (isLocal(sendingMailbox)) {
-        if (!isLocalExchangeSent) {
-          sendBlock(sendingMailbox, block);
-          isLocalExchangeSent = true;
-        } else {
-          throw new IOException("Local exchange has already been sent for 
singleton exchange!");
-        }
-      }
-    }
-    if (!isLocalExchangeSent) {
-      throw new IOException("Local exchange has not been sent successfully!");
-    }
-  }
-
-  private static boolean isLocal(SendingMailbox sendingMailbox) {
-    return sendingMailbox instanceof InMemorySendingMailbox;
+    sendBlock(sendingMailboxes.get(0), block);
   }
 }
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
index cf691b2f09..fc78863778 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
@@ -94,7 +94,7 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.STRING
         });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.INNER,
-        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
+        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses, false);
     HashJoinOperator joinOnString =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
 
@@ -132,7 +132,7 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.STRING
         });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.INNER,
-        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses, false);
     HashJoinOperator joinOnInt =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
     TransferableBlock result = joinOnInt.nextBlock();
@@ -167,7 +167,7 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.STRING
         });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.INNER,
-        getJoinKeys(new ArrayList<>(), new ArrayList<>()), joinClauses);
+        getJoinKeys(new ArrayList<>(), new ArrayList<>()), joinClauses, false);
     HashJoinOperator joinOnInt =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
     TransferableBlock result = joinOnInt.nextBlock();
@@ -209,7 +209,7 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.STRING
         });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.LEFT,
-        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
+        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
 
@@ -244,7 +244,7 @@ public class HashJoinOperatorTest {
         });
     List<RexExpression> joinClauses = new ArrayList<>();
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.INNER,
-        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
 
@@ -276,7 +276,7 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.STRING
         });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.LEFT,
-        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
 
@@ -312,7 +312,7 @@ public class HashJoinOperatorTest {
         });
 
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.INNER,
-        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
 
@@ -351,7 +351,7 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.STRING
         });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.INNER,
-        getJoinKeys(new ArrayList<>(), new ArrayList<>()), joinClauses);
+        getJoinKeys(new ArrayList<>(), new ArrayList<>()), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
     TransferableBlock result = join.nextBlock();
@@ -390,7 +390,7 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.STRING
         });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.INNER,
-        getJoinKeys(new ArrayList<>(), new ArrayList<>()), joinClauses);
+        getJoinKeys(new ArrayList<>(), new ArrayList<>()), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
     TransferableBlock result = join.nextBlock();
@@ -425,7 +425,7 @@ public class HashJoinOperatorTest {
         DataSchema.ColumnDataType.STRING
     });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.RIGHT,
-        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses, false);
     HashJoinOperator joinOnNum =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
     TransferableBlock result = joinOnNum.nextBlock();
@@ -475,7 +475,7 @@ public class HashJoinOperatorTest {
         DataSchema.ColumnDataType.STRING
     });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.SEMI,
-        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
+        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
     TransferableBlock result = join.nextBlock();
@@ -515,7 +515,7 @@ public class HashJoinOperatorTest {
         DataSchema.ColumnDataType.STRING
     });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.FULL,
-        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
     TransferableBlock result = join.nextBlock();
@@ -568,7 +568,7 @@ public class HashJoinOperatorTest {
         DataSchema.ColumnDataType.STRING
     });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.ANTI,
-        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
+        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
     TransferableBlock result = join.nextBlock();
@@ -607,7 +607,7 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.STRING
         });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.INNER,
-        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
 
@@ -641,7 +641,7 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.STRING
         });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.INNER,
-        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
 
@@ -678,7 +678,7 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.STRING
         });
     JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, 
JoinRelType.INNER,
-        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses, false);
     HashJoinOperator join =
         new HashJoinOperator(OperatorTestUtil.getDefaultContext(), 
_leftOperator, _rightOperator, leftSchema, node);
 
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/exchange/SingletonExchangeTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/exchange/SingletonExchangeTest.java
index 1143d38b56..42638e4b8c 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/exchange/SingletonExchangeTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/exchange/SingletonExchangeTest.java
@@ -19,7 +19,6 @@
 package org.apache.pinot.query.runtime.operator.exchange;
 
 import com.google.common.collect.ImmutableList;
-import java.io.IOException;
 import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.query.mailbox.GrpcSendingMailbox;
 import org.apache.pinot.query.mailbox.InMemorySendingMailbox;
@@ -68,8 +67,8 @@ public class SingletonExchangeTest {
     ImmutableList<SendingMailbox> destinations = ImmutableList.of(_mailbox1);
 
     // When:
-    new SingletonExchange(new OpChainId(1, 2, 3), destinations, 
TransferableBlockUtils::splitBlock,
-        (opChainId) -> { }, System.currentTimeMillis() + 
10_000L).route(destinations, _block);
+    new SingletonExchange(new OpChainId(1, 2, 3), destinations, 
TransferableBlockUtils::splitBlock, (opChainId) -> {
+    }, System.currentTimeMillis() + 10_000L).route(destinations, _block);
 
     // Then:
     ArgumentCaptor<TransferableBlock> captor = 
ArgumentCaptor.forClass(TransferableBlock.class);
@@ -78,43 +77,25 @@ public class SingletonExchangeTest {
     Assert.assertEquals(captor.getValue(), _block);
   }
 
-  @Test
-  public void shouldRouteSingletonWithExtraNonLocalMailbox()
+  @Test(expectedExceptions = IllegalArgumentException.class)
+  public void shouldThrowWhenSingletonWithNonLocalMailbox()
       throws Exception {
     // Given:
-    ImmutableList<SendingMailbox> destinations = ImmutableList.of(_mailbox1, 
_mailbox2);
-
-    // When:
-    new SingletonExchange(new OpChainId(1, 2, 3), destinations, 
TransferableBlockUtils::splitBlock,
-        (opChainId) -> { }, System.currentTimeMillis() + 
10_000L).route(destinations, _block);
-
-    // Then:
-    ArgumentCaptor<TransferableBlock> captor = 
ArgumentCaptor.forClass(TransferableBlock.class);
-    // Then:
-    Mockito.verify(_mailbox1, Mockito.times(1)).send(captor.capture());
-    Mockito.verify(_mailbox2, Mockito.times(0)).send(captor.capture());
-    Assert.assertEquals(captor.getValue(), _block);
-  }
-
-  @Test(expectedExceptions = IOException.class, 
expectedExceptionsMessageRegExp = ".*has already been sent.*")
-  public void shouldThrowWhenSingletonWithMultipleLocalMailbox()
-      throws Exception {
-    // Given:
-    ImmutableList<SendingMailbox> destinations = ImmutableList.of(_mailbox1, 
_mailbox2, _mailbox3);
+    ImmutableList<SendingMailbox> destinations = ImmutableList.of(_mailbox2);
 
     // When:
-    new SingletonExchange(new OpChainId(1, 2, 3), destinations, 
TransferableBlockUtils::splitBlock,
-        (opChainId) -> { }, System.currentTimeMillis() + 
10_000L).route(destinations, _block);
+    new SingletonExchange(new OpChainId(1, 2, 3), destinations, 
TransferableBlockUtils::splitBlock, (opChainId) -> {
+    }, System.currentTimeMillis() + 10_000L).route(destinations, _block);
   }
 
-  @Test(expectedExceptions = IOException.class, 
expectedExceptionsMessageRegExp = ".*has not been sent.*")
-  public void shouldThrowWhenSingletonWithNoLocalMailbox()
+  @Test(expectedExceptions = IllegalArgumentException.class)
+  public void shouldThrowWhenSingletonWithMultipleMailboxes()
       throws Exception {
     // Given:
-    ImmutableList<SendingMailbox> destinations = ImmutableList.of(_mailbox2);
+    ImmutableList<SendingMailbox> destinations = ImmutableList.of(_mailbox1, 
_mailbox3);
 
     // When:
-    new SingletonExchange(new OpChainId(1, 2, 3), destinations, 
TransferableBlockUtils::splitBlock,
-        (opChainId) -> { }, System.currentTimeMillis() + 
10_000L).route(destinations, _block);
+    new SingletonExchange(new OpChainId(1, 2, 3), destinations, 
TransferableBlockUtils::splitBlock, (opChainId) -> {
+    }, System.currentTimeMillis() + 10_000L).route(destinations, _block);
   }
 }
diff --git 
a/pinot-tools/src/main/java/org/apache/pinot/tools/ColocatedJoinEngineQuickStart.java
 
b/pinot-tools/src/main/java/org/apache/pinot/tools/ColocatedJoinEngineQuickStart.java
index c59603d2fd..0241168b81 100644
--- 
a/pinot-tools/src/main/java/org/apache/pinot/tools/ColocatedJoinEngineQuickStart.java
+++ 
b/pinot-tools/src/main/java/org/apache/pinot/tools/ColocatedJoinEngineQuickStart.java
@@ -38,10 +38,7 @@ public class ColocatedJoinEngineQuickStart extends 
MultistageEngineQuickStart {
 
   @Override
   public String[] getDefaultBatchTableDirectories() {
-    List<String> colocatedTableDirs = new 
ArrayList<>(Arrays.asList(COLOCATED_JOIN_DIRECTORIES));
-    String[] multiStageTableDirs = super.getDefaultBatchTableDirectories();
-    colocatedTableDirs.addAll(Arrays.asList(multiStageTableDirs));
-    return colocatedTableDirs.toArray(new String[0]);
+    return COLOCATED_JOIN_DIRECTORIES;
   }
 
   @Override
diff --git 
a/pinot-tools/src/main/resources/examples/batch/colocated/userAttributes/userAttributes_offline_table_config.json
 
b/pinot-tools/src/main/resources/examples/batch/colocated/userAttributes/userAttributes_offline_table_config.json
index 4f0599d942..6a5aecb868 100644
--- 
a/pinot-tools/src/main/resources/examples/batch/colocated/userAttributes/userAttributes_offline_table_config.json
+++ 
b/pinot-tools/src/main/resources/examples/batch/colocated/userAttributes/userAttributes_offline_table_config.json
@@ -3,49 +3,43 @@
   "tableType": "OFFLINE",
   "segmentsConfig": {
     "segmentPushType": "APPEND",
-    "segmentAssignmentStrategy": "BalanceNumSegmentAssignmentStrategy",
-    "schemaName": "userAttributes",
     "replication": 2
   },
+  "tenants": {
+    "broker": "DefaultTenant",
+    "server": "DefaultTenant"
+  },
+  "tableIndexConfig": {
+    "invertedIndexColumns": [
+      "userUUID"
+    ],
+    "segmentPartitionConfig": {
+      "columnPartitionMap": {
+        "userUUID": {
+          "functionName": "Murmur",
+          "numPartitions": 4
+        }
+      }
+    }
+  },
   "instanceAssignmentConfigMap": {
     "OFFLINE": {
       "tagPoolConfig": {
-        "tag": "DefaultTenant_OFFLINE",
-        "poolBased": false,
-        "numPools": 0
+        "tag": "DefaultTenant_OFFLINE"
       },
       "replicaGroupPartitionConfig": {
         "replicaGroupBased": true,
-        "numInstances": 0,
         "numReplicaGroups": 2,
         "numInstancesPerReplicaGroup": 2,
         "numPartitions": 2,
         "numInstancesPerPartition": 1,
-        "minimizeDataMovement": false,
         "partitionColumn": "userUUID"
-      },
-      "partitionSelector": "INSTANCE_REPLICA_GROUP_PARTITION_SELECTOR"
+      }
     }
   },
   "routing": {
     "instanceSelectorType": "multiStageReplicaGroup"
   },
-  "tenants": {
-  },
-  "tableIndexConfig": {
-    "loadMode": "HEAP",
-    "invertedIndexColumns": [
-      "userUUID"
-    ],
-    "segmentPartitionConfig": {
-      "columnPartitionMap": {
-        "userUUID": {
-          "functionName": "Murmur",
-          "numPartitions": 4
-        }
-      }
-    }
-  },
   "metadata": {
     "customConfigs": {
     }
diff --git 
a/pinot-tools/src/main/resources/examples/batch/colocated/userGroups/userGroups_offline_table_config.json
 
b/pinot-tools/src/main/resources/examples/batch/colocated/userGroups/userGroups_offline_table_config.json
index 8940b06b8d..9f7d3fd045 100644
--- 
a/pinot-tools/src/main/resources/examples/batch/colocated/userGroups/userGroups_offline_table_config.json
+++ 
b/pinot-tools/src/main/resources/examples/batch/colocated/userGroups/userGroups_offline_table_config.json
@@ -3,21 +3,17 @@
   "tableType": "OFFLINE",
   "segmentsConfig": {
     "segmentPushType": "APPEND",
-    "segmentAssignmentStrategy": "BalanceNumSegmentAssignmentStrategy",
-    "schemaName": "userGroups",
     "replication": "2",
     "replicaGroupStrategyConfig": {
       "partitionColumn": "userUUID",
       "numInstancesPerPartition": 2
     }
   },
-  "instancePartitionsMap": {
-    "OFFLINE": "userAttributes_OFFLINE"
-  },
   "tenants": {
+    "broker": "DefaultTenant",
+    "server": "DefaultTenant"
   },
   "tableIndexConfig": {
-    "loadMode": "HEAP",
     "invertedIndexColumns": [
       "userUUID",
       "groupUUID"
@@ -31,6 +27,9 @@
       }
     }
   },
+  "instancePartitionsMap": {
+    "OFFLINE": "userAttributes_OFFLINE"
+  },
   "routing": {
     "instanceSelectorType": "multiStageReplicaGroup"
   },


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to