This is an automated email from the ASF dual-hosted git repository.

ankitsultana pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 97aa5cacce [multistage] Bug Fixes and Refactorings Based on E2E 
Physical Optimizer Test (#15561)
97aa5cacce is described below

commit 97aa5cacce3cad05b4e3b917f4c1386942d6784c
Author: Ankit Sultana <ankitsult...@uber.com>
AuthorDate: Thu Apr 17 23:22:45 2025 -0500

    [multistage] Bug Fixes and Refactorings Based on E2E Physical Optimizer 
Test (#15561)
---
 .../pinot/calcite/rel/traits/TraitAssignment.java  | 21 ++++--
 .../planner/physical/v2/ExchangeStrategy.java      |  6 +-
 .../planner/physical/v2/HashDistributionDesc.java  | 22 +++---
 .../pinot/query/planner/physical/v2/PRelNode.java  |  2 +-
 .../planner/physical/v2/PinotDataDistribution.java |  6 +-
 .../planner/physical/v2/RelToPRelConverter.java    | 11 +--
 .../physical/v2/mapping/DistMappingGenerator.java  |  6 +-
 .../physical/v2/mapping/PinotDistMapping.java      | 74 ++++++++++++++------
 .../physical/v2/nodes/PhysicalAggregate.java       | 16 +++++
 .../physical/v2/nodes/PhysicalExchange.java        | 77 +++++++++++++++++---
 .../query/planner/physical/v2/opt/PRelOptRule.java |  6 ++
 .../physical/v2/opt/PhysicalOptRuleSet.java        | 25 ++++---
 .../v2/opt/rules/LeafStageAggregateRule.java       |  4 +-
 .../opt/rules/LeafStageWorkerAssignmentRule.java   | 19 +++--
 .../physical/v2/mapping/PinotDistMappingTest.java  | 81 ++++++++++++++--------
 15 files changed, 266 insertions(+), 110 deletions(-)

diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/traits/TraitAssignment.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/traits/TraitAssignment.java
index f2ba392129..1ebaf8186f 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/traits/TraitAssignment.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/traits/TraitAssignment.java
@@ -34,7 +34,6 @@ import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.JoinInfo;
 import org.apache.calcite.rel.core.Window;
 import org.apache.pinot.calcite.rel.hint.PinotHintOptions;
-import org.apache.pinot.calcite.rel.rules.PinotRuleUtils;
 import org.apache.pinot.query.context.PhysicalPlannerContext;
 import org.apache.pinot.query.planner.physical.v2.PRelNode;
 import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalAggregate;
@@ -113,16 +112,23 @@ public class TraitAssignment {
     }
     // Case-2: Handle dynamic filter for semi joins.
     JoinInfo joinInfo = join.analyzeCondition();
-    if (join.isSemiJoin() && joinInfo.nonEquiConditions.isEmpty() && 
joinInfo.leftKeys.size() == 1) {
+    /* if (join.isSemiJoin() && joinInfo.nonEquiConditions.isEmpty() && 
joinInfo.leftKeys.size() == 1) {
       if (PinotRuleUtils.canPushDynamicBroadcastToLeaf(join.getLeft())) {
         return assignDynamicFilterSemiJoin(join);
       }
-    }
+    } */
+    Preconditions.checkState(joinInfo.leftKeys.size() == 
joinInfo.rightKeys.size(),
+        "Always expect left and right keys to be same size. Found: %s and %s",
+        joinInfo.leftKeys, joinInfo.rightKeys);
     // Case-3: Default case.
-    RelDistribution leftDistribution = joinInfo.leftKeys.isEmpty() ? 
RelDistributions.RANDOM_DISTRIBUTED
-        : RelDistributions.hash(joinInfo.leftKeys);
-    RelDistribution rightDistribution = joinInfo.rightKeys.isEmpty() ? 
RelDistributions.BROADCAST_DISTRIBUTED
-        : RelDistributions.hash(joinInfo.rightKeys);
+    RelDistribution rightDistribution = joinInfo.isEqui() && 
!joinInfo.rightKeys.isEmpty()
+        ? RelDistributions.hash(joinInfo.rightKeys) : 
RelDistributions.BROADCAST_DISTRIBUTED;
+    RelDistribution leftDistribution;
+    if (joinInfo.leftKeys.isEmpty() || rightDistribution == 
RelDistributions.BROADCAST_DISTRIBUTED) {
+      leftDistribution = RelDistributions.RANDOM_DISTRIBUTED;
+    } else {
+      leftDistribution = RelDistributions.hash(joinInfo.leftKeys);
+    }
     // left-input
     RelNode leftInput = join.getInput(0);
     RelTraitSet leftTraitSet = leftInput.getTraitSet().plus(leftDistribution);
@@ -233,6 +239,7 @@ public class TraitAssignment {
     return join.copy(join.getTraitSet(), ImmutableList.of(leftInput, 
newProject));
   }
 
+  @SuppressWarnings("unused")
   private RelNode assignDynamicFilterSemiJoin(PhysicalJoin join) {
     /*
      * When dynamic broadcast is enabled, push broadcast trait to right input 
along with the pipeline breaker
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/ExchangeStrategy.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/ExchangeStrategy.java
index dc0e97fd8b..82441f47aa 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/ExchangeStrategy.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/ExchangeStrategy.java
@@ -59,5 +59,9 @@ public enum ExchangeStrategy {
   /**
    * Each stream will send data to all receiving streams.
    */
-  BROADCAST_EXCHANGE
+  BROADCAST_EXCHANGE,
+  /**
+   * Records are sent randomly from a given worker in the sender to some 
worker in the receiver.
+   */
+  RANDOM_EXCHANGE
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/HashDistributionDesc.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/HashDistributionDesc.java
index 7c2284cce5..dc373b4a88 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/HashDistributionDesc.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/HashDistributionDesc.java
@@ -18,10 +18,13 @@
  */
 package org.apache.pinot.query.planner.physical.v2;
 
+import java.util.ArrayDeque;
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
-import javax.annotation.Nullable;
+import java.util.Set;
+import org.apache.commons.collections4.CollectionUtils;
 import org.apache.pinot.query.planner.physical.v2.mapping.PinotDistMapping;
 
 
@@ -62,18 +65,19 @@ public class HashDistributionDesc {
    * Returns the hash distribution descriptor for the given target mapping, or 
{@code null} if we can't preserve
    * partitioning info.
    */
-  @Nullable
-  public HashDistributionDesc apply(PinotDistMapping mapping) {
+  public Set<HashDistributionDesc> apply(PinotDistMapping mapping) {
     for (Integer currentKey : _keys) {
-      if (currentKey >= mapping.getSourceCount() || 
mapping.getTarget(currentKey) == -1) {
-        return null;
+      if (currentKey >= mapping.getSourceCount() || 
CollectionUtils.isEmpty(mapping.getTargets(currentKey))) {
+        return Set.of();
       }
     }
-    List<Integer> newKey = new ArrayList<>();
-    for (int currentKey : _keys) {
-      newKey.add(mapping.getTarget(currentKey));
+    List<List<Integer>> newKeys = new ArrayList<>();
+    PinotDistMapping.computeAllMappings(0, _keys, mapping, new ArrayDeque<>(), 
newKeys);
+    Set<HashDistributionDesc> newDescs = new HashSet<>();
+    for (List<Integer> newKey : newKeys) {
+      newDescs.add(new HashDistributionDesc(newKey, _hashFunction, 
_numPartitions));
     }
-    return new HashDistributionDesc(newKey, _hashFunction, _numPartitions);
+    return newDescs;
   }
 
   @Override
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelNode.java
index 177870cbd7..12773ec055 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelNode.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelNode.java
@@ -105,7 +105,7 @@ public interface PRelNode {
   }
 
   default PRelNode with(List<PRelNode> newInputs) {
-    return with(getNodeId(), newInputs, getPinotDataDistributionOrThrow());
+    return with(getNodeId(), newInputs, getPinotDataDistribution());
   }
 
   default PRelNode asLeafStage() {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PinotDataDistribution.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PinotDataDistribution.java
index 36002cb32f..cd1cd99a98 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PinotDataDistribution.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PinotDataDistribution.java
@@ -163,10 +163,8 @@ public class PinotDataDistribution {
     }
     Set<HashDistributionDesc> newHashDesc = new HashSet<>();
     for (HashDistributionDesc desc : _hashDistributionDesc) {
-      HashDistributionDesc newDescs = desc.apply(mapping);
-      if (newDescs != null) {
-        newHashDesc.add(newDescs);
-      }
+      Set<HashDistributionDesc> newDescs = desc.apply(mapping);
+      newHashDesc.addAll(newDescs);
     }
     RelDistribution.Type newType = _type;
     if (newType == RelDistribution.Type.HASH_DISTRIBUTED && 
newHashDesc.isEmpty()) {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/RelToPRelConverter.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/RelToPRelConverter.java
index 6b80fd4d38..67a24c75fe 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/RelToPRelConverter.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/RelToPRelConverter.java
@@ -45,10 +45,7 @@ import 
org.apache.pinot.query.planner.physical.v2.nodes.PhysicalTableScan;
 import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalUnion;
 import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalValues;
 import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalWindow;
-import org.apache.pinot.query.planner.physical.v2.opt.PRelOptRule;
 import org.apache.pinot.query.planner.physical.v2.opt.PhysicalOptRuleSet;
-import org.apache.pinot.query.planner.physical.v2.opt.RuleExecutor;
-import org.apache.pinot.query.planner.physical.v2.opt.RuleExecutors;
 import org.apache.pinot.query.planner.plannode.AggregateNode;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -69,11 +66,9 @@ public class RelToPRelConverter {
     // Step-2: Assign traits
     rootPRelNode = TraitAssignment.assign(rootPRelNode, context);
     // Step-3: Run physical optimizer rules.
-    var ruleAndExecutorList = PhysicalOptRuleSet.create(context, tableCache);
-    for (var ruleAndExecutor : ruleAndExecutorList) {
-      PRelOptRule rule = ruleAndExecutor.getLeft();
-      RuleExecutor executor = RuleExecutors.create(ruleAndExecutor.getRight(), 
rule, context);
-      rootPRelNode = executor.execute(rootPRelNode);
+    var pRelTransformers = PhysicalOptRuleSet.create(context, tableCache);
+    for (var pRelTransformer : pRelTransformers) {
+      rootPRelNode = pRelTransformer.execute(rootPRelNode);
     }
     return rootPRelNode;
   }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/DistMappingGenerator.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/DistMappingGenerator.java
index 39413a0268..6e1195671f 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/DistMappingGenerator.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/DistMappingGenerator.java
@@ -59,7 +59,7 @@ public class DistMappingGenerator {
       PinotDistMapping mapping = new 
PinotDistMapping(source.getRowType().getFieldCount());
       List<Integer> groupSet = aggregate.getGroupSet().asList();
       for (int j = 0; j < groupSet.size(); j++) {
-        mapping.set(groupSet.get(j), j);
+        mapping.add(groupSet.get(j), j);
       }
       return mapping;
     } else if (destination instanceof Join) {
@@ -72,7 +72,7 @@ public class DistMappingGenerator {
       }
       PinotDistMapping mapping = new 
PinotDistMapping(source.getRowType().getFieldCount());
       for (int i = 0; i < mapping.getSourceCount(); i++) {
-        mapping.set(i, i + leftFieldCount);
+        mapping.add(i, i + leftFieldCount);
       }
       return mapping;
     } else if (destination instanceof Filter) {
@@ -100,7 +100,7 @@ public class DistMappingGenerator {
     for (RexNode rexNode : project.getProjects()) {
       if (rexNode instanceof RexInputRef) {
         RexInputRef rexInputRef = (RexInputRef) rexNode;
-        mapping.set(rexInputRef.getIndex(), indexInCurrentRelNode);
+        mapping.add(rexInputRef.getIndex(), indexInCurrentRelNode);
       }
       indexInCurrentRelNode++;
     }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMapping.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMapping.java
index 1a083e0314..847795b5ad 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMapping.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMapping.java
@@ -19,11 +19,16 @@
 package org.apache.pinot.query.planner.physical.v2.mapping;
 
 import com.google.common.base.Preconditions;
+import java.util.ArrayDeque;
 import java.util.ArrayList;
-import java.util.Collections;
+import java.util.Deque;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import org.apache.calcite.rel.RelCollation;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.RelFieldCollation;
+import org.apache.commons.collections4.CollectionUtils;
 
 
 /**
@@ -31,21 +36,20 @@ import java.util.Map;
  * RelNode pair and is used to track how input fields are mapped to output 
fields.
  */
 public class PinotDistMapping {
-  private static final int DEFAULT_MAPPING_VALUE = -1;
   private final int _sourceCount;
-  private final Map<Integer, Integer> _sourceToTargetMapping = new HashMap<>();
+  private final Map<Integer, List<Integer>> _sourceToTargetMapping = new 
HashMap<>();
 
   public PinotDistMapping(int sourceCount) {
     _sourceCount = sourceCount;
     for (int i = 0; i < sourceCount; i++) {
-      _sourceToTargetMapping.put(i, DEFAULT_MAPPING_VALUE);
+      _sourceToTargetMapping.put(i, new ArrayList<>());
     }
   }
 
   public static PinotDistMapping identity(int sourceCount) {
     PinotDistMapping mapping = new PinotDistMapping(sourceCount);
     for (int i = 0; i < sourceCount; i++) {
-      mapping.set(i, i);
+      mapping.add(i, i);
     }
     return mapping;
   }
@@ -54,29 +58,57 @@ public class PinotDistMapping {
     return _sourceCount;
   }
 
-  public int getTarget(int source) {
+  public List<Integer> getTargets(int source) {
     Preconditions.checkArgument(source >= 0 && source < _sourceCount, "Invalid 
source index: %s", source);
-    Integer target = _sourceToTargetMapping.get(source);
-    return target == null ? DEFAULT_MAPPING_VALUE : target;
+    List<Integer> target = _sourceToTargetMapping.get(source);
+    return target == null ? List.of() : target;
   }
 
-  public void set(int source, int target) {
+  public void add(int source, int target) {
     Preconditions.checkArgument(source >= 0 && source < _sourceCount, "Invalid 
source index: %s", source);
-    _sourceToTargetMapping.put(source, target);
+    _sourceToTargetMapping.computeIfAbsent(source, (x) -> new 
ArrayList<>()).add(target);
   }
 
-  public List<Integer> getMappedKeys(List<Integer> existingKeys) {
-    List<Integer> result = new ArrayList<>(existingKeys.size());
-    for (int key : existingKeys) {
-      Integer mappedKey = _sourceToTargetMapping.get(key);
-      Preconditions.checkArgument(mappedKey != null,
-          "Key %s not found in mapping with source count: %s", key, 
_sourceCount);
-      if (mappedKey != DEFAULT_MAPPING_VALUE) {
-        result.add(mappedKey);
-      } else {
-        return Collections.emptyList();
+  public List<List<Integer>> getMappedKeys(List<Integer> existingKeys) {
+    List<List<Integer>> result = new ArrayList<>();
+    computeAllMappings(0, existingKeys, this, new ArrayDeque<>(), result);
+    return result;
+  }
+
+  public static RelCollation apply(RelCollation relCollation, PinotDistMapping 
mapping) {
+    if (relCollation.getKeys().isEmpty()) {
+      return relCollation;
+    }
+    List<RelFieldCollation> newFieldCollations = new ArrayList<>();
+    for (RelFieldCollation fieldCollation : relCollation.getFieldCollations()) 
{
+      List<Integer> newFieldIndices = 
mapping.getTargets(fieldCollation.getFieldIndex());
+      if (CollectionUtils.isEmpty(newFieldIndices)) {
+        break;
       }
+      
newFieldCollations.add(fieldCollation.withFieldIndex(newFieldIndices.get(0)));
+    }
+    return RelCollations.of(newFieldCollations);
+  }
+
+  /**
+   * Consider a node that is partitioned on the key: [1]. If there's a project 
node on top of this, with project
+   * expressions as: [RexInputRef#0, RexInputRef#1, RexInputRef#1], then the 
project node will have two hash
+   * distribution descriptors: [1] and [2]. This method computes all such 
mappings for the given key indexes.
+   * <p>
+   *   This is a common occurrence in Calcite plans.
+   * </p>
+   */
+  public static void computeAllMappings(int index, List<Integer> currentKey, 
PinotDistMapping mapping,
+      Deque<Integer> runningKey, List<List<Integer>> newKeysSink) {
+    if (index == currentKey.size()) {
+      newKeysSink.add(new ArrayList<>(runningKey));
+      return;
+    }
+    List<Integer> possibilities = mapping.getTargets(currentKey.get(index));
+    for (int currentKeyPossibility : possibilities) {
+      runningKey.addLast(currentKeyPossibility);
+      computeAllMappings(index + 1, currentKey, mapping, runningKey, 
newKeysSink);
+      runningKey.removeLast();
     }
-    return result;
   }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalAggregate.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalAggregate.java
index 4f7d026996..6c2353549f 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalAggregate.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalAggregate.java
@@ -113,4 +113,20 @@ public class PhysicalAggregate extends Aggregate 
implements PRelNode {
         getAggCallList(), _nodeId, _pRelInputs.get(0), _pinotDataDistribution, 
true, _aggType, _leafReturnFinalResult,
         _collations, _limit);
   }
+
+  public AggregateNode.AggType getAggType() {
+    return _aggType;
+  }
+
+  public boolean isLeafReturnFinalResult() {
+    return _leafReturnFinalResult;
+  }
+
+  public List<RelFieldCollation> getCollations() {
+    return _collations;
+  }
+
+  public int getLimit() {
+    return _limit;
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalExchange.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalExchange.java
index 8571f97568..3128782a6c 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalExchange.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/nodes/PhysicalExchange.java
@@ -22,13 +22,17 @@ import com.google.common.base.Preconditions;
 import java.util.Collections;
 import java.util.List;
 import javax.annotation.Nullable;
-import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelCollation;
 import org.apache.calcite.rel.RelCollations;
 import org.apache.calcite.rel.RelDistribution;
+import org.apache.calcite.rel.RelDistributions;
 import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.RelWriter;
 import org.apache.calcite.rel.core.Exchange;
+import org.apache.pinot.calcite.rel.logical.PinotRelExchangeType;
+import org.apache.pinot.calcite.rel.traits.PinotExecStrategyTrait;
+import org.apache.pinot.calcite.rel.traits.PinotExecStrategyTraitDef;
 import org.apache.pinot.query.planner.physical.v2.ExchangeStrategy;
 import org.apache.pinot.query.planner.physical.v2.PRelNode;
 import org.apache.pinot.query.planner.physical.v2.PinotDataDistribution;
@@ -73,10 +77,11 @@ public class PhysicalExchange extends Exchange implements 
PRelNode {
    */
   private final RelCollation _relCollation;
 
-  public PhysicalExchange(RelOptCluster cluster, RelDistribution distribution,
-      int nodeId, PRelNode input, @Nullable PinotDataDistribution 
pinotDataDistribution,
-      List<Integer> distributionKeys, ExchangeStrategy exchangeStrategy, 
@Nullable RelCollation relCollation) {
-    super(cluster, EMPTY_TRAIT_SET, input.unwrap(), distribution);
+  public PhysicalExchange(int nodeId, PRelNode input, @Nullable 
PinotDataDistribution pinotDataDistribution,
+      List<Integer> distributionKeys, ExchangeStrategy exchangeStrategy, 
@Nullable RelCollation relCollation,
+      PinotExecStrategyTrait execStrategyTrait) {
+    super(input.unwrap().getCluster(), 
EMPTY_TRAIT_SET.plus(execStrategyTrait), input.unwrap(),
+        getRelDistribution(exchangeStrategy, distributionKeys));
     _nodeId = nodeId;
     _pRelInputs = Collections.singletonList(input);
     _pinotDataDistribution = pinotDataDistribution;
@@ -89,8 +94,8 @@ public class PhysicalExchange extends Exchange implements 
PRelNode {
   public Exchange copy(RelTraitSet traitSet, RelNode newInput, RelDistribution 
newDistribution) {
     Preconditions.checkState(newInput instanceof PRelNode, "Expected input of 
PhysicalExchange to be a PRelNode");
     Preconditions.checkState(traitSet.isEmpty(), "Expected empty trait set for 
PhysicalExchange");
-    return new PhysicalExchange(getCluster(), newDistribution, _nodeId, 
(PRelNode) newInput, _pinotDataDistribution,
-        _distributionKeys, _exchangeStrategy, _relCollation);
+    return new PhysicalExchange(_nodeId, (PRelNode) newInput, 
_pinotDataDistribution, _distributionKeys,
+        _exchangeStrategy, _relCollation, 
PinotExecStrategyTrait.getDefaultExecStrategy());
   }
 
   @Override
@@ -119,9 +124,63 @@ public class PhysicalExchange extends Exchange implements 
PRelNode {
     return false;
   }
 
+  public List<Integer> getDistributionKeys() {
+    return _distributionKeys;
+  }
+
+  public ExchangeStrategy getExchangeStrategy() {
+    return _exchangeStrategy;
+  }
+
+  public RelCollation getRelCollation() {
+    return _relCollation;
+  }
+
+  public PinotExecStrategyTrait getExecStrategy() {
+    PinotExecStrategyTrait trait = 
traitSet.getTrait(PinotExecStrategyTraitDef.INSTANCE);
+    if (trait == null) {
+      return PinotExecStrategyTrait.getDefaultExecStrategy();
+    }
+    return trait;
+  }
+
+  public PinotRelExchangeType getRelExchangeType() {
+    PinotExecStrategyTrait trait = 
traitSet.getTrait(PinotExecStrategyTraitDef.INSTANCE);
+    if (trait == null) {
+      return PinotExecStrategyTrait.getDefaultExecStrategy().getType();
+    }
+    return trait.getType();
+  }
+
   @Override
   public PRelNode with(int newNodeId, List<PRelNode> newInputs, 
PinotDataDistribution newDistribution) {
-    return new PhysicalExchange(getCluster(), getDistribution(), newNodeId, 
newInputs.get(0), newDistribution,
-        _distributionKeys, _exchangeStrategy, _relCollation);
+    return new PhysicalExchange(newNodeId, newInputs.get(0), newDistribution, 
_distributionKeys, _exchangeStrategy,
+        _relCollation, PinotExecStrategyTrait.getDefaultExecStrategy());
+  }
+
+  @Override public RelWriter explainTerms(RelWriter pw) {
+    return pw.item("input", input)
+        .item("exchangeStrategy", _exchangeStrategy)
+        .item("distKeys", _distributionKeys)
+        .item("execStrategy", getRelExchangeType())
+        .item("collation", _relCollation);
+  }
+
+  private static RelDistribution getRelDistribution(ExchangeStrategy 
exchangeStrategy, List<Integer> keys) {
+    switch (exchangeStrategy) {
+      case IDENTITY_EXCHANGE:
+      case PARTITIONING_EXCHANGE:
+      case SUB_PARTITIONING_HASH_EXCHANGE:
+      case COALESCING_PARTITIONING_EXCHANGE:
+        return RelDistributions.hash(keys);
+      case BROADCAST_EXCHANGE:
+        return RelDistributions.BROADCAST_DISTRIBUTED;
+      case SINGLETON_EXCHANGE:
+        return RelDistributions.SINGLETON;
+      case SUB_PARTITIONING_RR_EXCHANGE:
+        return RelDistributions.ROUND_ROBIN_DISTRIBUTED;
+      default:
+        throw new IllegalStateException(String.format("Unexpected exchange 
strategy: %s", exchangeStrategy));
+    }
   }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/PRelOptRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/PRelOptRule.java
index e6dcb4fbca..84504f6238 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/PRelOptRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/PRelOptRule.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.query.planner.physical.v2.opt;
 
+import javax.annotation.Nullable;
 import org.apache.pinot.query.planner.physical.v2.PRelNode;
 
 
@@ -43,4 +44,9 @@ public abstract class PRelOptRule {
   public PRelNode onDone(PRelNode currentNode) {
     return currentNode;
   }
+
+  @Nullable
+  public PRelNode getParentNode(PRelOptRuleCall call) {
+    return call._parents.isEmpty() ? null : call._parents.getLast();
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/PhysicalOptRuleSet.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/PhysicalOptRuleSet.java
index 9db842fabf..76ba9964db 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/PhysicalOptRuleSet.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/PhysicalOptRuleSet.java
@@ -18,10 +18,11 @@
  */
 package org.apache.pinot.query.planner.physical.v2.opt;
 
+import java.util.ArrayList;
 import java.util.List;
-import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.config.provider.TableCache;
 import org.apache.pinot.query.context.PhysicalPlannerContext;
+import 
org.apache.pinot.query.planner.physical.v2.opt.rules.LeafStageAggregateRule;
 import 
org.apache.pinot.query.planner.physical.v2.opt.rules.LeafStageBoundaryRule;
 import 
org.apache.pinot.query.planner.physical.v2.opt.rules.LeafStageWorkerAssignmentRule;
 
@@ -30,13 +31,19 @@ public class PhysicalOptRuleSet {
   private PhysicalOptRuleSet() {
   }
 
-  public static List<Pair<PRelOptRule, RuleExecutors.Type>> 
create(PhysicalPlannerContext context,
-      TableCache tableCache) {
-    return List.of(
-        Pair.of(LeafStageBoundaryRule.INSTANCE, RuleExecutors.Type.POST_ORDER),
-        Pair.of(new LeafStageWorkerAssignmentRule(context, tableCache), 
RuleExecutors.Type.POST_ORDER));
-        // Pair.of(new WorkerExchangeAssignmentRule(context), 
RuleExecutors.Type.IN_ORDER),
-        // Pair.of(AggregatePushdownRule.INSTANCE, 
RuleExecutors.Type.POST_ORDER),
-        // Pair.of(SortPushdownRule.INSTANCE, RuleExecutors.Type.POST_ORDER));
+  public static List<PRelNodeTransformer> create(PhysicalPlannerContext 
context, TableCache tableCache) {
+    List<PRelNodeTransformer> transformers = new ArrayList<>();
+    transformers.add(create(LeafStageBoundaryRule.INSTANCE, 
RuleExecutors.Type.POST_ORDER, context));
+    transformers.add(create(new LeafStageWorkerAssignmentRule(context, 
tableCache), RuleExecutors.Type.POST_ORDER,
+        context));
+    transformers.add(create(new LeafStageAggregateRule(context), 
RuleExecutors.Type.POST_ORDER, context));
+    // transformers.add(new WorkerExchangeAssignmentRule(context));
+    // transformers.add(create(new AggregatePushdownRule(context), 
RuleExecutors.Type.POST_ORDER, context));
+    // transformers.add(create(new SortPushdownRule(context), 
RuleExecutors.Type.POST_ORDER, context));
+    return transformers;
+  }
+
+  private static PRelNodeTransformer create(PRelOptRule rule, 
RuleExecutors.Type type, PhysicalPlannerContext context) {
+    return (pRelNode) -> RuleExecutors.create(type, rule, 
context).execute(pRelNode);
   }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageAggregateRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageAggregateRule.java
index e8121b063b..7506541503 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageAggregateRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageAggregateRule.java
@@ -59,7 +59,7 @@ public class LeafStageAggregateRule extends PRelOptRule {
     if (!(currentNode.unwrap() instanceof Aggregate)) {
       return false;
     }
-    if (!isProjectFilterOrScan(currentNode.getPRelInput(0).unwrap())) {
+    if (!currentNode.getPRelInput(0).isLeafStage() || 
!isProjectFilterOrScan(currentNode.getPRelInput(0).unwrap())) {
       return false;
     }
     // ==> We have: "aggregate (non-leaf stage) > project|filter|table-scan 
(leaf-stage)"
@@ -80,7 +80,7 @@ public class LeafStageAggregateRule extends PRelOptRule {
         currentNode.unwrap(), null);
     PinotDataDistribution derivedDistribution = 
currentNode.getPRelInput(0).getPinotDataDistributionOrThrow()
         .apply(mapping);
-    return currentNode.with(currentNode.getPRelInputs(), derivedDistribution);
+    return currentNode.with(currentNode.getPRelInputs(), 
derivedDistribution).asLeafStage();
   }
 
   private static boolean isPartitionedByHintPresent(PhysicalAggregate aggRel) {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageWorkerAssignmentRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageWorkerAssignmentRule.java
index 3295e2cf37..20c891031d 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageWorkerAssignmentRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/opt/rules/LeafStageWorkerAssignmentRule.java
@@ -186,8 +186,7 @@ public class LeafStageWorkerAssignmentRule extends 
PRelOptRule {
         String tableType = partitionedTableTypes.iterator().next();
         String tableNameWithType = 
TableNameBuilder.forType(TableType.valueOf(tableType)).tableNameWithType(tableName);
         TableScanWorkerAssignmentResult assignmentResult = 
attemptPartitionedDistribution(tableNameWithType,
-            fieldNames, 
instanceIdToSegments.getSegmentsMap(TableType.valueOf(tableType)),
-            tpiMap.get(tableType));
+            fieldNames, 
instanceIdToSegments.getSegmentsMap(TableType.valueOf(tableType)), 
tpiMap.get(tableType));
         if (assignmentResult != null) {
           return assignmentResult;
         }
@@ -220,8 +219,10 @@ public class LeafStageWorkerAssignmentRule extends 
PRelOptRule {
       }
       workers.add(String.format("%s@%s", workers.size(), instanceId));
     }
-    PinotDataDistribution pinotDataDistribution = new 
PinotDataDistribution(RelDistribution.Type.RANDOM_DISTRIBUTED,
-        workers, workers.hashCode(), null, null);
+    RelDistribution.Type distType = workers.size() == 1 ? 
RelDistribution.Type.SINGLETON
+        : RelDistribution.Type.RANDOM_DISTRIBUTED;
+    PinotDataDistribution pinotDataDistribution = new 
PinotDataDistribution(distType, workers, workers.hashCode(),
+        null, null);
     return new TableScanWorkerAssignmentResult(pinotDataDistribution, 
workerIdToSegmentsMap);
   }
 
@@ -233,7 +234,10 @@ public class LeafStageWorkerAssignmentRule extends 
PRelOptRule {
   @VisibleForTesting
   static TableScanWorkerAssignmentResult attemptPartitionedDistribution(String 
tableNameWithType,
       List<String> fieldNames, Map<String, List<String>> 
instanceIdToSegmentsMap,
-      TablePartitionInfo tablePartitionInfo) {
+      @Nullable TablePartitionInfo tablePartitionInfo) {
+    if (tablePartitionInfo == null) {
+      return null;
+    }
     if 
(CollectionUtils.isNotEmpty(tablePartitionInfo.getSegmentsWithInvalidPartition()))
 {
       LOGGER.warn("Table {} has {} segments with invalid partition info. Will 
assume un-partitioned distribution",
           tableNameWithType, 
tablePartitionInfo.getSegmentsWithInvalidPartition().size());
@@ -252,6 +256,9 @@ public class LeafStageWorkerAssignmentRule extends 
PRelOptRule {
       return null;
     } else if (numPartitions < numSelectedServers) {
       return null;
+    } else if (numSelectedServers == 1) {
+      // ==> scan will have a single stream, so partitioned distribution 
doesn't matter
+      return null;
     }
     // Pre-compute segmentToServer map for quick lookup later.
     Map<String, String> segmentToServer = new HashMap<>();
@@ -332,7 +339,7 @@ public class LeafStageWorkerAssignmentRule extends 
PRelOptRule {
       String realtimeTableType = 
TableNameBuilder.REALTIME.tableNameWithType(tableName);
       TablePartitionInfo tablePartitionInfo = 
_routingManager.getTablePartitionInfo(realtimeTableType);
       if (tablePartitionInfo != null) {
-        result.put("REALTIME", 
_routingManager.getTablePartitionInfo(tableName));
+        result.put("REALTIME", tablePartitionInfo);
       }
     }
     return result;
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMappingTest.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMappingTest.java
index bd537c4be0..a2c7516864 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMappingTest.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/v2/mapping/PinotDistMappingTest.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.query.planner.physical.v2.mapping;
 
+import java.util.Collections;
 import java.util.List;
 import org.testng.annotations.Test;
 
@@ -31,7 +32,7 @@ public class PinotDistMappingTest {
     PinotDistMapping mapping = PinotDistMapping.identity(10);
     assertEquals(mapping.getSourceCount(), 10);
     for (int i = 0; i < 10; i++) {
-      assertEquals(mapping.getTarget(i), i);
+      assertEquals(mapping.getTargets(i), List.of(i));
     }
     // Test getMappedKeys always returns the same values as the input.
     List<List<Integer>> testKeys = List.of(
@@ -42,7 +43,7 @@ public class PinotDistMappingTest {
         List.of(4, 2, 9)
     );
     for (List<Integer> keys : testKeys) {
-      assertEquals(mapping.getMappedKeys(keys), keys);
+      assertEquals(mapping.getMappedKeys(keys), List.of(keys));
     }
   }
 
@@ -50,10 +51,10 @@ public class PinotDistMappingTest {
   public void testOutOfBoundsSource() {
     // When the passed source index is out of bounds wrt the sourceCount in 
the mapping, we should get an exception.
     PinotDistMapping mapping = new PinotDistMapping(5);
-    assertThrows(IllegalArgumentException.class, () -> mapping.getTarget(-1));
-    assertThrows(IllegalArgumentException.class, () -> mapping.getTarget(5));
-    assertThrows(IllegalArgumentException.class, () -> mapping.set(-1, 2));
-    assertThrows(IllegalArgumentException.class, () -> mapping.set(5, 2));
+    assertThrows(IllegalArgumentException.class, () -> mapping.getTargets(-1));
+    assertThrows(IllegalArgumentException.class, () -> mapping.getTargets(5));
+    assertThrows(IllegalArgumentException.class, () -> mapping.add(-1, 2));
+    assertThrows(IllegalArgumentException.class, () -> mapping.add(5, 2));
     assertThrows(IllegalArgumentException.class, () -> 
mapping.getMappedKeys(List.of(5)));
   }
 
@@ -61,25 +62,25 @@ public class PinotDistMappingTest {
   public void testSet() {
     // Test setting a mapping value.
     PinotDistMapping mapping = new PinotDistMapping(5);
-    mapping.set(0, 2);
-    assertEquals(mapping.getTarget(0), 2);
-    assertEquals(mapping.getTarget(1), -1);
-    assertEquals(mapping.getTarget(2), -1);
-    assertEquals(mapping.getTarget(3), -1);
-    assertEquals(mapping.getTarget(4), -1);
+    mapping.add(0, 2);
+    assertEquals(mapping.getTargets(0), List.of(2));
+    assertEquals(mapping.getTargets(1), Collections.emptyList());
+    assertEquals(mapping.getTargets(2), Collections.emptyList());
+    assertEquals(mapping.getTargets(3), Collections.emptyList());
+    assertEquals(mapping.getTargets(4), Collections.emptyList());
 
     // Test setting multiple mapping values.
-    mapping.set(1, 3);
-    mapping.set(2, 4);
-    assertEquals(mapping.getTarget(0), 2);
-    assertEquals(mapping.getTarget(1), 3);
-    assertEquals(mapping.getTarget(2), 4);
-    assertEquals(mapping.getTarget(3), -1);
-    assertEquals(mapping.getTarget(4), -1);
+    mapping.add(1, 3);
+    mapping.add(2, 4);
+    assertEquals(mapping.getTargets(0), List.of(2));
+    assertEquals(mapping.getTargets(1), List.of(3));
+    assertEquals(mapping.getTargets(2), List.of(4));
+    assertEquals(mapping.getTargets(3), Collections.emptyList());
+    assertEquals(mapping.getTargets(4), Collections.emptyList());
 
     // Test setting a mapping value to an invalid index.
-    assertThrows(IllegalArgumentException.class, () -> mapping.set(-1, 2));
-    assertThrows(IllegalArgumentException.class, () -> mapping.set(5, 2));
+    assertThrows(IllegalArgumentException.class, () -> mapping.add(-1, 2));
+    assertThrows(IllegalArgumentException.class, () -> mapping.add(5, 2));
   }
 
   @Test
@@ -87,27 +88,47 @@ public class PinotDistMappingTest {
     {
       // Test when all passed keys are mapped.
       PinotDistMapping mapping = new PinotDistMapping(5);
-      mapping.set(0, 2);
-      mapping.set(1, 3);
-      mapping.set(2, 4);
+      mapping.add(0, 2);
+      mapping.add(1, 3);
+      mapping.add(2, 4);
       List<Integer> keys = List.of(0, 1, 2);
-      List<Integer> expectedMappedKeys = List.of(2, 3, 4);
+      List<List<Integer>> expectedMappedKeys = List.of(List.of(2, 3, 4));
+      assertEquals(mapping.getMappedKeys(keys), expectedMappedKeys);
+    }
+    {
+      // Test when a key is mapped to multiple targets
+      PinotDistMapping mapping = new PinotDistMapping(5);
+      mapping.add(0, 2);
+      mapping.add(1, 3);
+      mapping.add(1, 4);
+      List<Integer> keys = List.of(0, 1);
+      List<List<Integer>> expectedMappedKeys = List.of(List.of(2, 3), 
List.of(2, 4));
       assertEquals(mapping.getMappedKeys(keys), expectedMappedKeys);
     }
     {
       // Test when one of the keys is not mapped
       PinotDistMapping mapping = new PinotDistMapping(5);
-      mapping.set(0, 2);
-      mapping.set(1, 3);
+      mapping.add(0, 2);
+      mapping.add(1, 3);
       List<Integer> keys = List.of(0, 1, 2);
-      List<Integer> expectedMappedKeys = List.of();
+      List<List<Integer>> expectedMappedKeys = List.of();
+      assertEquals(mapping.getMappedKeys(keys), expectedMappedKeys);
+    }
+    {
+      // Test when passed keys are empty
+      PinotDistMapping mapping = new PinotDistMapping(5);
+      mapping.add(0, 2);
+      mapping.add(1, 3);
+      mapping.add(1, 4);
+      List<Integer> keys = List.of();
+      List<List<Integer>> expectedMappedKeys = List.of(List.of());
       assertEquals(mapping.getMappedKeys(keys), expectedMappedKeys);
     }
     {
       // Test getting mapped keys with an invalid key.
       PinotDistMapping mapping = new PinotDistMapping(5);
-      mapping.set(0, 2);
-      mapping.set(1, 3);
+      mapping.add(0, 2);
+      mapping.add(1, 3);
       List<Integer> keys = List.of(0, 1, 5);
       assertThrows(IllegalArgumentException.class, () -> 
mapping.getMappedKeys(keys));
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org


Reply via email to