This is an automated email from the ASF dual-hosted git repository.

siddteotia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new fb4966b0f7 [multistage] carry partition scheme for optimization (#9100)
fb4966b0f7 is described below

commit fb4966b0f79dcd830e20c53305cd0ecacc0f5440
Author: Rong Rong <walterddr.walter...@gmail.com>
AuthorDate: Wed Aug 3 21:04:22 2022 -0700

    [multistage] carry partition scheme for optimization (#9100)
    
    * adding stage planner and mailbox sender to allow SINGELTON to be acutal 
singelton
    
    add partition key generator
    
    * [change default join distribute to hash]
    
    * adding logic to use singleton connection
    
    * fix a bug on non-agg group-by random distribution
    
    * add test
    
    * address pr comments
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 .../pinot/query/planner/logical/StagePlanner.java  | 131 +++++++++++++++++++--
 .../query/planner/stage/AbstractStageNode.java     |  16 +++
 .../pinot/query/planner/stage/AggregateNode.java   |   6 +-
 .../query/planner/stage/MailboxReceiveNode.java    |  11 +-
 .../pinot/query/planner/stage/MailboxSendNode.java |   6 -
 .../pinot/query/planner/stage/StageNode.java       |   6 +
 .../PinotAggregateExchangeNodeInsertRule.java      |   2 +-
 .../rules/PinotJoinExchangeNodeInsertRule.java     |  10 +-
 ...ironmentTest.java => QueryCompilationTest.java} |  76 +++++++++++-
 .../runtime/executor/WorkerQueryExecutor.java      |   6 +-
 .../runtime/operator/MailboxReceiveOperator.java   |  23 +++-
 .../runtime/operator/MailboxSendOperator.java      |  18 ++-
 .../pinot/query/service/QueryDispatcher.java       |   4 +-
 13 files changed, 274 insertions(+), 41 deletions(-)

diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
index cf0b218708..ac30996efa 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
@@ -19,8 +19,10 @@
 package org.apache.pinot.query.planner.logical;
 
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.RelRoot;
@@ -29,9 +31,15 @@ import org.apache.pinot.query.context.PlannerContext;
 import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.planner.StageMetadata;
 import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
+import org.apache.pinot.query.planner.partitioning.KeySelector;
+import org.apache.pinot.query.planner.stage.AggregateNode;
+import org.apache.pinot.query.planner.stage.FilterNode;
+import org.apache.pinot.query.planner.stage.JoinNode;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
 import org.apache.pinot.query.planner.stage.MailboxSendNode;
+import org.apache.pinot.query.planner.stage.ProjectNode;
 import org.apache.pinot.query.planner.stage.StageNode;
+import org.apache.pinot.query.planner.stage.TableScanNode;
 import org.apache.pinot.query.routing.WorkerManager;
 
 
@@ -73,9 +81,9 @@ public class StagePlanner {
     // receiver so doesn't matter what the exchange type is. setting it to 
SINGLETON by default.
     StageNode globalReceiverNode =
         new MailboxReceiveNode(0, globalStageRoot.getDataSchema(), 
globalStageRoot.getStageId(),
-            RelDistribution.Type.SINGLETON);
+            RelDistribution.Type.RANDOM_DISTRIBUTED, null);
     StageNode globalSenderNode = new 
MailboxSendNode(globalStageRoot.getStageId(), globalStageRoot.getDataSchema(),
-        globalReceiverNode.getStageId(), RelDistribution.Type.SINGLETON);
+        globalReceiverNode.getStageId(), 
RelDistribution.Type.RANDOM_DISTRIBUTED, null);
     globalSenderNode.addInput(globalStageRoot);
     _queryStageMap.put(globalSenderNode.getStageId(), globalSenderNode);
     StageMetadata stageMetadata = 
_stageMetadataMap.get(globalSenderNode.getStageId());
@@ -105,17 +113,37 @@ public class StagePlanner {
       RelDistribution.Type exchangeType = distribution.getType();
 
       // 2. make an exchange sender and receiver node pair
-      StageNode mailboxReceiver = new MailboxReceiveNode(currentStageId, 
nextStageRoot.getDataSchema(),
-          nextStageRoot.getStageId(), exchangeType);
-      StageNode mailboxSender = new 
MailboxSendNode(nextStageRoot.getStageId(), nextStageRoot.getDataSchema(),
-          mailboxReceiver.getStageId(), exchangeType, exchangeType == 
RelDistribution.Type.HASH_DISTRIBUTED
-          ? new FieldSelectionKeySelector(distributionKeys) : null);
+      // only HASH_DISTRIBUTED requires a partition key selector; so all other 
types (SINGLETON and BROADCAST)
+      // of exchange will not carry a partition key selector.
+      KeySelector<Object[], Object[]> keySelector = exchangeType == 
RelDistribution.Type.HASH_DISTRIBUTED
+          ? new FieldSelectionKeySelector(distributionKeys) : null;
+
+      StageNode mailboxReceiver;
+      StageNode mailboxSender;
+      if (canSkipShuffle(nextStageRoot, keySelector)) {
+        // Use SINGLETON exchange type indicates a LOCAL-to-LOCAL data 
transfer between execution threads.
+        // TODO: actually implement the SINGLETON exchange without going 
through the over-the-wire GRPC mailbox
+        // sender and receiver.
+        mailboxReceiver = new MailboxReceiveNode(currentStageId, 
nextStageRoot.getDataSchema(),
+            nextStageRoot.getStageId(), RelDistribution.Type.SINGLETON, 
keySelector);
+        mailboxSender = new MailboxSendNode(nextStageRoot.getStageId(), 
nextStageRoot.getDataSchema(),
+            mailboxReceiver.getStageId(), RelDistribution.Type.SINGLETON, 
keySelector);
+      } else {
+        mailboxReceiver = new MailboxReceiveNode(currentStageId, 
nextStageRoot.getDataSchema(),
+            nextStageRoot.getStageId(), exchangeType, keySelector);
+        mailboxSender = new MailboxSendNode(nextStageRoot.getStageId(), 
nextStageRoot.getDataSchema(),
+            mailboxReceiver.getStageId(), exchangeType, keySelector);
+      }
       mailboxSender.addInput(nextStageRoot);
 
       // 3. put the sender side as a completed stage.
       _queryStageMap.put(mailboxSender.getStageId(), mailboxSender);
 
-      // 4. return the receiver (this is considered as a "virtual table scan" 
node for its parent.
+      // 4. update stage metadata.
+      updateStageMetadata(mailboxSender.getStageId(), mailboxSender, 
_stageMetadataMap);
+      updateStageMetadata(mailboxReceiver.getStageId(), mailboxReceiver, 
_stageMetadataMap);
+
+      // 5. return the receiver, this is considered as a "virtual table scan" 
node for its parent.
       return mailboxReceiver;
     } else {
       StageNode stageNode = RelToStageConverter.toStageNode(node, 
currentStageId);
@@ -123,12 +151,95 @@ public class StagePlanner {
       for (RelNode input : inputs) {
         stageNode.addInput(walkRelPlan(input, currentStageId));
       }
-      StageMetadata stageMetadata = 
_stageMetadataMap.computeIfAbsent(currentStageId, (id) -> new StageMetadata());
-      stageMetadata.attach(stageNode);
+      updateStageMetadata(currentStageId, stageNode, _stageMetadataMap);
       return stageNode;
     }
   }
 
+  private boolean canSkipShuffle(StageNode stageNode, KeySelector<Object[], 
Object[]> keySelector) {
+    Set<Integer> originSet = stageNode.getPartitionKeys();
+    if (!originSet.isEmpty() && keySelector != null) {
+      Set<Integer> targetSet = new HashSet<>(((FieldSelectionKeySelector) 
keySelector).getColumnIndices());
+      return targetSet.containsAll(originSet);
+    }
+    return false;
+  }
+
+  private static void updateStageMetadata(int stageId, StageNode node, 
Map<Integer, StageMetadata> stageMetadataMap) {
+    updatePartitionKeys(node);
+    StageMetadata stageMetadata = stageMetadataMap.computeIfAbsent(stageId, 
(id) -> new StageMetadata());
+    stageMetadata.attach(node);
+  }
+
+  private static void updatePartitionKeys(StageNode node) {
+    if (node instanceof ProjectNode) {
+      // any input reference directly carry over should still be a partition 
key.
+      Set<Integer> previousPartitionKeys = 
node.getInputs().get(0).getPartitionKeys();
+      Set<Integer> newPartitionKeys = new HashSet<>();
+      ProjectNode projectNode = (ProjectNode) node;
+      for (int i = 0; i < projectNode.getProjects().size(); i++) {
+        RexExpression rexExpression = projectNode.getProjects().get(i);
+        if (rexExpression instanceof RexExpression.InputRef
+            && previousPartitionKeys.contains(((RexExpression.InputRef) 
rexExpression).getIndex())) {
+          newPartitionKeys.add(i);
+        }
+      }
+      projectNode.setPartitionKeys(newPartitionKeys);
+    } else if (node instanceof FilterNode) {
+      // filter node doesn't change partition keys.
+      node.setPartitionKeys(node.getInputs().get(0).getPartitionKeys());
+    } else if (node instanceof AggregateNode) {
+      // any input reference directly carry over in group set of aggregation 
should still be a partition key.
+      Set<Integer> previousPartitionKeys = 
node.getInputs().get(0).getPartitionKeys();
+      Set<Integer> newPartitionKeys = new HashSet<>();
+      AggregateNode aggregateNode = (AggregateNode) node;
+      for (int i = 0; i < aggregateNode.getGroupSet().size(); i++) {
+        RexExpression rexExpression = aggregateNode.getGroupSet().get(i);
+        if (rexExpression instanceof RexExpression.InputRef
+            && previousPartitionKeys.contains(((RexExpression.InputRef) 
rexExpression).getIndex())) {
+          newPartitionKeys.add(i);
+        }
+      }
+      aggregateNode.setPartitionKeys(newPartitionKeys);
+    } else if (node instanceof JoinNode) {
+      int leftDataSchemaSize = node.getInputs().get(0).getDataSchema().size();
+      Set<Integer> leftPartitionKeys = 
node.getInputs().get(0).getPartitionKeys();
+      Set<Integer> rightPartitionKeys = 
node.getInputs().get(1).getPartitionKeys();
+      // TODO: currently JOIN criteria guarantee to only have one 
FieldSelectionKeySelector. Support more.
+      FieldSelectionKeySelector leftJoinKeySelector =
+          (FieldSelectionKeySelector) ((JoinNode) 
node).getCriteria().get(0).getLeftJoinKeySelector();
+      FieldSelectionKeySelector rightJoinKeySelector =
+          (FieldSelectionKeySelector) ((JoinNode) 
node).getCriteria().get(0).getRightJoinKeySelector();
+      Set<Integer> newPartitionKeys = new HashSet<>();
+      for (int i = 0; i < leftJoinKeySelector.getColumnIndices().size(); i++) {
+        int leftIndex = leftJoinKeySelector.getColumnIndices().get(i);
+        int rightIndex = rightJoinKeySelector.getColumnIndices().get(i);
+        if (leftPartitionKeys.contains(leftIndex)) {
+          newPartitionKeys.add(i);
+        }
+        if (rightPartitionKeys.contains(rightIndex)) {
+          newPartitionKeys.add(leftDataSchemaSize + i);
+        }
+      }
+      node.setPartitionKeys(newPartitionKeys);
+    } else if (node instanceof TableScanNode) {
+      // TODO: add table partition in table config as partition keys. we dont 
have that information yet.
+    } else if (node instanceof MailboxReceiveNode) {
+      // hash distribution key is partition key.
+      FieldSelectionKeySelector keySelector = (FieldSelectionKeySelector)
+          ((MailboxReceiveNode) node).getPartitionKeySelector();
+      if (keySelector != null) {
+        node.setPartitionKeys(new HashSet<>(keySelector.getColumnIndices()));
+      }
+    } else if (node instanceof MailboxSendNode) {
+      FieldSelectionKeySelector keySelector = (FieldSelectionKeySelector)
+          ((MailboxSendNode) node).getPartitionKeySelector();
+      if (keySelector != null) {
+        node.setPartitionKeys(new HashSet<>(keySelector.getColumnIndices()));
+      }
+    }
+  }
+
   private boolean isExchangeNode(RelNode node) {
     return (node instanceof LogicalExchange);
   }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AbstractStageNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AbstractStageNode.java
index 1de069f0b4..594c6d7e38 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AbstractStageNode.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AbstractStageNode.java
@@ -19,7 +19,10 @@
 package org.apache.pinot.query.planner.stage;
 
 import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 import org.apache.pinot.common.proto.Plan;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.serde.ProtoSerializable;
@@ -31,6 +34,7 @@ public abstract class AbstractStageNode implements StageNode, 
ProtoSerializable
   protected final int _stageId;
   protected final List<StageNode> _inputs;
   protected DataSchema _dataSchema;
+  protected Set<Integer> _partitionedKeys;
 
   public AbstractStageNode(int stageId) {
     this(stageId, null);
@@ -40,6 +44,7 @@ public abstract class AbstractStageNode implements StageNode, 
ProtoSerializable
     _stageId = stageId;
     _dataSchema = dataSchema;
     _inputs = new ArrayList<>();
+    _partitionedKeys = new HashSet<>();
   }
 
   @Override
@@ -62,10 +67,21 @@ public abstract class AbstractStageNode implements 
StageNode, ProtoSerializable
     return _dataSchema;
   }
 
+  @Override
   public void setDataSchema(DataSchema dataSchema) {
     _dataSchema = dataSchema;
   }
 
+  @Override
+  public Set<Integer> getPartitionKeys() {
+    return _partitionedKeys;
+  }
+
+  @Override
+  public void setPartitionKeys(Collection<Integer> partitionedKeys) {
+    _partitionedKeys.addAll(partitionedKeys);
+  }
+
   @Override
   public void fromObjectField(Plan.ObjectField objectField) {
     ProtoSerializationUtils.setObjectFieldToObject(this, objectField);
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java
index d0a28b0cbd..b67c5a9500 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java
@@ -19,7 +19,6 @@
 package org.apache.pinot.query.planner.stage;
 
 import java.util.ArrayList;
-import java.util.Iterator;
 import java.util.List;
 import java.util.stream.Collectors;
 import org.apache.calcite.rel.core.AggregateCall;
@@ -43,9 +42,8 @@ public class AggregateNode extends AbstractStageNode {
     super(stageId, dataSchema);
     _aggCalls = 
aggCalls.stream().map(RexExpression::toRexExpression).collect(Collectors.toList());
     _groupSet = new ArrayList<>(groupSet.cardinality());
-    Iterator<Integer> groupSetIt = groupSet.iterator();
-    while (groupSetIt.hasNext()) {
-      _groupSet.add(new RexExpression.InputRef(groupSetIt.next()));
+    for (Integer integer : groupSet) {
+      _groupSet.add(new RexExpression.InputRef(integer));
     }
   }
 
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxReceiveNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxReceiveNode.java
index abba178865..3fa3a55acd 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxReceiveNode.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxReceiveNode.java
@@ -18,8 +18,10 @@
  */
 package org.apache.pinot.query.planner.stage;
 
+import javax.annotation.Nullable;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.partitioning.KeySelector;
 import org.apache.pinot.query.planner.serde.ProtoProperties;
 
 
@@ -28,16 +30,19 @@ public class MailboxReceiveNode extends AbstractStageNode {
   private int _senderStageId;
   @ProtoProperties
   private RelDistribution.Type _exchangeType;
+  @ProtoProperties
+  private KeySelector<Object[], Object[]> _partitionKeySelector;
 
   public MailboxReceiveNode(int stageId) {
     super(stageId);
   }
 
   public MailboxReceiveNode(int stageId, DataSchema dataSchema, int 
senderStageId,
-      RelDistribution.Type exchangeType) {
+      RelDistribution.Type exchangeType, @Nullable KeySelector<Object[], 
Object[]> partitionKeySelector) {
     super(stageId, dataSchema);
     _senderStageId = senderStageId;
     _exchangeType = exchangeType;
+    _partitionKeySelector = partitionKeySelector;
   }
 
   public int getSenderStageId() {
@@ -47,4 +52,8 @@ public class MailboxReceiveNode extends AbstractStageNode {
   public RelDistribution.Type getExchangeType() {
     return _exchangeType;
   }
+
+  public KeySelector<Object[], Object[]> getPartitionKeySelector() {
+    return _partitionKeySelector;
+  }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxSendNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxSendNode.java
index 962dbc73c4..f67f85e573 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxSendNode.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxSendNode.java
@@ -37,12 +37,6 @@ public class MailboxSendNode extends AbstractStageNode {
     super(stageId);
   }
 
-  public MailboxSendNode(int stageId, DataSchema dataSchema, int 
receiverStageId,
-      RelDistribution.Type exchangeType) {
-    // When exchangeType is not HASH_DISTRIBUTE, no partitionKeySelector is 
needed.
-    this(stageId, dataSchema, receiverStageId, exchangeType, null);
-  }
-
   public MailboxSendNode(int stageId, DataSchema dataSchema, int 
receiverStageId,
       RelDistribution.Type exchangeType, @Nullable KeySelector<Object[], 
Object[]> partitionKeySelector) {
     super(stageId, dataSchema);
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNode.java
index 6efa59ce2d..bd69371484 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNode.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNode.java
@@ -19,7 +19,9 @@
 package org.apache.pinot.query.planner.stage;
 
 import java.io.Serializable;
+import java.util.Collection;
 import java.util.List;
+import java.util.Set;
 import org.apache.pinot.common.utils.DataSchema;
 
 
@@ -42,4 +44,8 @@ public interface StageNode extends Serializable {
   DataSchema getDataSchema();
 
   void setDataSchema(DataSchema dataSchema);
+
+  Set<Integer> getPartitionKeys();
+
+  void setPartitionKeys(Collection<Integer> partitionKeys);
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotAggregateExchangeNodeInsertRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotAggregateExchangeNodeInsertRule.java
index 1150bc2085..180137e323 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotAggregateExchangeNodeInsertRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotAggregateExchangeNodeInsertRule.java
@@ -106,7 +106,7 @@ public class PinotAggregateExchangeNodeInsertRule extends 
RelOptRule {
     List<Integer> groupSetIndices = ImmutableIntList.range(0, 
oldAggRel.getGroupCount());
     LogicalExchange exchange = null;
     if (groupSetIndices.size() == 0) {
-      exchange = LogicalExchange.create(newLeafAgg, 
RelDistributions.SINGLETON);
+      exchange = LogicalExchange.create(newLeafAgg, 
RelDistributions.hash(Collections.emptyList()));
     } else {
       exchange = LogicalExchange.create(newLeafAgg, 
RelDistributions.hash(groupSetIndices));
     }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotJoinExchangeNodeInsertRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotJoinExchangeNodeInsertRule.java
index 6aaacccac1..f766afb756 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotJoinExchangeNodeInsertRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotJoinExchangeNodeInsertRule.java
@@ -70,7 +70,12 @@ public class PinotJoinExchangeNodeInsertRule extends 
RelOptRule {
     RelNode leftExchange;
     RelNode rightExchange;
     List<RelHint> hints = join.getHints();
-    if (hints.contains(PinotRelationalHints.USE_HASH_DISTRIBUTE)) {
+    if (hints.contains(PinotRelationalHints.USE_BROADCAST_DISTRIBUTE)) {
+      // TODO: determine which side should be the broadcast table based on 
table metadata
+      // TODO: support SINGLETON exchange if the non-broadcast table size is 
small enough to stay local.
+      leftExchange = LogicalExchange.create(leftInput, 
RelDistributions.RANDOM_DISTRIBUTED);
+      rightExchange = LogicalExchange.create(rightInput, 
RelDistributions.BROADCAST_DISTRIBUTED);
+    } else { // if (hints.contains(PinotRelationalHints.USE_HASH_DISTRIBUTE)) {
       RexCall joinCondition = (RexCall) join.getCondition();
       int leftNodeOffset = join.getLeft().getRowType().getFieldNames().size();
       List<List<Integer>> conditions = 
PlannerUtils.parseJoinConditions(joinCondition, leftNodeOffset);
@@ -78,9 +83,6 @@ public class PinotJoinExchangeNodeInsertRule extends 
RelOptRule {
           RelDistributions.hash(conditions.get(0)));
       rightExchange = LogicalExchange.create(rightInput,
           RelDistributions.hash(conditions.get(1)));
-    } else { // if (hints.contains(PinotRelationalHints.USE_BROADCAST_JOIN))
-      leftExchange = LogicalExchange.create(leftInput, 
RelDistributions.SINGLETON);
-      rightExchange = LogicalExchange.create(rightInput, 
RelDistributions.BROADCAST_DISTRIBUTED);
     }
 
     RelNode newJoinNode =
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTest.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
similarity index 58%
rename from 
pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTest.java
rename to 
pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 143eabf0f0..31c018d158 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTest.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -22,18 +22,26 @@ import com.google.common.collect.ImmutableList;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
+import org.apache.calcite.rel.RelDistribution;
 import org.apache.calcite.sql.SqlNode;
 import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.context.PlannerContext;
 import org.apache.pinot.query.planner.PlannerUtils;
 import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.planner.StageMetadata;
+import org.apache.pinot.query.planner.stage.AbstractStageNode;
+import org.apache.pinot.query.planner.stage.AggregateNode;
+import org.apache.pinot.query.planner.stage.FilterNode;
+import org.apache.pinot.query.planner.stage.JoinNode;
+import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
+import org.apache.pinot.query.planner.stage.ProjectNode;
+import org.apache.pinot.query.planner.stage.StageNode;
 import org.testng.Assert;
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 
-public class QueryEnvironmentTest extends QueryEnvironmentTestBase {
+public class QueryCompilationTest extends QueryEnvironmentTestBase {
 
   @Test(dataProvider = "testQueryParserDataProvider")
   public void testQueryParser(String query, String digest)
@@ -45,7 +53,7 @@ public class QueryEnvironmentTest extends 
QueryEnvironmentTestBase {
   }
 
   @Test(dataProvider = "testQueryDataProvider")
-  public void testQueryToRel(String query)
+  public void testQueryPlanWithoutException(String query)
       throws Exception {
     try {
       QueryPlan queryPlan = _queryEnvironment.planQuery(query);
@@ -65,6 +73,38 @@ public class QueryEnvironmentTest extends 
QueryEnvironmentTestBase {
     }
   }
 
+  @Test
+  public void testQueryGroupByAfterJoinShouldNotDoDataShuffle()
+      throws Exception {
+    String query = "SELECT a.col1, a.col2, AVG(b.col3) FROM a JOIN b ON a.col1 
= b.col2 "
+        + " WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1, 
a.col2";
+    QueryPlan queryPlan = _queryEnvironment.planQuery(query);
+    Assert.assertEquals(queryPlan.getQueryStageMap().size(), 5);
+    Assert.assertEquals(queryPlan.getStageMetadataMap().size(), 5);
+    for (Map.Entry<Integer, StageMetadata> e : 
queryPlan.getStageMetadataMap().entrySet()) {
+      if (e.getValue().getScannedTables().size() == 0 && 
!PlannerUtils.isRootStage(e.getKey())) {
+        StageNode node = queryPlan.getQueryStageMap().get(e.getKey());
+        while (node != null) {
+          if (node instanceof JoinNode) {
+            // JOIN is exchanged with hash distribution (data shuffle)
+            MailboxReceiveNode left = (MailboxReceiveNode) 
node.getInputs().get(0);
+            MailboxReceiveNode right = (MailboxReceiveNode) 
node.getInputs().get(1);
+            Assert.assertEquals(left.getExchangeType(), 
RelDistribution.Type.HASH_DISTRIBUTED);
+            Assert.assertEquals(right.getExchangeType(), 
RelDistribution.Type.HASH_DISTRIBUTED);
+            break;
+          }
+          if (node instanceof AggregateNode && node.getInputs().get(0) 
instanceof MailboxReceiveNode) {
+            // AGG is exchanged with singleton since it has already been 
distributed by JOIN.
+            MailboxReceiveNode input = (MailboxReceiveNode) 
node.getInputs().get(0);
+            Assert.assertEquals(input.getExchangeType(), 
RelDistribution.Type.SINGLETON);
+            break;
+          }
+          node = node.getInputs().get(0);
+        }
+      }
+    }
+  }
+
   @Test
   public void testQueryAndAssertStageContentForJoin()
       throws Exception {
@@ -95,12 +135,38 @@ public class QueryEnvironmentTest extends 
QueryEnvironmentTestBase {
   }
 
   @Test
-  public void testQueryProjectFilterPushdownForJoin() {
+  public void testQueryProjectFilterPushDownForJoin() {
     String query = "SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON 
a.col1 = b.col2 "
         + "WHERE a.col3 >= 0 AND a.col2 IN  ('a', 'b') AND b.col3 < 0";
     QueryPlan queryPlan = _queryEnvironment.planQuery(query);
-    Assert.assertEquals(queryPlan.getQueryStageMap().size(), 4);
-    Assert.assertEquals(queryPlan.getStageMetadataMap().size(), 4);
+    List<StageNode> intermediateStageRoots =
+        queryPlan.getStageMetadataMap().entrySet().stream().filter(e -> 
e.getValue().getScannedTables().size() == 0)
+            .map(e -> 
queryPlan.getQueryStageMap().get(e.getKey())).collect(Collectors.toList());
+    // Assert that no project of filter node for any intermediate stage 
because all should've been pushed down.
+    for (StageNode roots : intermediateStageRoots) {
+      assertNodeTypeNotIn(roots, ImmutableList.of(ProjectNode.class, 
FilterNode.class));
+    }
+  }
+
+  // --------------------------------------------------------------------------
+  // Test Utils.
+  // --------------------------------------------------------------------------
+
+  private static void assertNodeTypeNotIn(StageNode node, List<Class<? extends 
AbstractStageNode>> bannedNodeType) {
+    Assert.assertFalse(isOneOf(bannedNodeType, node));
+    for (StageNode child : node.getInputs()) {
+      assertNodeTypeNotIn(child, bannedNodeType);
+    }
+  }
+
+  private static boolean isOneOf(List<Class<? extends AbstractStageNode>> 
allowedNodeTypes,
+      StageNode node) {
+    for (Class<? extends AbstractStageNode> allowedNodeType : 
allowedNodeTypes) {
+      if (node.getClass() == allowedNodeType) {
+        return true;
+      }
+    }
+    return false;
   }
 
   @DataProvider(name = "testQueryParserDataProvider")
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/WorkerQueryExecutor.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/WorkerQueryExecutor.java
index 9c978f0ad9..7bfe5031e8 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/WorkerQueryExecutor.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/WorkerQueryExecutor.java
@@ -21,7 +21,6 @@ package org.apache.pinot.query.runtime.executor;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ExecutorService;
-import org.apache.calcite.rel.RelDistribution;
 import org.apache.pinot.common.metrics.ServerMetrics;
 import org.apache.pinot.common.proto.Mailbox;
 import org.apache.pinot.core.operator.BaseOperator;
@@ -106,8 +105,9 @@ public class WorkerQueryExecutor {
     if (stageNode instanceof MailboxReceiveNode) {
       MailboxReceiveNode receiveNode = (MailboxReceiveNode) stageNode;
       List<ServerInstance> sendingInstances = 
metadataMap.get(receiveNode.getSenderStageId()).getServerInstances();
-      return new MailboxReceiveOperator(_mailboxService, 
receiveNode.getDataSchema(), RelDistribution.Type.ANY,
-          sendingInstances, _hostName, _port, requestId, 
receiveNode.getSenderStageId());
+      return new MailboxReceiveOperator(_mailboxService, 
receiveNode.getDataSchema(), sendingInstances,
+          receiveNode.getExchangeType(), 
receiveNode.getPartitionKeySelector(), _hostName, _port, requestId,
+          receiveNode.getSenderStageId());
     } else if (stageNode instanceof MailboxSendNode) {
       MailboxSendNode sendNode = (MailboxSendNode) stageNode;
       BaseOperator<TransferableBlock> nextOperator = getOperator(requestId, 
sendNode.getInputs().get(0), metadataMap);
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java
index 91d5de3d8d..81481b6311 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperator.java
@@ -18,7 +18,9 @@
  */
 package org.apache.pinot.query.runtime.operator;
 
+import com.google.common.base.Preconditions;
 import java.nio.ByteBuffer;
+import java.util.Collections;
 import java.util.List;
 import javax.annotation.Nullable;
 import org.apache.calcite.rel.RelDistribution;
@@ -34,6 +36,7 @@ import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.mailbox.ReceivingMailbox;
 import org.apache.pinot.query.mailbox.StringMailboxIdentifier;
+import org.apache.pinot.query.planner.partitioning.KeySelector;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.service.QueryConfig;
@@ -51,6 +54,7 @@ public class MailboxReceiveOperator extends 
BaseOperator<TransferableBlock> {
 
   private final MailboxService<Mailbox.MailboxContent> _mailboxService;
   private final RelDistribution.Type _exchangeType;
+  private final KeySelector<Object[], Object[]> _keySelector;
   private final List<ServerInstance> _sendingStageInstances;
   private final DataSchema _dataSchema;
   private final String _hostName;
@@ -61,18 +65,31 @@ public class MailboxReceiveOperator extends 
BaseOperator<TransferableBlock> {
   private TransferableBlock _upstreamErrorBlock;
 
   public MailboxReceiveOperator(MailboxService<Mailbox.MailboxContent> 
mailboxService, DataSchema dataSchema,
-      RelDistribution.Type exchangeType, List<ServerInstance> 
sendingStageInstances, String hostName, int port,
-      long jobId, int stageId) {
+      List<ServerInstance> sendingStageInstances, RelDistribution.Type 
exchangeType,
+      KeySelector<Object[], Object[]> keySelector, String hostName, int port, 
long jobId, int stageId) {
     _dataSchema = dataSchema;
     _mailboxService = mailboxService;
     _exchangeType = exchangeType;
-    _sendingStageInstances = sendingStageInstances;
+    if (_exchangeType == RelDistribution.Type.SINGLETON) {
+      ServerInstance singletonInstance = null;
+      for (ServerInstance serverInstance : sendingStageInstances) {
+        if (serverInstance.getHostname().equals(_mailboxService.getHostname())
+            && serverInstance.getQueryMailboxPort() == 
_mailboxService.getMailboxPort()) {
+          Preconditions.checkState(singletonInstance == null, "multiple 
instance found for singleton exchange type!");
+          singletonInstance = serverInstance;
+        }
+      }
+      _sendingStageInstances = Collections.singletonList(singletonInstance);
+    } else {
+      _sendingStageInstances = sendingStageInstances;
+    }
     _hostName = hostName;
     _port = port;
     _jobId = jobId;
     _stageId = stageId;
     _timeout = QueryConfig.DEFAULT_TIMEOUT_NANO;
     _upstreamErrorBlock = null;
+    _keySelector = keySelector;
   }
 
   @Override
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
index bcd49f0b9a..32375dc4fa 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
@@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableSet;
 import com.google.protobuf.ByteString;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Random;
 import java.util.Set;
@@ -77,8 +78,20 @@ public class MailboxSendOperator extends 
BaseOperator<TransferableBlock> {
     _dataSchema = dataSchema;
     _mailboxService = mailboxService;
     _dataTableBlockBaseOperator = dataTableBlockBaseOperator;
-    _receivingStageInstances = receivingStageInstances;
     _exchangeType = exchangeType;
+    if (_exchangeType == RelDistribution.Type.SINGLETON) {
+      ServerInstance singletonInstance = null;
+      for (ServerInstance serverInstance : receivingStageInstances) {
+        if (serverInstance.getHostname().equals(_mailboxService.getHostname())
+            && serverInstance.getQueryMailboxPort() == 
_mailboxService.getMailboxPort()) {
+          Preconditions.checkState(singletonInstance == null, "multiple 
instance found for singleton exchange type!");
+          singletonInstance = serverInstance;
+        }
+      }
+      _receivingStageInstances = Collections.singletonList(singletonInstance);
+    } else {
+      _receivingStageInstances = receivingStageInstances;
+    }
     _keySelector = keySelector;
     _serverHostName = hostName;
     _serverPort = port;
@@ -112,7 +125,8 @@ public class MailboxSendOperator extends 
BaseOperator<TransferableBlock> {
     try {
       switch (_exchangeType) {
         case SINGLETON:
-          // TODO: singleton or random distribution should've been 
distinguished in planning phase.
+          sendDataTableBlock(_receivingStageInstances.get(0), dataBlock);
+          break;
         case RANDOM_DISTRIBUTED:
           if (isEndOfStream) {
             for (ServerInstance serverInstance : _receivingStageInstances) {
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/QueryDispatcher.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/QueryDispatcher.java
index 653d2b216d..0323bed72e 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/QueryDispatcher.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/QueryDispatcher.java
@@ -151,8 +151,8 @@ public class QueryDispatcher {
       List<ServerInstance> sendingInstances, long jobId, int stageId, 
DataSchema dataSchema, String hostname,
       int port) {
     MailboxReceiveOperator mailboxReceiveOperator =
-        new MailboxReceiveOperator(mailboxService, dataSchema, 
RelDistribution.Type.ANY, sendingInstances, hostname,
-            port, jobId, stageId);
+        new MailboxReceiveOperator(mailboxService, dataSchema, 
sendingInstances,
+            RelDistribution.Type.RANDOM_DISTRIBUTED, null, hostname, port, 
jobId, stageId);
     return mailboxReceiveOperator;
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to