This is an automated email from the ASF dual-hosted git repository.

ankitsultana pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 04c8a4f6ee [multistage] Add Leaf Stage Worker Assignment / Boundary / 
Agg Rules (#15481)
04c8a4f6ee is described below

commit 04c8a4f6ee80e471f7f8e6bd23048d6992eea8f0
Author: Ankit Sultana <ankitsult...@uber.com>
AuthorDate: Mon Apr 14 16:03:00 2025 -0500

    [multistage] Add Leaf Stage Worker Assignment / Boundary / Agg Rules 
(#15481)
---
 .../pinot/calcite/rel/traits/TraitAssignment.java  |  37 +-
 .../planner/physical/v2/HashDistributionDesc.java  |   8 +-
 .../pinot/query/planner/physical/v2/PRelNode.java  |  21 +
 .../planner/physical/v2/PinotDataDistribution.java |   8 +-
 .../planner/physical/v2/RelToPRelConverter.java    | 136 +++++++
 .../physical/v2/mapping/DistMappingGenerator.java  | 109 ++++++
 .../physical/v2/mapping/PinotDistMapping.java      |  82 ++++
 .../physical/v2/nodes/PhysicalAggregate.java       |  10 +
 .../planner/physical/v2/nodes/PhysicalFilter.java  |   9 +
 .../planner/physical/v2/nodes/PhysicalProject.java |   9 +
 .../planner/physical/v2/nodes/PhysicalSort.java    |   9 +
 .../physical/v2/nodes/PhysicalTableScan.java       |  12 +-
 .../physical/v2/opt/PhysicalOptRuleSet.java        |  42 ++
 .../planner/physical/v2/opt/RuleExecutors.java     |  43 +++
 .../v2/opt/rules/LeafStageAggregateRule.java       |  96 +++++
 .../v2/opt/rules/LeafStageBoundaryRule.java        |  90 +++++
 .../opt/rules/LeafStageWorkerAssignmentRule.java   | 424 +++++++++++++++++++++
 .../physical/v2/mapping/PinotDistMappingTest.java  | 115 ++++++
 .../rules/LeafStageWorkerAssignmentRuleTest.java   | 207 ++++++++++
 19 files changed, 1441 insertions(+), 26 deletions(-)

diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/traits/TraitAssignment.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/traits/TraitAssignment.java
index c6e0cacd5c..f2ba392129 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/traits/TraitAssignment.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/traits/TraitAssignment.java
@@ -36,6 +36,7 @@ import org.apache.calcite.rel.core.Window;
 import org.apache.pinot.calcite.rel.hint.PinotHintOptions;
 import org.apache.pinot.calcite.rel.rules.PinotRuleUtils;
 import org.apache.pinot.query.context.PhysicalPlannerContext;
+import org.apache.pinot.query.planner.physical.v2.PRelNode;
 import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalAggregate;
 import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalJoin;
 import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalProject;
@@ -52,33 +53,35 @@ import 
org.apache.pinot.query.planner.physical.v2.nodes.PhysicalWindow;
 public class TraitAssignment {
   private final Supplier<Integer> _planIdGenerator;
 
-  public TraitAssignment(Supplier<Integer> planIdGenerator) {
+  private TraitAssignment(Supplier<Integer> planIdGenerator) {
     _planIdGenerator = planIdGenerator;
   }
 
-  public static RelNode assign(RelNode relNode, PhysicalPlannerContext 
physicalPlannerContext) {
+  public static PRelNode assign(PRelNode pRelNode, PhysicalPlannerContext 
physicalPlannerContext) {
     TraitAssignment traitAssignment = new 
TraitAssignment(physicalPlannerContext.getNodeIdGenerator());
-    return traitAssignment.assign(relNode);
+    return traitAssignment.assign(pRelNode);
   }
 
-  public RelNode assign(RelNode node) {
+  @VisibleForTesting
+  PRelNode assign(PRelNode pRelNode) {
     // Process inputs first.
+    RelNode relNode = pRelNode.unwrap();
     List<RelNode> newInputs = new ArrayList<>();
-    for (RelNode input : node.getInputs()) {
-      newInputs.add(assign(input));
+    for (RelNode input : relNode.getInputs()) {
+      newInputs.add(assign((PRelNode) input).unwrap());
     }
-    node = node.copy(node.getTraitSet(), newInputs);
-    // Process current node.
-    if (node instanceof PhysicalSort) {
-      return assignSort((PhysicalSort) node);
-    } else if (node instanceof PhysicalJoin) {
-      return assignJoin((PhysicalJoin) node);
-    } else if (node instanceof PhysicalAggregate) {
-      return assignAggregate((PhysicalAggregate) node);
-    } else if (node instanceof PhysicalWindow) {
-      return assignWindow((PhysicalWindow) node);
+    relNode = relNode.copy(relNode.getTraitSet(), newInputs);
+    // Process current relNode.
+    if (relNode instanceof PhysicalSort) {
+      return (PRelNode) assignSort((PhysicalSort) relNode);
+    } else if (relNode instanceof PhysicalJoin) {
+      return (PRelNode) assignJoin((PhysicalJoin) relNode);
+    } else if (relNode instanceof PhysicalAggregate) {
+      return (PRelNode) assignAggregate((PhysicalAggregate) relNode);
+    } else if (relNode instanceof PhysicalWindow) {
+      return (PRelNode) assignWindow((PhysicalWindow) relNode);
     }
-    return node;
+    return (PRelNode) relNode;
   }
 
   /**
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/HashDistributionDesc.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/HashDistributionDesc.java
index 3c15b6c656..7c2284cce5 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/HashDistributionDesc.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/HashDistributionDesc.java
@@ -22,7 +22,7 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
 import javax.annotation.Nullable;
-import org.apache.calcite.util.mapping.Mappings;
+import org.apache.pinot.query.planner.physical.v2.mapping.PinotDistMapping;
 
 
 /**
@@ -63,15 +63,15 @@ public class HashDistributionDesc {
    * partitioning info.
    */
   @Nullable
-  public HashDistributionDesc apply(Mappings.TargetMapping targetMapping) {
+  public HashDistributionDesc apply(PinotDistMapping mapping) {
     for (Integer currentKey : _keys) {
-      if (currentKey >= targetMapping.getSourceCount() || 
targetMapping.getTargetOpt(currentKey) == -1) {
+      if (currentKey >= mapping.getSourceCount() || 
mapping.getTarget(currentKey) == -1) {
         return null;
       }
     }
     List<Integer> newKey = new ArrayList<>();
     for (int currentKey : _keys) {
-      newKey.add(targetMapping.getTargetOpt(currentKey));
+      newKey.add(mapping.getTarget(currentKey));
     }
     return new HashDistributionDesc(newKey, _hashFunction, _numPartitions);
   }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelNode.java
index 8f53d0b067..177870cbd7 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelNode.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelNode.java
@@ -21,6 +21,8 @@ package org.apache.pinot.query.planner.physical.v2;
 import java.util.List;
 import java.util.Objects;
 import javax.annotation.Nullable;
+import org.apache.calcite.rel.RelCollation;
+import org.apache.calcite.rel.RelDistribution;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.TableScan;
 
@@ -81,6 +83,21 @@ public interface PRelNode {
     return null;
   }
 
+  /**
+   * TODO(mse-physical): This does not check PinotExecStrategyTrait. We should 
revisit whether exec strategy should be
+   *   a trait or not.
+   */
+  default boolean areTraitsSatisfied() {
+    RelNode relNode = unwrap();
+    RelDistribution distribution = relNode.getTraitSet().getDistribution();
+    PinotDataDistribution dataDistribution = getPinotDataDistributionOrThrow();
+    if (dataDistribution.satisfies(distribution)) {
+      RelCollation collation = relNode.getTraitSet().getCollation();
+      return dataDistribution.satisfies(collation);
+    }
+    return false;
+  }
+
   PRelNode with(int newNodeId, List<PRelNode> newInputs, PinotDataDistribution 
newDistribution);
 
   default PRelNode with(List<PRelNode> newInputs, PinotDataDistribution 
newDistribution) {
@@ -90,4 +107,8 @@ public interface PRelNode {
   default PRelNode with(List<PRelNode> newInputs) {
     return with(getNodeId(), newInputs, getPinotDataDistributionOrThrow());
   }
+
+  default PRelNode asLeafStage() {
+    throw new UnsupportedOperationException(String.format("Cannot make %s a 
leaf stage node", unwrap()));
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PinotDataDistribution.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PinotDataDistribution.java
index 36ea5ca3cc..36002cb32f 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PinotDataDistribution.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PinotDataDistribution.java
@@ -28,7 +28,7 @@ import javax.annotation.Nullable;
 import org.apache.calcite.rel.RelCollation;
 import org.apache.calcite.rel.RelCollations;
 import org.apache.calcite.rel.RelDistribution;
-import org.apache.calcite.util.mapping.Mappings;
+import org.apache.pinot.query.planner.physical.v2.mapping.PinotDistMapping;
 
 
 /**
@@ -157,13 +157,13 @@ public class PinotDataDistribution {
     return _collation.satisfies(relCollation);
   }
 
-  public PinotDataDistribution apply(@Nullable Mappings.TargetMapping 
targetMapping) {
-    if (targetMapping == null) {
+  public PinotDataDistribution apply(@Nullable PinotDistMapping mapping) {
+    if (mapping == null) {
       return new PinotDataDistribution(RelDistribution.Type.ANY, _workers, 
_workerHash, null, null);
     }
     Set<HashDistributionDesc> newHashDesc = new HashSet<>();
     for (HashDistributionDesc desc : _hashDistributionDesc) {
-      HashDistributionDesc newDescs = desc.apply(targetMapping);
+      HashDistributionDesc newDescs = desc.apply(mapping);
       if (newDescs != null) {
         newHashDesc.add(newDescs);
       }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/RelToPRelConverter.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/RelToPRelConverter.java
new file mode 100644
index 0000000000..6b80fd4d38
--- /dev/null
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/RelToPRelConverter.java
@@ -0,0 +1,136 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical.v2;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Supplier;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Filter;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.Minus;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.Sort;
+import org.apache.calcite.rel.core.TableScan;
+import org.apache.calcite.rel.core.Union;
+import org.apache.calcite.rel.core.Values;
+import org.apache.calcite.rel.core.Window;
+import org.apache.pinot.calcite.rel.logical.PinotLogicalAggregate;
+import org.apache.pinot.calcite.rel.traits.TraitAssignment;
+import org.apache.pinot.common.config.provider.TableCache;
+import org.apache.pinot.query.context.PhysicalPlannerContext;
+import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalAggregate;
+import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalFilter;
+import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalJoin;
+import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalProject;
+import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalSort;
+import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalTableScan;
+import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalUnion;
+import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalValues;
+import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalWindow;
+import org.apache.pinot.query.planner.physical.v2.opt.PRelOptRule;
+import org.apache.pinot.query.planner.physical.v2.opt.PhysicalOptRuleSet;
+import org.apache.pinot.query.planner.physical.v2.opt.RuleExecutor;
+import org.apache.pinot.query.planner.physical.v2.opt.RuleExecutors;
+import org.apache.pinot.query.planner.plannode.AggregateNode;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Converts a tree of RelNode to a tree of PRelNode, running the configured 
Physical Optimizers in the process.
+ */
+public class RelToPRelConverter {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(RelToPRelConverter.class);
+
+  private RelToPRelConverter() {
+  }
+
+  public static PRelNode toPRelNode(RelNode relNode, PhysicalPlannerContext 
context, TableCache tableCache) {
+    // Step-1: Convert each RelNode to a PRelNode
+    PRelNode rootPRelNode = create(relNode, context.getNodeIdGenerator());
+    // Step-2: Assign traits
+    rootPRelNode = TraitAssignment.assign(rootPRelNode, context);
+    // Step-3: Run physical optimizer rules.
+    var ruleAndExecutorList = PhysicalOptRuleSet.create(context, tableCache);
+    for (var ruleAndExecutor : ruleAndExecutorList) {
+      PRelOptRule rule = ruleAndExecutor.getLeft();
+      RuleExecutor executor = RuleExecutors.create(ruleAndExecutor.getRight(), 
rule, context);
+      rootPRelNode = executor.execute(rootPRelNode);
+    }
+    return rootPRelNode;
+  }
+
+  public static PRelNode create(RelNode relNode, Supplier<Integer> 
nodeIdGenerator) {
+    List<PRelNode> inputs = new ArrayList<>();
+    for (RelNode input : relNode.getInputs()) {
+      inputs.add(create(input, nodeIdGenerator));
+    }
+    if (relNode instanceof TableScan) {
+      Preconditions.checkState(inputs.isEmpty(), "Expected no inputs to table 
scan. Found: %s", inputs);
+      return new PhysicalTableScan((TableScan) relNode, nodeIdGenerator.get(), 
null, null);
+    } else if (relNode instanceof Filter) {
+      Preconditions.checkState(inputs.size() == 1, "Expected exactly 1 input 
of filter. Found: %s", inputs);
+      Filter filter = (Filter) relNode;
+      return new PhysicalFilter(filter.getCluster(), filter.getTraitSet(), 
filter.getHints(), filter.getCondition(),
+          nodeIdGenerator.get(), inputs.get(0), null, false);
+    } else if (relNode instanceof Project) {
+      Preconditions.checkState(inputs.size() == 1, "Expected exactly 1 input 
of project. Found: %s", inputs);
+      Project project = (Project) relNode;
+      return new PhysicalProject(project.getCluster(), project.getTraitSet(), 
project.getHints(), project.getProjects(),
+          project.getRowType(), project.getVariablesSet(), 
nodeIdGenerator.get(), inputs.get(0), null, false);
+    } else if (relNode instanceof PinotLogicalAggregate) {
+      Preconditions.checkState(inputs.size() == 1, "Expected exactly 1 input 
of agg. Found: %s", inputs);
+      PinotLogicalAggregate aggRel = (PinotLogicalAggregate) relNode;
+      return new PhysicalAggregate(aggRel.getCluster(), aggRel.getTraitSet(), 
aggRel.getHints(), aggRel.getGroupSet(),
+          aggRel.getGroupSets(), aggRel.getAggCallList(), 
nodeIdGenerator.get(), inputs.get(0), null, false,
+          AggregateNode.AggType.DIRECT, false, List.of(), 0);
+    } else if (relNode instanceof Join) {
+      Preconditions.checkState(relNode.getInputs().size() == 2, "Expected 
exactly 2 inputs to join. Found: %s", inputs);
+      Join join = (Join) relNode;
+      return new PhysicalJoin(join.getCluster(), join.getTraitSet(), 
join.getHints(), join.getCondition(),
+          join.getVariablesSet(), join.getJoinType(), nodeIdGenerator.get(), 
inputs.get(0), inputs.get(1), null);
+    } else if (relNode instanceof Union) {
+      Union union = (Union) relNode;
+      return new PhysicalUnion(union.getCluster(), union.getTraitSet(), 
union.getHints(), union.all, inputs,
+          nodeIdGenerator.get(), null);
+    } else if (relNode instanceof Minus) {
+      Minus minus = (Minus) relNode;
+      return new PhysicalUnion(minus.getCluster(), minus.getTraitSet(), 
minus.getHints(), minus.all, inputs,
+          nodeIdGenerator.get(), null);
+    } else if (relNode instanceof Sort) {
+      Preconditions.checkState(inputs.size() == 1, "Expected exactly 1 input 
of sort. Found: %s", inputs);
+      Sort sort = (Sort) relNode;
+      return new PhysicalSort(sort.getCluster(), sort.getTraitSet(), 
sort.getHints(), sort.getCollation(), sort.offset,
+          sort.fetch, inputs.get(0), nodeIdGenerator.get(), null, false);
+    } else if (relNode instanceof Values) {
+      Preconditions.checkState(inputs.isEmpty(), "Expected no inputs to 
values. Found: %s", inputs);
+      Values values = (Values) relNode;
+      return new PhysicalValues(values.getCluster(), values.getHints(), 
values.getRowType(), values.getTuples(),
+          values.getTraitSet(), nodeIdGenerator.get(), null);
+    } else if (relNode instanceof Window) {
+      Preconditions.checkState(inputs.size() == 1, "Expected exactly 1 input 
of window. Found: %s", inputs);
+      Window window = (Window) relNode;
+      return new PhysicalWindow(window.getCluster(), window.getTraitSet(), 
window.getHints(), window.getConstants(),
+          window.getRowType(), window.groups, nodeIdGenerator.get(), 
inputs.get(0), null);
+    }
+    throw new IllegalStateException("Unexpected relNode type: " + 
relNode.getClass().getName());
+  }
+}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/DistMappingGenerator.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/DistMappingGenerator.java
new file mode 100644
index 0000000000..39413a0268
--- /dev/null
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/DistMappingGenerator.java
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical.v2.mapping;
+
+import java.util.List;
+import javax.annotation.Nullable;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.Filter;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.SetOp;
+import org.apache.calcite.rel.core.Sort;
+import org.apache.calcite.rel.core.TableScan;
+import org.apache.calcite.rel.core.Values;
+import org.apache.calcite.rel.core.Window;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.commons.collections4.CollectionUtils;
+
+
+/**
+ * Generates {@link PinotDistMapping} for a given source and destination 
RelNode.
+ */
+public class DistMappingGenerator {
+  private DistMappingGenerator() {
+  }
+
+  /**
+   * Source to destination mapping.
+   */
+  public static PinotDistMapping compute(RelNode source, RelNode destination,
+      @Nullable List<RelNode> leadingSiblings) {
+    if (destination instanceof Project) {
+      Project project = (Project) destination;
+      return computeExactInputRefMapping(project, 
source.getRowType().getFieldCount());
+    } else if (destination instanceof Window) {
+      // Window preserves input fields, and appends a field for each window 
expr.
+      Window window = (Window) destination;
+      return 
PinotDistMapping.identity(window.getInput().getRowType().getFieldCount());
+    } else if (destination instanceof Aggregate) {
+      Aggregate aggregate = (Aggregate) destination;
+      PinotDistMapping mapping = new 
PinotDistMapping(source.getRowType().getFieldCount());
+      List<Integer> groupSet = aggregate.getGroupSet().asList();
+      for (int j = 0; j < groupSet.size(); j++) {
+        mapping.set(groupSet.get(j), j);
+      }
+      return mapping;
+    } else if (destination instanceof Join) {
+      if (CollectionUtils.isEmpty(leadingSiblings)) {
+        return PinotDistMapping.identity(source.getRowType().getFieldCount());
+      }
+      int leftFieldCount = 0;
+      for (RelNode sibling : leadingSiblings) {
+        leftFieldCount += sibling.getRowType().getFieldCount();
+      }
+      PinotDistMapping mapping = new 
PinotDistMapping(source.getRowType().getFieldCount());
+      for (int i = 0; i < mapping.getSourceCount(); i++) {
+        mapping.set(i, i + leftFieldCount);
+      }
+      return mapping;
+    } else if (destination instanceof Filter) {
+      return PinotDistMapping.identity(source.getRowType().getFieldCount());
+    } else if (destination instanceof TableScan) {
+      throw new IllegalStateException("Found destination as TableScan in 
MappingGenerator");
+    } else if (destination instanceof Values) {
+      throw new IllegalStateException("Found destination as Values in 
MappingGenerator");
+    } else if (destination instanceof Sort) {
+      return PinotDistMapping.identity(source.getRowType().getFieldCount());
+    } else if (destination instanceof SetOp) {
+      SetOp setOp = (SetOp) destination;
+      if (setOp.isHomogeneous(true)) {
+        return PinotDistMapping.identity(source.getRowType().getFieldCount());
+      }
+      // TODO(mse-physical): Handle heterogeneous set ops. Currently we drop 
al mapping refs.
+      return new PinotDistMapping(source.getRowType().getFieldCount());
+    }
+    throw new IllegalStateException("Unknown node type: " + 
destination.getClass());
+  }
+
+  private static PinotDistMapping computeExactInputRefMapping(Project project, 
int sourceCount) {
+    PinotDistMapping mapping = new PinotDistMapping(sourceCount);
+    int indexInCurrentRelNode = 0;
+    for (RexNode rexNode : project.getProjects()) {
+      if (rexNode instanceof RexInputRef) {
+        RexInputRef rexInputRef = (RexInputRef) rexNode;
+        mapping.set(rexInputRef.getIndex(), indexInCurrentRelNode);
+      }
+      indexInCurrentRelNode++;
+    }
+    return mapping;
+  }
+}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMapping.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMapping.java
new file mode 100644
index 0000000000..1a083e0314
--- /dev/null
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMapping.java
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical.v2.mapping;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+
+/**
+ * Mapping specifically for Pinot Data Distribution and trait mapping. A 
mapping is defined for a source / destination
+ * RelNode pair and is used to track how input fields are mapped to output 
fields.
+ */
+public class PinotDistMapping {
+  private static final int DEFAULT_MAPPING_VALUE = -1;
+  private final int _sourceCount;
+  private final Map<Integer, Integer> _sourceToTargetMapping = new HashMap<>();
+
+  public PinotDistMapping(int sourceCount) {
+    _sourceCount = sourceCount;
+    for (int i = 0; i < sourceCount; i++) {
+      _sourceToTargetMapping.put(i, DEFAULT_MAPPING_VALUE);
+    }
+  }
+
+  public static PinotDistMapping identity(int sourceCount) {
+    PinotDistMapping mapping = new PinotDistMapping(sourceCount);
+    for (int i = 0; i < sourceCount; i++) {
+      mapping.set(i, i);
+    }
+    return mapping;
+  }
+
+  public int getSourceCount() {
+    return _sourceCount;
+  }
+
+  public int getTarget(int source) {
+    Preconditions.checkArgument(source >= 0 && source < _sourceCount, "Invalid 
source index: %s", source);
+    Integer target = _sourceToTargetMapping.get(source);
+    return target == null ? DEFAULT_MAPPING_VALUE : target;
+  }
+
+  public void set(int source, int target) {
+    Preconditions.checkArgument(source >= 0 && source < _sourceCount, "Invalid 
source index: %s", source);
+    _sourceToTargetMapping.put(source, target);
+  }
+
+  public List<Integer> getMappedKeys(List<Integer> existingKeys) {
+    List<Integer> result = new ArrayList<>(existingKeys.size());
+    for (int key : existingKeys) {
+      Integer mappedKey = _sourceToTargetMapping.get(key);
+      Preconditions.checkArgument(mappedKey != null,
+          "Key %s not found in mapping with source count: %s", key, 
_sourceCount);
+      if (mappedKey != DEFAULT_MAPPING_VALUE) {
+        result.add(mappedKey);
+      } else {
+        return Collections.emptyList();
+      }
+    }
+    return result;
+  }
+}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalAggregate.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalAggregate.java
index 9a76e9e31d..4f7d026996 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalAggregate.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalAggregate.java
@@ -103,4 +103,14 @@ public class PhysicalAggregate extends Aggregate 
implements PRelNode {
         getAggCallList(), newNodeId, newInputs.get(0), newDistribution, 
_leafStage, _aggType, _leafReturnFinalResult,
         _collations, _limit);
   }
+
+  @Override
+  public PhysicalAggregate asLeafStage() {
+    if (isLeafStage()) {
+      return this;
+    }
+    return new PhysicalAggregate(getCluster(), getTraitSet(), getHints(), 
getGroupSet(), getGroupSets(),
+        getAggCallList(), _nodeId, _pRelInputs.get(0), _pinotDataDistribution, 
true, _aggType, _leafReturnFinalResult,
+        _collations, _limit);
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalFilter.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalFilter.java
index d483382394..1f5ecee7d4 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalFilter.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalFilter.java
@@ -92,4 +92,13 @@ public class PhysicalFilter extends Filter implements 
PRelNode {
     return new PhysicalFilter(getCluster(), getTraitSet(), getHints(), 
condition, newNodeId, newInputs.get(0),
         newDistribution, _leafStage);
   }
+
+  @Override
+  public PRelNode asLeafStage() {
+    if (isLeafStage()) {
+      return this;
+    }
+    return new PhysicalFilter(getCluster(), getTraitSet(), getHints(), 
condition, _nodeId, _pRelInputs.get(0),
+        _pinotDataDistribution, true);
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalProject.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalProject.java
index a8b7fde656..63065e38ff 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalProject.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalProject.java
@@ -90,4 +90,13 @@ public class PhysicalProject extends Project implements 
PRelNode {
     return new PhysicalProject(getCluster(), getTraitSet(), getHints(), 
getProjects(), getRowType(), getVariablesSet(),
         newNodeId, newInputs.get(0), newDistribution, _leafStage);
   }
+
+  @Override
+  public PRelNode asLeafStage() {
+    if (isLeafStage()) {
+      return this;
+    }
+    return new PhysicalProject(getCluster(), getTraitSet(), getHints(), 
getProjects(), getRowType(),
+        getVariablesSet(), _nodeId, _pRelInputs.get(0), 
_pinotDataDistribution, true);
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalSort.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalSort.java
index 486f1d9d0f..25bad2b107 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalSort.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalSort.java
@@ -87,4 +87,13 @@ public class PhysicalSort extends Sort implements PRelNode {
     return new PhysicalSort(getCluster(), getTraitSet(), getHints(), 
getCollation(), offset, fetch, newInputs.get(0),
         newNodeId, newDistribution, _leafStage);
   }
+
+  @Override
+  public PRelNode asLeafStage() {
+    if (isLeafStage()) {
+      return this;
+    }
+    return new PhysicalSort(getCluster(), getTraitSet(), getHints(), 
getCollation(), offset, fetch, _pRelInputs.get(0),
+        _nodeId, _pinotDataDistribution, true);
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalTableScan.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalTableScan.java
index cdbf90db41..18d72fd37d 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalTableScan.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalTableScan.java
@@ -39,7 +39,7 @@ public class PhysicalTableScan extends TableScan implements 
PRelNode {
   @Nullable
   private final TableScanMetadata _tableScanMetadata;
 
-  public PhysicalTableScan(TableScan tableScan, int nodeId, 
PinotDataDistribution pinotDataDistribution,
+  public PhysicalTableScan(TableScan tableScan, int nodeId, @Nullable 
PinotDataDistribution pinotDataDistribution,
       @Nullable TableScanMetadata tableScanMetadata) {
     this(tableScan.getCluster(), tableScan.getTraitSet(), 
tableScan.getHints(), tableScan.getTable(), nodeId,
         pinotDataDistribution, tableScanMetadata);
@@ -99,4 +99,14 @@ public class PhysicalTableScan extends TableScan implements 
PRelNode {
     return new PhysicalTableScan(getCluster(), getTraitSet(), getHints(), 
getTable(), newNodeId,
         newDistribution, _tableScanMetadata);
   }
+
+  @Override
+  public PRelNode asLeafStage() {
+    return this;
+  }
+
+  public PhysicalTableScan with(PinotDataDistribution pinotDataDistribution, 
TableScanMetadata metadata) {
+    return new PhysicalTableScan(getCluster(), getTraitSet(), getHints(), 
getTable(), _nodeId,
+        pinotDataDistribution, metadata);
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/PhysicalOptRuleSet.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/PhysicalOptRuleSet.java
new file mode 100644
index 0000000000..9db842fabf
--- /dev/null
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/PhysicalOptRuleSet.java
@@ -0,0 +1,42 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical.v2.opt;
+
+import java.util.List;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.pinot.common.config.provider.TableCache;
+import org.apache.pinot.query.context.PhysicalPlannerContext;
+import 
org.apache.pinot.query.planner.physical.v2.opt.rules.LeafStageBoundaryRule;
+import 
org.apache.pinot.query.planner.physical.v2.opt.rules.LeafStageWorkerAssignmentRule;
+
+
+public class PhysicalOptRuleSet {
+  private PhysicalOptRuleSet() {
+  }
+
+  public static List<Pair<PRelOptRule, RuleExecutors.Type>> 
create(PhysicalPlannerContext context,
+      TableCache tableCache) {
+    return List.of(
+        Pair.of(LeafStageBoundaryRule.INSTANCE, RuleExecutors.Type.POST_ORDER),
+        Pair.of(new LeafStageWorkerAssignmentRule(context, tableCache), 
RuleExecutors.Type.POST_ORDER));
+        // Pair.of(new WorkerExchangeAssignmentRule(context), 
RuleExecutors.Type.IN_ORDER),
+        // Pair.of(AggregatePushdownRule.INSTANCE, 
RuleExecutors.Type.POST_ORDER),
+        // Pair.of(SortPushdownRule.INSTANCE, RuleExecutors.Type.POST_ORDER));
+  }
+}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/RuleExecutors.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/RuleExecutors.java
new file mode 100644
index 0000000000..b8c08b6d80
--- /dev/null
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/RuleExecutors.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical.v2.opt;
+
+import org.apache.pinot.query.context.PhysicalPlannerContext;
+
+
+public class RuleExecutors {
+  private RuleExecutors() {
+  }
+
+  public static RuleExecutor create(Type type, PRelOptRule rule, 
PhysicalPlannerContext context) {
+    switch (type) {
+      case POST_ORDER:
+        return new PostOrderRuleExecutor(rule, context);
+      case IN_ORDER:
+        return new LeftInputFirstRuleExecutor(rule, context);
+      default:
+        throw new IllegalStateException(String.format("Unrecognized rule 
executor type: %s", type));
+    }
+  }
+
+  public enum Type {
+    POST_ORDER,
+    IN_ORDER
+  }
+}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageAggregateRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageAggregateRule.java
new file mode 100644
index 0000000000..e8121b063b
--- /dev/null
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageAggregateRule.java
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical.v2.opt.rules;
+
+import com.google.common.base.Preconditions;
+import java.util.Map;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.Filter;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.TableScan;
+import org.apache.pinot.calcite.rel.hint.PinotHintOptions;
+import org.apache.pinot.calcite.rel.hint.PinotHintStrategyTable;
+import org.apache.pinot.query.context.PhysicalPlannerContext;
+import org.apache.pinot.query.planner.physical.v2.PRelNode;
+import org.apache.pinot.query.planner.physical.v2.PinotDataDistribution;
+import org.apache.pinot.query.planner.physical.v2.mapping.DistMappingGenerator;
+import org.apache.pinot.query.planner.physical.v2.mapping.PinotDistMapping;
+import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalAggregate;
+import org.apache.pinot.query.planner.physical.v2.opt.PRelOptRule;
+import org.apache.pinot.query.planner.physical.v2.opt.PRelOptRuleCall;
+
+
+/**
+ * Often it might be possible to promote an aggregate on top of a leaf stage 
to be part of the leaf stage. This rule
+ * handles that case. This is different from aggregate pushdown because 
pushdown is related to taking a decision about
+ * whether we should split the aggregate over an exchange into two, whereas 
this rule is able to avoid the Exchange
+ * altogether.
+ */
+public class LeafStageAggregateRule extends PRelOptRule {
+  private final PhysicalPlannerContext _physicalPlannerContext;
+
+  public LeafStageAggregateRule(PhysicalPlannerContext physicalPlannerContext) 
{
+    _physicalPlannerContext = physicalPlannerContext;
+  }
+
+  @Override
+  public boolean matches(PRelOptRuleCall call) {
+    if (call._currentNode.isLeafStage()) {
+      return false;
+    }
+    PRelNode currentNode = call._currentNode;
+    if (!(currentNode.unwrap() instanceof Aggregate)) {
+      return false;
+    }
+    if (!isProjectFilterOrScan(currentNode.getPRelInput(0).unwrap())) {
+      return false;
+    }
+    // ==> We have: "aggregate (non-leaf stage) > project|filter|table-scan 
(leaf-stage)"
+    PhysicalAggregate aggRel = (PhysicalAggregate) currentNode.unwrap();
+    PRelNode pRelInput = aggRel.getPRelInput(0);
+    if (isPartitionedByHintPresent(aggRel)) {
+      Preconditions.checkState(aggRel.getInput().getTraitSet().getCollation() 
== null,
+          "Aggregate input has sort constraint, but partition-by hint is 
forcing to skip exchange");
+      return true;
+    }
+    return pRelInput.areTraitsSatisfied();
+  }
+
+  @Override
+  public PRelNode onMatch(PRelOptRuleCall call) {
+    PhysicalAggregate currentNode = (PhysicalAggregate) call._currentNode;
+    PinotDistMapping mapping = 
DistMappingGenerator.compute(currentNode.getPRelInput(0).unwrap(),
+        currentNode.unwrap(), null);
+    PinotDataDistribution derivedDistribution = 
currentNode.getPRelInput(0).getPinotDataDistributionOrThrow()
+        .apply(mapping);
+    return currentNode.with(currentNode.getPRelInputs(), derivedDistribution);
+  }
+
+  private static boolean isPartitionedByHintPresent(PhysicalAggregate aggRel) {
+    Map<String, String> hintOptions =
+        PinotHintStrategyTable.getHintOptions(aggRel.getHints(), 
PinotHintOptions.AGGREGATE_HINT_OPTIONS);
+    hintOptions = hintOptions == null ? Map.of() : hintOptions;
+    return 
Boolean.parseBoolean(hintOptions.get(PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS));
+  }
+
+  private static boolean isProjectFilterOrScan(RelNode relNode) {
+    return relNode instanceof TableScan || relNode instanceof Project || 
relNode instanceof Filter;
+  }
+}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageBoundaryRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageBoundaryRule.java
new file mode 100644
index 0000000000..42b7cf9c65
--- /dev/null
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageBoundaryRule.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical.v2.opt.rules;
+
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Filter;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.TableScan;
+import org.apache.pinot.query.planner.physical.v2.PRelNode;
+import org.apache.pinot.query.planner.physical.v2.opt.PRelOptRule;
+import org.apache.pinot.query.planner.physical.v2.opt.PRelOptRuleCall;
+
+
+/**
+ * The leaf stage consists of a table-scan and an optional project and/or 
filter. The filter and project nodes
+ * may be in any order. We don't include sort or aggregate in the leaf stage 
in this rule, because they will made part
+ * of Leaf stage (if appropriate) as part of the Aggregate and Sort pushdown 
rules.
+ * <p>
+ *  The idea is that you can and should always make filter and project part of 
the leaf stage and compute them locally
+ *  on the server where the table-scan is computed. Whether it makes sense to 
run the aggregate or sort in the leaf
+ *  stage depends on a few conditions, and hence it is handled as part of 
their respective pushdown rules.
+ * </p>
+ */
+public class LeafStageBoundaryRule extends PRelOptRule {
+  public static final LeafStageBoundaryRule INSTANCE = new 
LeafStageBoundaryRule();
+
+  private LeafStageBoundaryRule() {
+  }
+
+  @Override
+  public boolean matches(PRelOptRuleCall call) {
+    RelNode currentRel = call._currentNode.unwrap();
+    if (currentRel instanceof TableScan) {
+      return true;
+    }
+    if (!isProjectOrFilter(currentRel)) {
+      return false;
+    }
+    if (isTableScan(currentRel.getInput(0))) {
+      // (Project|Filter) > Table Scan
+      return true;
+    }
+    if (isProject(currentRel) && isFilter(currentRel.getInput(0))
+        && isTableScan(currentRel.getInput(0).getInput(0))) {
+      // Project > Filter > Table Scan
+      return true;
+    }
+    // Filter > Project > Table Scan
+    return isFilter(currentRel) && isProject(currentRel.getInput(0))
+        && isTableScan(currentRel.getInput(0).getInput(0));
+  }
+
+  @Override
+  public PRelNode onMatch(PRelOptRuleCall call) {
+    PRelNode currentNode = call._currentNode;
+    return currentNode.asLeafStage();
+  }
+
+  private boolean isProjectOrFilter(RelNode relNode) {
+    return isProject(relNode) || isFilter(relNode);
+  }
+
+  private boolean isProject(RelNode relNode) {
+    return relNode instanceof Project;
+  }
+
+  private boolean isFilter(RelNode relNode) {
+    return relNode instanceof Filter;
+  }
+
+  private boolean isTableScan(RelNode relNode) {
+    return relNode instanceof TableScan;
+  }
+}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageWorkerAssignmentRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageWorkerAssignmentRule.java
new file mode 100644
index 0000000000..3295e2cf37
--- /dev/null
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageWorkerAssignmentRule.java
@@ -0,0 +1,424 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical.v2.opt.rules;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import org.apache.calcite.plan.RelOptTable;
+import org.apache.calcite.rel.RelDistribution;
+import org.apache.calcite.rel.core.TableScan;
+import org.apache.calcite.rel.hint.RelHint;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.collections4.MapUtils;
+import org.apache.pinot.calcite.rel.hint.PinotHintOptions;
+import org.apache.pinot.common.config.provider.TableCache;
+import org.apache.pinot.common.utils.DatabaseUtils;
+import org.apache.pinot.core.routing.RoutingManager;
+import org.apache.pinot.core.routing.RoutingTable;
+import org.apache.pinot.core.routing.ServerRouteInfo;
+import org.apache.pinot.core.routing.TablePartitionInfo;
+import org.apache.pinot.core.routing.TimeBoundaryInfo;
+import org.apache.pinot.core.transport.ServerInstance;
+import org.apache.pinot.query.context.PhysicalPlannerContext;
+import org.apache.pinot.query.planner.physical.v2.HashDistributionDesc;
+import org.apache.pinot.query.planner.physical.v2.PRelNode;
+import org.apache.pinot.query.planner.physical.v2.PinotDataDistribution;
+import org.apache.pinot.query.planner.physical.v2.TableScanMetadata;
+import org.apache.pinot.query.planner.physical.v2.mapping.DistMappingGenerator;
+import org.apache.pinot.query.planner.physical.v2.mapping.PinotDistMapping;
+import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalTableScan;
+import org.apache.pinot.query.planner.physical.v2.opt.PRelOptRule;
+import org.apache.pinot.query.planner.physical.v2.opt.PRelOptRuleCall;
+import org.apache.pinot.query.planner.plannode.PlanNode;
+import org.apache.pinot.query.routing.QueryServerInstance;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.utils.builder.TableNameBuilder;
+import org.apache.pinot.sql.parsers.CalciteSqlCompiler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * <h1>Overview</h1>
+ * Assigns workers to all PRelNodes that are part of the leaf stage as 
determined by {@link PRelNode#isLeafStage()}.
+ * The workers are mainly determined by the Table Scan, unless filter based 
server pruning is enabled.
+ * <h1>Current Features</h1>
+ * <ul>
+ *   <li>
+ *     Automatically detects partitioning and adds that information to 
PinotDataDistribution. This will be used
+ *     in subsequent worker assignment steps to simplify Exchange.
+ *   </li>
+ * </ul>
+ * <h1>Planned / Upcoming Features</h1>
+ * <ul>
+ *   <li>Support for look-up join.</li>
+ *   <li>Support for partition parallelism and the colocated join hint. See F2 
in #15455.</li>
+ *   <li>Support for Hybrid Tables for automatic partitioning inference.</li>
+ *   <li>Server pruning based on filter predicates.</li>
+ * </ul>
+ */
+public class LeafStageWorkerAssignmentRule extends PRelOptRule {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(LeafStageWorkerAssignmentRule.class);
+  private final TableCache _tableCache;
+  private final RoutingManager _routingManager;
+  private final PhysicalPlannerContext _physicalPlannerContext;
+
+  public LeafStageWorkerAssignmentRule(PhysicalPlannerContext 
physicalPlannerContext, TableCache tableCache) {
+    _routingManager = physicalPlannerContext.getRoutingManager();
+    _physicalPlannerContext = physicalPlannerContext;
+    _tableCache = tableCache;
+  }
+
+  @Override
+  public boolean matches(PRelOptRuleCall call) {
+    return call._currentNode.isLeafStage();
+  }
+
+  @Override
+  public PRelNode onMatch(PRelOptRuleCall call) {
+    if (call._currentNode.unwrap() instanceof TableScan) {
+      return assignTableScan((PhysicalTableScan) call._currentNode, 
_physicalPlannerContext.getRequestId());
+    }
+    PRelNode currentNode = call._currentNode;
+    Preconditions.checkState(currentNode.isLeafStage(), "Leaf stage worker 
assignment called for non-leaf stage node:"
+        + " %s", currentNode);
+    PinotDistMapping mapping = 
DistMappingGenerator.compute(currentNode.getPRelInput(0).unwrap(),
+        currentNode.unwrap(), null);
+    PinotDataDistribution derivedDistribution = 
currentNode.getPRelInput(0).getPinotDataDistributionOrThrow()
+        .apply(mapping);
+    return currentNode.with(currentNode.getPRelInputs(), derivedDistribution);
+  }
+
+  private PhysicalTableScan assignTableScan(PhysicalTableScan tableScan, long 
requestId) {
+    // Step-1: Init tableName, table options, routing table and time boundary 
info.
+    String tableName = Objects.requireNonNull(getActualTableName(tableScan), 
"Table not found");
+    Map<String, String> tableOptions = getTableOptions(tableScan.getHints());
+    Map<String, RoutingTable> routingTableMap = getRoutingTable(tableName, 
requestId);
+    Preconditions.checkState(!routingTableMap.isEmpty(), "Unable to find 
routing entries for table: %s", tableName);
+    // acquire time boundary info if it is a hybrid table.
+    TimeBoundaryInfo timeBoundaryInfo = null;
+    if (routingTableMap.size() > 1) {
+      timeBoundaryInfo = _routingManager.getTimeBoundaryInfo(
+          TableNameBuilder.forType(TableType.OFFLINE)
+              
.tableNameWithType(TableNameBuilder.extractRawTableName(tableName)));
+      if (timeBoundaryInfo == null) {
+        // remove offline table routing if no time boundary info is acquired.
+        routingTableMap.remove(TableType.OFFLINE.name());
+      }
+    }
+    // Step-2: Compute instance to segments map and unavailable segments.
+    Map<String, Set<String>> segmentUnavailableMap = new HashMap<>();
+    InstanceIdToSegments instanceIdToSegments = new InstanceIdToSegments();
+    for (Map.Entry<String, RoutingTable> routingEntry : 
routingTableMap.entrySet()) {
+      String tableType = routingEntry.getKey();
+      RoutingTable routingTable = routingEntry.getValue();
+      Map<String, List<String>> currentSegmentsMap = new HashMap<>();
+      Map<ServerInstance, ServerRouteInfo> tmp = 
routingTable.getServerInstanceToSegmentsMap();
+      for (Map.Entry<ServerInstance, ServerRouteInfo> serverEntry : 
tmp.entrySet()) {
+        String instanceId = serverEntry.getKey().getInstanceId();
+        Preconditions.checkState(currentSegmentsMap.put(instanceId, 
serverEntry.getValue().getSegments()) == null,
+            "Entry for server %s and table type: %s already exist!", 
serverEntry.getKey(), tableType);
+        
_physicalPlannerContext.getInstanceIdToQueryServerInstance().computeIfAbsent(instanceId,
+            (ignore) -> new QueryServerInstance(serverEntry.getKey()));
+      }
+      if (tableType.equalsIgnoreCase(TableType.OFFLINE.name())) {
+        instanceIdToSegments._offlineTableSegmentsMap = currentSegmentsMap;
+      } else {
+        instanceIdToSegments._realtimeTableSegmentsMap = currentSegmentsMap;
+      }
+      if (!routingTable.getUnavailableSegments().isEmpty()) {
+        // Set unavailable segments in context, keyed by PRelNode ID.
+        
segmentUnavailableMap.put(TableNameBuilder.forType(TableType.valueOf(tableName)).tableNameWithType(tableName),
+            new HashSet<>(routingTable.getUnavailableSegments()));
+      }
+    }
+    List<String> fieldNames = tableScan.getRowType().getFieldNames();
+    Map<String, TablePartitionInfo> tablePartitionInfoMap = 
calculateTablePartitionInfo(tableName,
+        routingTableMap.keySet());
+    TableScanWorkerAssignmentResult workerAssignmentResult = 
assignTableScan(tableName, fieldNames,
+        instanceIdToSegments, tablePartitionInfoMap);
+    TableScanMetadata metadata = new TableScanMetadata(Set.of(tableName), 
workerAssignmentResult._workerIdToSegmentsMap,
+        tableOptions, segmentUnavailableMap, timeBoundaryInfo);
+    return tableScan.with(workerAssignmentResult._pinotDataDistribution, 
metadata);
+  }
+
+  /**
+   * Assigns workers for the table-scan node, automatically detecting table 
partitioning whenever possible. The
+   * arguments to this method are minimal to facilitate unit-testing.
+   */
+  @VisibleForTesting
+  static TableScanWorkerAssignmentResult assignTableScan(String tableName, 
List<String> fieldNames,
+      InstanceIdToSegments instanceIdToSegments, Map<String, 
TablePartitionInfo> tpiMap) {
+    Set<String> tableTypes = instanceIdToSegments.getActiveTableTypes();
+    Set<String> partitionedTableTypes = 
tableTypes.stream().filter(tpiMap::containsKey).collect(Collectors.toSet());
+    Preconditions.checkState(!tableTypes.isEmpty(), "No routing entry for 
offline or realtime type");
+    if (tableTypes.equals(partitionedTableTypes)) {
+      if (partitionedTableTypes.size() == 1) {
+        // Attempt partitioned distribution
+        String tableType = partitionedTableTypes.iterator().next();
+        String tableNameWithType = 
TableNameBuilder.forType(TableType.valueOf(tableType)).tableNameWithType(tableName);
+        TableScanWorkerAssignmentResult assignmentResult = 
attemptPartitionedDistribution(tableNameWithType,
+            fieldNames, 
instanceIdToSegments.getSegmentsMap(TableType.valueOf(tableType)),
+            tpiMap.get(tableType));
+        if (assignmentResult != null) {
+          return assignmentResult;
+        }
+      } else {
+        // TODO(mse-physical): Support automatic partitioned dist for hybrid 
tables.
+        LOGGER.warn("Automatic Partitioned Distribution not supported for 
Hybrid Tables yet");
+      }
+    }
+    // For each server, we want to know the segments for each table-type.
+    Map<String, Map<String, List<String>>> instanceIdToTableTypeToSegmentsMap 
= new HashMap<>();
+    for (String tableType : tableTypes) {
+      Map<String, List<String>> segmentsMap = 
instanceIdToSegments.getSegmentsMap(TableType.valueOf(tableType));
+      Preconditions.checkNotNull(segmentsMap, "Unexpected null segments map in 
leaf worker assignment");
+      for (Map.Entry<String, List<String>> entry : segmentsMap.entrySet()) {
+        String instanceId = entry.getKey();
+        List<String> segments = entry.getValue();
+        instanceIdToTableTypeToSegmentsMap.computeIfAbsent(instanceId, k -> 
new HashMap<>())
+            .put(tableType, segments);
+      }
+    }
+    // For each server, assign one worker each.
+    Map<Integer, Map<String, List<String>>> workerIdToSegmentsMap = new 
HashMap<>();
+    List<String> workers = new ArrayList<>();
+    for (Map.Entry<String, Map<String, List<String>>> entry : 
instanceIdToTableTypeToSegmentsMap.entrySet()) {
+      String instanceId = entry.getKey();
+      for (var tableTypeAndSegments : entry.getValue().entrySet()) {
+        String tableType = tableTypeAndSegments.getKey();
+        List<String> segments = tableTypeAndSegments.getValue();
+        workerIdToSegmentsMap.computeIfAbsent(workers.size(), (x) -> new 
HashMap<>()).put(tableType, segments);
+      }
+      workers.add(String.format("%s@%s", workers.size(), instanceId));
+    }
+    PinotDataDistribution pinotDataDistribution = new 
PinotDataDistribution(RelDistribution.Type.RANDOM_DISTRIBUTED,
+        workers, workers.hashCode(), null, null);
+    return new TableScanWorkerAssignmentResult(pinotDataDistribution, 
workerIdToSegmentsMap);
+  }
+
+  /**
+   * Tries to assign workers for the table-scan node to generate a partitioned 
data distribution. If this is not
+   * possible, we simply return null.
+   */
+  @Nullable
+  @VisibleForTesting
+  static TableScanWorkerAssignmentResult attemptPartitionedDistribution(String 
tableNameWithType,
+      List<String> fieldNames, Map<String, List<String>> 
instanceIdToSegmentsMap,
+      TablePartitionInfo tablePartitionInfo) {
+    if 
(CollectionUtils.isNotEmpty(tablePartitionInfo.getSegmentsWithInvalidPartition()))
 {
+      LOGGER.warn("Table {} has {} segments with invalid partition info. Will 
assume un-partitioned distribution",
+          tableNameWithType, 
tablePartitionInfo.getSegmentsWithInvalidPartition().size());
+      return null;
+    }
+    String tableType =
+        
Objects.requireNonNull(TableNameBuilder.getTableTypeFromTableName(tableNameWithType),
+            "Illegal state: expected table name with type").toString();
+    int numPartitions = tablePartitionInfo.getNumPartitions();
+    int keyIndex = fieldNames.indexOf(tablePartitionInfo.getPartitionColumn());
+    String function = tablePartitionInfo.getPartitionFunctionName();
+    int numSelectedServers = instanceIdToSegmentsMap.size();
+    if (keyIndex == -1) {
+      LOGGER.warn("Unable to find partition column {} in table scan fields 
{}", tablePartitionInfo.getPartitionColumn(),
+          fieldNames);
+      return null;
+    } else if (numPartitions < numSelectedServers) {
+      return null;
+    }
+    // Pre-compute segmentToServer map for quick lookup later.
+    Map<String, String> segmentToServer = new HashMap<>();
+    for (var entry : instanceIdToSegmentsMap.entrySet()) {
+      String instanceId = entry.getKey();
+      for (String segment : entry.getValue()) {
+        segmentToServer.put(segment, instanceId);
+      }
+    }
+    // For each partition, we expect at most 1 server which will be stored in 
this array.
+    String[] partitionToServerMap = new 
String[tablePartitionInfo.getNumPartitions()];
+    TablePartitionInfo.PartitionInfo[] partitionInfos = 
tablePartitionInfo.getPartitionInfoMap();
+    Map<Integer, List<String>> segmentsByPartition = new HashMap<>();
+    // Ensure each partition is assigned to exactly 1 server.
+    for (int partitionNum = 0; partitionNum < numPartitions; partitionNum++) {
+      TablePartitionInfo.PartitionInfo info = partitionInfos[partitionNum];
+      List<String> selectedSegments = new ArrayList<>();
+      if (info != null) {
+        String chosenServer;
+        for (String segment : info._segments) {
+          chosenServer = segmentToServer.get(segment);
+          // segmentToServer may return null if TPI has a segment not present 
in instanceIdToSegmentsMap.
+          // This can happen when the segment was not selected for the query 
(due to pruning for instance).
+          if (chosenServer != null) {
+            selectedSegments.add(segment);
+            if (partitionToServerMap[partitionNum] == null) {
+              partitionToServerMap[partitionNum] = chosenServer;
+            } else if 
(!partitionToServerMap[partitionNum].equals(chosenServer)) {
+              return null;
+            }
+          }
+        }
+      }
+      segmentsByPartition.put(partitionNum, selectedSegments);
+    }
+    // Initialize workers list. Initially each element is empty. We have 1 
worker for each selected server.
+    List<String> workers = new ArrayList<>();
+    for (int i = 0; i < numSelectedServers; i++) {
+      workers.add("");
+    }
+    // Try to assign workers in such a way that partition P goes to worker = P 
% num-workers.
+    for (int partitionNum = 0; partitionNum < numPartitions; partitionNum++) {
+      if (partitionToServerMap[partitionNum] != null) {
+        int workerId = partitionNum % workers.size();
+        if (workers.get(workerId).isEmpty()) {
+          workers.set(workerId, partitionToServerMap[partitionNum]);
+        } else if 
(!workers.get(workerId).equals(partitionToServerMap[partitionNum])) {
+          return null;
+        }
+      }
+    }
+    // Build the workerId to segments map.
+    Map<Integer, Map<String, List<String>>> workerIdToSegmentsMap = new 
HashMap<>();
+    for (int workerId = 0; workerId < workers.size(); workerId++) {
+      List<String> segmentsForWorker = new ArrayList<>();
+      for (int partitionNum = workerId; partitionNum < numPartitions; 
partitionNum += workers.size()) {
+        segmentsForWorker.addAll(segmentsByPartition.get(partitionNum));
+      }
+      workers.set(workerId, String.format("%s@%s", workerId, 
workers.get(workerId)));
+      workerIdToSegmentsMap.put(workerId, ImmutableMap.of(tableType, 
segmentsForWorker));
+    }
+    HashDistributionDesc desc = new 
HashDistributionDesc(ImmutableList.of(keyIndex), function, numPartitions);
+    PinotDataDistribution dataDistribution = new 
PinotDataDistribution(RelDistribution.Type.HASH_DISTRIBUTED,
+        workers, workers.hashCode(), ImmutableSet.of(desc), null);
+    return new TableScanWorkerAssignmentResult(dataDistribution, 
workerIdToSegmentsMap);
+  }
+
+  private Map<String, TablePartitionInfo> calculateTablePartitionInfo(String 
tableName, Set<String> tableTypes) {
+    Map<String, TablePartitionInfo> result = new HashMap<>();
+    if (tableTypes.contains("OFFLINE")) {
+      String offlineTableType = 
TableNameBuilder.OFFLINE.tableNameWithType(tableName);
+      TablePartitionInfo tablePartitionInfo = 
_routingManager.getTablePartitionInfo(offlineTableType);
+      if (tablePartitionInfo != null) {
+        result.put("OFFLINE", tablePartitionInfo);
+      }
+    }
+    if (tableTypes.contains("REALTIME")) {
+      String realtimeTableType = 
TableNameBuilder.REALTIME.tableNameWithType(tableName);
+      TablePartitionInfo tablePartitionInfo = 
_routingManager.getTablePartitionInfo(realtimeTableType);
+      if (tablePartitionInfo != null) {
+        result.put("REALTIME", 
_routingManager.getTablePartitionInfo(tableName));
+      }
+    }
+    return result;
+  }
+
+  /**
+   * Acquire routing table for items listed in TableScanNode.
+   *
+   * @param tableName table name with or without type suffix.
+   * @return keyed-map from table type(s) to routing table(s).
+   */
+  private Map<String, RoutingTable> getRoutingTable(String tableName, long 
requestId) {
+    String rawTableName = TableNameBuilder.extractRawTableName(tableName);
+    TableType tableType = 
TableNameBuilder.getTableTypeFromTableName(tableName);
+    Map<String, RoutingTable> routingTableMap = new HashMap<>();
+    RoutingTable routingTable;
+    if (tableType == null) {
+      routingTable = getRoutingTable(rawTableName, TableType.OFFLINE, 
requestId);
+      if (routingTable != null) {
+        routingTableMap.put(TableType.OFFLINE.name(), routingTable);
+      }
+      routingTable = getRoutingTable(rawTableName, TableType.REALTIME, 
requestId);
+      if (routingTable != null) {
+        routingTableMap.put(TableType.REALTIME.name(), routingTable);
+      }
+    } else {
+      routingTable = getRoutingTable(tableName, tableType, requestId);
+      if (routingTable != null) {
+        routingTableMap.put(tableType.name(), routingTable);
+      }
+    }
+    return routingTableMap;
+  }
+
+  private RoutingTable getRoutingTable(String tableName, TableType tableType, 
long requestId) {
+    String tableNameWithType =
+        
TableNameBuilder.forType(tableType).tableNameWithType(TableNameBuilder.extractRawTableName(tableName));
+    return _routingManager.getRoutingTable(
+        CalciteSqlCompiler.compileToBrokerRequest("SELECT * FROM \"" + 
tableNameWithType + "\""), requestId);
+  }
+
+  private Map<String, String> getTableOptions(List<RelHint> hints) {
+    Map<String, String> tmp = 
PlanNode.NodeHint.fromRelHints(hints).getHintOptions().get(
+        PinotHintOptions.TABLE_HINT_OPTIONS);
+    return tmp == null ? Map.of() : tmp;
+  }
+
+  private String getActualTableName(TableScan tableScan) {
+    RelOptTable table = tableScan.getTable();
+    List<String> qualifiedName = table.getQualifiedName();
+    String tmp = qualifiedName.size() == 1 ? qualifiedName.get(0)
+        : DatabaseUtils.constructFullyQualifiedTableName(qualifiedName.get(0), 
qualifiedName.get(1));
+    return _tableCache.getActualTableName(tmp);
+  }
+
+  static class TableScanWorkerAssignmentResult {
+    final PinotDataDistribution _pinotDataDistribution;
+    final Map<Integer, Map<String, List<String>>> _workerIdToSegmentsMap;
+
+    TableScanWorkerAssignmentResult(PinotDataDistribution 
pinotDataDistribution,
+        Map<Integer, Map<String, List<String>>> workerIdToSegmentsMap) {
+      _pinotDataDistribution = pinotDataDistribution;
+      _workerIdToSegmentsMap = workerIdToSegmentsMap;
+    }
+  }
+
+  static class InstanceIdToSegments {
+    @Nullable
+    Map<String, List<String>> _offlineTableSegmentsMap;
+    @Nullable
+    Map<String, List<String>> _realtimeTableSegmentsMap;
+
+    @Nullable
+    Map<String, List<String>> getSegmentsMap(TableType tableType) {
+      return tableType == TableType.OFFLINE ? _offlineTableSegmentsMap : 
_realtimeTableSegmentsMap;
+    }
+
+    Set<String> getActiveTableTypes() {
+      Set<String> tableTypes = new HashSet<>();
+      if (MapUtils.isNotEmpty(_offlineTableSegmentsMap)) {
+        tableTypes.add(TableType.OFFLINE.name());
+      }
+      if (MapUtils.isNotEmpty(_realtimeTableSegmentsMap)) {
+        tableTypes.add(TableType.REALTIME.name());
+      }
+      return tableTypes;
+    }
+  }
+}
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMappingTest.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMappingTest.java
new file mode 100644
index 0000000000..bd537c4be0
--- /dev/null
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMappingTest.java
@@ -0,0 +1,115 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical.v2.mapping;
+
+import java.util.List;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.*;
+
+
+public class PinotDistMappingTest {
+  @Test
+  public void testIdentityMapping() {
+    // Test identity mapping returns the expected mapping.
+    PinotDistMapping mapping = PinotDistMapping.identity(10);
+    assertEquals(mapping.getSourceCount(), 10);
+    for (int i = 0; i < 10; i++) {
+      assertEquals(mapping.getTarget(i), i);
+    }
+    // Test getMappedKeys always returns the same values as the input.
+    List<List<Integer>> testKeys = List.of(
+        List.of(0),
+        List.of(9),
+        List.of(1, 3, 5),
+        List.of(0, 2, 4),
+        List.of(4, 2, 9)
+    );
+    for (List<Integer> keys : testKeys) {
+      assertEquals(mapping.getMappedKeys(keys), keys);
+    }
+  }
+
+  @Test
+  public void testOutOfBoundsSource() {
+    // When the passed source index is out of bounds wrt the sourceCount in 
the mapping, we should get an exception.
+    PinotDistMapping mapping = new PinotDistMapping(5);
+    assertThrows(IllegalArgumentException.class, () -> mapping.getTarget(-1));
+    assertThrows(IllegalArgumentException.class, () -> mapping.getTarget(5));
+    assertThrows(IllegalArgumentException.class, () -> mapping.set(-1, 2));
+    assertThrows(IllegalArgumentException.class, () -> mapping.set(5, 2));
+    assertThrows(IllegalArgumentException.class, () -> 
mapping.getMappedKeys(List.of(5)));
+  }
+
+  @Test
+  public void testSet() {
+    // Test setting a mapping value.
+    PinotDistMapping mapping = new PinotDistMapping(5);
+    mapping.set(0, 2);
+    assertEquals(mapping.getTarget(0), 2);
+    assertEquals(mapping.getTarget(1), -1);
+    assertEquals(mapping.getTarget(2), -1);
+    assertEquals(mapping.getTarget(3), -1);
+    assertEquals(mapping.getTarget(4), -1);
+
+    // Test setting multiple mapping values.
+    mapping.set(1, 3);
+    mapping.set(2, 4);
+    assertEquals(mapping.getTarget(0), 2);
+    assertEquals(mapping.getTarget(1), 3);
+    assertEquals(mapping.getTarget(2), 4);
+    assertEquals(mapping.getTarget(3), -1);
+    assertEquals(mapping.getTarget(4), -1);
+
+    // Test setting a mapping value to an invalid index.
+    assertThrows(IllegalArgumentException.class, () -> mapping.set(-1, 2));
+    assertThrows(IllegalArgumentException.class, () -> mapping.set(5, 2));
+  }
+
+  @Test
+  public void testGetMappedKeys() {
+    {
+      // Test when all passed keys are mapped.
+      PinotDistMapping mapping = new PinotDistMapping(5);
+      mapping.set(0, 2);
+      mapping.set(1, 3);
+      mapping.set(2, 4);
+      List<Integer> keys = List.of(0, 1, 2);
+      List<Integer> expectedMappedKeys = List.of(2, 3, 4);
+      assertEquals(mapping.getMappedKeys(keys), expectedMappedKeys);
+    }
+    {
+      // Test when one of the keys is not mapped
+      PinotDistMapping mapping = new PinotDistMapping(5);
+      mapping.set(0, 2);
+      mapping.set(1, 3);
+      List<Integer> keys = List.of(0, 1, 2);
+      List<Integer> expectedMappedKeys = List.of();
+      assertEquals(mapping.getMappedKeys(keys), expectedMappedKeys);
+    }
+    {
+      // Test getting mapped keys with an invalid key.
+      PinotDistMapping mapping = new PinotDistMapping(5);
+      mapping.set(0, 2);
+      mapping.set(1, 3);
+      List<Integer> keys = List.of(0, 1, 5);
+      assertThrows(IllegalArgumentException.class, () -> 
mapping.getMappedKeys(keys));
+    }
+  }
+}
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageWorkerAssignmentRuleTest.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageWorkerAssignmentRuleTest.java
new file mode 100644
index 0000000000..eed658a96b
--- /dev/null
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageWorkerAssignmentRuleTest.java
@@ -0,0 +1,207 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical.v2.opt.rules;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.RelDistribution;
+import org.apache.pinot.core.routing.TablePartitionInfo;
+import org.apache.pinot.query.planner.physical.v2.HashDistributionDesc;
+import 
org.apache.pinot.query.planner.physical.v2.opt.rules.LeafStageWorkerAssignmentRule.InstanceIdToSegments;
+import 
org.apache.pinot.query.planner.physical.v2.opt.rules.LeafStageWorkerAssignmentRule.TableScanWorkerAssignmentResult;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.utils.builder.TableNameBuilder;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.*;
+
+
+public class LeafStageWorkerAssignmentRuleTest {
+  private static final String TABLE_NAME = "testTable";
+  private static final List<String> FIELDS_IN_SCAN = List.of("userId", 
"orderId", "orderAmount", "cityId", "cityName");
+  private static final String PARTITION_COLUMN = "userId";
+  private static final String PARTITION_FUNCTION = "murmur";
+  private static final int NUM_SERVERS = 4;
+  private static final int OFFLINE_NUM_PARTITIONS = 4;
+  private static final int REALTIME_NUM_PARTITIONS = 8;
+  private static final InstanceIdToSegments OFFLINE_INSTANCE_ID_TO_SEGMENTS;
+  private static final InstanceIdToSegments REALTIME_INSTANCE_ID_TO_SEGMENTS;
+  private static final InstanceIdToSegments HYBRID_INSTANCE_ID_TO_SEGMENTS;
+
+  static {
+    Map<String, List<String>> offlineSegmentsMap = createOfflineSegmentsMap();
+    Map<String, List<String>> realtimeSegmentsMap = 
createRealtimeSegmentsMap();
+    OFFLINE_INSTANCE_ID_TO_SEGMENTS = new InstanceIdToSegments();
+    OFFLINE_INSTANCE_ID_TO_SEGMENTS._offlineTableSegmentsMap = 
offlineSegmentsMap;
+    REALTIME_INSTANCE_ID_TO_SEGMENTS = new InstanceIdToSegments();
+    REALTIME_INSTANCE_ID_TO_SEGMENTS._realtimeTableSegmentsMap = 
realtimeSegmentsMap;
+    HYBRID_INSTANCE_ID_TO_SEGMENTS = new InstanceIdToSegments();
+    HYBRID_INSTANCE_ID_TO_SEGMENTS._offlineTableSegmentsMap = 
offlineSegmentsMap;
+    HYBRID_INSTANCE_ID_TO_SEGMENTS._realtimeTableSegmentsMap = 
realtimeSegmentsMap;
+  }
+
+  @Test
+  public void testAssignTableScanWithUnPartitionedOfflineTable() {
+    TableScanWorkerAssignmentResult result = 
LeafStageWorkerAssignmentRule.assignTableScan(TABLE_NAME, FIELDS_IN_SCAN,
+        OFFLINE_INSTANCE_ID_TO_SEGMENTS, Map.of());
+    assertEquals(result._pinotDataDistribution.getType(), 
RelDistribution.Type.RANDOM_DISTRIBUTED);
+    assertEquals(result._pinotDataDistribution.getWorkers().size(), 4);
+    assertEquals(result._pinotDataDistribution.getCollation(), 
RelCollations.EMPTY);
+    
assertEquals(result._pinotDataDistribution.getHashDistributionDesc().size(), 0);
+    validateTableScanAssignment(result, 
OFFLINE_INSTANCE_ID_TO_SEGMENTS._offlineTableSegmentsMap, "OFFLINE");
+  }
+
+  @Test
+  public void testAssignTableScanWithUnPartitionedRealtimeTable() {
+    TableScanWorkerAssignmentResult result = 
LeafStageWorkerAssignmentRule.assignTableScan(TABLE_NAME, FIELDS_IN_SCAN,
+        REALTIME_INSTANCE_ID_TO_SEGMENTS, Map.of());
+    assertEquals(result._pinotDataDistribution.getType(), 
RelDistribution.Type.RANDOM_DISTRIBUTED);
+    assertEquals(result._pinotDataDistribution.getWorkers().size(), 4);
+    assertEquals(result._pinotDataDistribution.getCollation(), 
RelCollations.EMPTY);
+    
assertEquals(result._pinotDataDistribution.getHashDistributionDesc().size(), 0);
+    validateTableScanAssignment(result, 
REALTIME_INSTANCE_ID_TO_SEGMENTS._realtimeTableSegmentsMap, "REALTIME");
+  }
+
+  @Test
+  public void testAssignTableScanWithUnPartitionedHybridTable() {
+    TableScanWorkerAssignmentResult result = 
LeafStageWorkerAssignmentRule.assignTableScan(TABLE_NAME, FIELDS_IN_SCAN,
+        HYBRID_INSTANCE_ID_TO_SEGMENTS, Map.of());
+    assertEquals(result._pinotDataDistribution.getType(), 
RelDistribution.Type.RANDOM_DISTRIBUTED);
+    assertEquals(result._pinotDataDistribution.getWorkers().size(), 4);
+    assertEquals(result._pinotDataDistribution.getCollation(), 
RelCollations.EMPTY);
+    
assertEquals(result._pinotDataDistribution.getHashDistributionDesc().size(), 0);
+    validateTableScanAssignment(result, 
HYBRID_INSTANCE_ID_TO_SEGMENTS._offlineTableSegmentsMap, "OFFLINE");
+    validateTableScanAssignment(result, 
HYBRID_INSTANCE_ID_TO_SEGMENTS._realtimeTableSegmentsMap, "REALTIME");
+  }
+
+  @Test
+  public void testAssignTableScanPartitionedOfflineTable() {
+    TableScanWorkerAssignmentResult result = 
LeafStageWorkerAssignmentRule.assignTableScan(TABLE_NAME, FIELDS_IN_SCAN,
+        OFFLINE_INSTANCE_ID_TO_SEGMENTS, Map.of("OFFLINE", 
createOfflineTablePartitionInfo()));
+    // Basic checks.
+    assertEquals(result._pinotDataDistribution.getType(), 
RelDistribution.Type.HASH_DISTRIBUTED);
+    assertEquals(result._pinotDataDistribution.getWorkers().size(), 4);
+    assertEquals(result._pinotDataDistribution.getCollation(), 
RelCollations.EMPTY);
+    
assertEquals(result._pinotDataDistribution.getHashDistributionDesc().size(), 1);
+    HashDistributionDesc desc = 
result._pinotDataDistribution.getHashDistributionDesc().iterator().next();
+    assertEquals(desc.getNumPartitions(), OFFLINE_NUM_PARTITIONS);
+    assertEquals(desc.getKeys(), 
List.of(FIELDS_IN_SCAN.indexOf(PARTITION_COLUMN)));
+    assertEquals(desc.getHashFunction(), PARTITION_FUNCTION);
+    validateTableScanAssignment(result, 
OFFLINE_INSTANCE_ID_TO_SEGMENTS._offlineTableSegmentsMap, "OFFLINE");
+  }
+
+  @Test
+  public void testAssignTableScanPartitionedRealtimeTable() {
+    TableScanWorkerAssignmentResult result = 
LeafStageWorkerAssignmentRule.assignTableScan(TABLE_NAME, FIELDS_IN_SCAN,
+        REALTIME_INSTANCE_ID_TO_SEGMENTS, Map.of("REALTIME", 
createRealtimeTablePartitionInfo()));
+    // Basic checks.
+    assertEquals(result._pinotDataDistribution.getType(), 
RelDistribution.Type.HASH_DISTRIBUTED);
+    assertEquals(result._pinotDataDistribution.getWorkers().size(), 4);
+    assertEquals(result._pinotDataDistribution.getCollation(), 
RelCollations.EMPTY);
+    
assertEquals(result._pinotDataDistribution.getHashDistributionDesc().size(), 1);
+    HashDistributionDesc desc = 
result._pinotDataDistribution.getHashDistributionDesc().iterator().next();
+    assertEquals(desc.getNumPartitions(), REALTIME_NUM_PARTITIONS);
+    assertEquals(desc.getKeys(), 
List.of(FIELDS_IN_SCAN.indexOf(PARTITION_COLUMN)));
+    assertEquals(desc.getHashFunction(), PARTITION_FUNCTION);
+    validateTableScanAssignment(result, 
REALTIME_INSTANCE_ID_TO_SEGMENTS._realtimeTableSegmentsMap, "REALTIME");
+  }
+
+  @Test
+  public void testAssignTableScanPartitionedHybridTable() {
+    TableScanWorkerAssignmentResult result = 
LeafStageWorkerAssignmentRule.assignTableScan(TABLE_NAME, FIELDS_IN_SCAN,
+        HYBRID_INSTANCE_ID_TO_SEGMENTS, Map.of("OFFLINE", 
createOfflineTablePartitionInfo(),
+            "REALTIME", createRealtimeTablePartitionInfo()));
+    assertEquals(result._pinotDataDistribution.getType(), 
RelDistribution.Type.RANDOM_DISTRIBUTED);
+    assertEquals(result._pinotDataDistribution.getWorkers().size(), 4);
+    assertEquals(result._pinotDataDistribution.getCollation(), 
RelCollations.EMPTY);
+    
assertEquals(result._pinotDataDistribution.getHashDistributionDesc().size(), 0);
+    validateTableScanAssignment(result, 
HYBRID_INSTANCE_ID_TO_SEGMENTS._offlineTableSegmentsMap, "OFFLINE");
+    validateTableScanAssignment(result, 
HYBRID_INSTANCE_ID_TO_SEGMENTS._realtimeTableSegmentsMap, "REALTIME");
+  }
+
+  private static void 
validateTableScanAssignment(TableScanWorkerAssignmentResult assignmentResult,
+      Map<String, List<String>> instanceIdToSegmentsMap, String tableType) {
+    Map<String, List<String>> actualInstanceIdToSegments = new HashMap<>();
+    for (var entry : assignmentResult._workerIdToSegmentsMap.entrySet()) {
+      int workerId = entry.getKey();
+      String fullWorkerId = 
assignmentResult._pinotDataDistribution.getWorkers().get(workerId);
+      String instanceId = fullWorkerId.split("@")[1];
+      actualInstanceIdToSegments.put(instanceId, 
entry.getValue().get(tableType));
+      assertEquals(Integer.parseInt(fullWorkerId.split("@")[0]), workerId);
+    }
+    assertEquals(actualInstanceIdToSegments, instanceIdToSegmentsMap);
+  }
+
+  private static Map<String, List<String>> createOfflineSegmentsMap() {
+    // assume 4 servers and 4 partitions.
+    Map<String, List<String>> result = new HashMap<>();
+    result.put("instance-0", List.of("segment1-0", "segment2-0", 
"segment3-0"));
+    result.put("instance-1", List.of("segment1-1", "segment2-1"));
+    result.put("instance-2", List.of("segment1-2"));
+    result.put("instance-3", List.of("segment1-3", "segment2-3", 
"segment3-3"));
+    return result;
+  }
+
+  private static Map<String, List<String>> createRealtimeSegmentsMap() {
+    // assume 4 servers and 8 partitions. assume partition-5 is missing.
+    Map<String, List<String>> result = new HashMap<>();
+    result.put("instance-0", List.of("segment1-0", "segment1-4", 
"segment2-4"));
+    result.put("instance-1", List.of("segment1-1", "segment2-1"));
+    result.put("instance-2", List.of("segment1-2", "segment1-6"));
+    result.put("instance-3", List.of("segment1-3", "segment2-3", "segment1-7", 
"segment2-7"));
+    return result;
+  }
+
+  private static TablePartitionInfo createOfflineTablePartitionInfo() {
+    TablePartitionInfo.PartitionInfo[] infos = new 
TablePartitionInfo.PartitionInfo[OFFLINE_NUM_PARTITIONS];
+    for (int partitionNum = 0; partitionNum < OFFLINE_NUM_PARTITIONS; 
partitionNum++) {
+      String selectedInstance = String.format("instance-%s", partitionNum % 
NUM_SERVERS);
+      String additionalInstance = String.format("instance-%s", NUM_SERVERS + 
partitionNum);
+      final String segmentSuffixForPartition = String.format("-%d", 
partitionNum);
+      List<String> segments = 
Objects.requireNonNull(OFFLINE_INSTANCE_ID_TO_SEGMENTS._offlineTableSegmentsMap)
+          .get(selectedInstance).stream().filter(segment -> 
segment.endsWith(segmentSuffixForPartition))
+          .collect(Collectors.toList());
+      infos[partitionNum] = new 
TablePartitionInfo.PartitionInfo(Set.of(selectedInstance, additionalInstance),
+         segments);
+    }
+    return new 
TablePartitionInfo(TableNameBuilder.forType(TableType.OFFLINE).tableNameWithType(TABLE_NAME),
+        PARTITION_COLUMN, PARTITION_FUNCTION, OFFLINE_NUM_PARTITIONS, infos, 
List.of());
+  }
+
+  private static TablePartitionInfo createRealtimeTablePartitionInfo() {
+    TablePartitionInfo.PartitionInfo[] infos = new 
TablePartitionInfo.PartitionInfo[REALTIME_NUM_PARTITIONS];
+    for (int partitionNum = 0; partitionNum < REALTIME_NUM_PARTITIONS; 
partitionNum++) {
+      String selectedInstance = String.format("instance-%s", partitionNum % 
NUM_SERVERS);
+      String additionalInstance = String.format("instance-%s", NUM_SERVERS + 
(partitionNum % NUM_SERVERS));
+      final String segmentSuffixForPartition = String.format("-%d", 
partitionNum);
+      List<String> segments = 
Objects.requireNonNull(REALTIME_INSTANCE_ID_TO_SEGMENTS._realtimeTableSegmentsMap)
+          .get(selectedInstance).stream().filter(segment -> 
segment.endsWith(segmentSuffixForPartition))
+          .collect(Collectors.toList());
+      infos[partitionNum] = new 
TablePartitionInfo.PartitionInfo(Set.of(selectedInstance, additionalInstance),
+          segments);
+    }
+    return new 
TablePartitionInfo(TableNameBuilder.forType(TableType.REALTIME).tableNameWithType(TABLE_NAME),
+        PARTITION_COLUMN, PARTITION_FUNCTION, REALTIME_NUM_PARTITIONS, infos, 
List.of());
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to