siddharthteotia commented on code in PR #10248:
URL: https://github.com/apache/pinot/pull/10248#discussion_r1118330347


##########
pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java:
##########
@@ -169,36 +185,136 @@ private RelNode makeNewIntermediateAgg(RelOptRuleCall 
ruleCall, Aggregate oldAgg
    */
   private static void convertAggCall(RexBuilder rexBuilder, Aggregate 
oldAggRel, int oldCallIndex,
       AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, 
RexNode> aggCallMapping,
-      List<RexNode> inputExprs) {
+      boolean isLeafStageAggregationPresent, List<Integer> argList) {
     final int nGroups = oldAggRel.getGroupCount();
     final SqlAggFunction oldAggregation = oldCall.getAggregation();
     final SqlKind aggKind = oldAggregation.getKind();
     // Check only the supported AGG functions are provided.
     Preconditions.checkState(SUPPORTED_AGG_KIND.contains(aggKind), 
"Unsupported SQL aggregation "
         + "kind: {}. Only splittable aggregation functions are supported!", 
aggKind);
 
-    // Special treatment on COUNT
     AggregateCall newCall;
-    if (oldAggregation instanceof SqlCountAggFunction) {
-      newCall = AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), 
oldCall.isDistinct(), oldCall.isApproximate(),
-          oldCall.ignoreNulls(), convertArgList(nGroups + oldCallIndex, 
Collections.singletonList(oldCallIndex)),
-          oldCall.filterArg, oldCall.distinctKeys, oldCall.collation, 
oldCall.type, oldCall.getName());
+    if (isLeafStageAggregationPresent) {
+
+      // Special treatment for Count. If count is performed at the Leaf Stage, 
a Sum needs to be performed at the
+      // intermediate stage.
+      if (oldAggregation instanceof SqlCountAggFunction) {
+        newCall =
+            AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), 
oldCall.isDistinct(), oldCall.isApproximate(),
+                oldCall.ignoreNulls(), convertArgList(nGroups + oldCallIndex, 
Collections.singletonList(oldCallIndex)),
+                oldCall.filterArg, oldCall.distinctKeys, oldCall.collation, 
oldCall.type, oldCall.getName());
+      } else {
+        newCall = AggregateCall.create(oldCall.getAggregation(), 
oldCall.isDistinct(), oldCall.isApproximate(),
+            oldCall.ignoreNulls(), convertArgList(nGroups + oldCallIndex, 
oldCall.getArgList()), oldCall.filterArg,
+            oldCall.distinctKeys, oldCall.collation, oldCall.type, 
oldCall.getName());
+      }
     } else {
-      newCall = AggregateCall.create(
-          oldCall.getAggregation(), oldCall.isDistinct(), 
oldCall.isApproximate(), oldCall.ignoreNulls(),
-          convertArgList(nGroups + oldCallIndex, oldCall.getArgList()), 
oldCall.filterArg, oldCall.distinctKeys,
-          oldCall.collation, oldCall.type, oldCall.getName());
+      List<Integer> newArgList = oldCall.getArgList().size() == 0 ? 
Collections.emptyList()
+          : Collections.singletonList(argList.get(oldCallIndex));
+
+      newCall = AggregateCall.create(oldCall.getAggregation(), 
oldCall.isDistinct(), oldCall.isApproximate(),
+          oldCall.ignoreNulls(), newArgList, oldCall.filterArg, 
oldCall.distinctKeys, oldCall.collation, oldCall.type,
+          oldCall.getName());
     }
-    rexBuilder.addAggCall(newCall,
-        nGroups,
-        newCalls,
-        aggCallMapping,
-        oldAggRel.getInput()::fieldIsNullable);
+
+    rexBuilder.addAggCall(newCall, nGroups, newCalls, aggCallMapping, 
oldAggRel.getInput()::fieldIsNullable);
   }
 
   private static List<Integer> convertArgList(int oldCallIndexWithShift, 
List<Integer> argList) {
     Preconditions.checkArgument(argList.size() <= 1,
         "Unable to convert call as the argList contains more than 1 argument");
     return argList.size() == 1 ? 
Collections.singletonList(oldCallIndexWithShift) : Collections.emptyList();
   }
+
+  private void createPlanWithoutLeafAggregation(RelOptRuleCall call) {
+    Aggregate oldAggRel = call.rel(0);
+    RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
+    LogicalProject project;
+
+    List<Integer> newAggArgColumns = new ArrayList<>();
+    List<Integer> newAggGroupByColumns = new ArrayList<>();
+
+    // 1. Create the LogicalProject node if it does not exist. This is to send 
only the relevant columns over
+    //    the wire for intermediate aggregation.
+    if (childRel instanceof Project) {
+      // Avoid creating a new LogicalProject if the child node of aggregation 
is already a project node.
+      project = (LogicalProject) childRel;
+      newAggArgColumns = fetchNewAggArgCols(oldAggRel.getAggCallList());
+      newAggGroupByColumns = oldAggRel.getGroupSet().asList();
+    } else {
+      // Create a leaf stage project. This is done so that only the required 
columns are sent over the wire for
+      // intermediate aggregation. If there are multiple aggregations on the 
same column, the column is projected
+      // only once.
+      project = createLogicalProjectForAggregate(oldAggRel, newAggArgColumns, 
newAggGroupByColumns);
+    }
+
+    // 2. Create an exchange on top of the LogicalProject.
+    LogicalExchange exchange = LogicalExchange.create(project, 
RelDistributions.hash(newAggGroupByColumns));
+
+    // 3. Create an intermediate stage aggregation.
+    RelNode newAggNode =
+        makeNewIntermediateAgg(call, oldAggRel, exchange, false, 
newAggArgColumns, newAggGroupByColumns);
+
+    call.transformTo(newAggNode);
+  }
+
+  private LogicalProject createLogicalProjectForAggregate(Aggregate oldAggRel, 
List<Integer> newAggArgColumns,
+      List<Integer> newAggGroupByCols) {
+    RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
+    RexBuilder childRexBuilder = childRel.getCluster().getRexBuilder();
+    List<RelDataTypeField> fieldList = childRel.getRowType().getFieldList();
+
+    List<RexNode> projectColRexNodes = new ArrayList<>();
+    List<String> projectColNames = new ArrayList<>();
+    // Maintains a mapping from the column to the corresponding index in 
projectColRexNodes.
+    Map<Integer, Integer> projectSet = new HashMap<>();
+
+    int projectIndex = 0;
+    for (int group : oldAggRel.getGroupSet().asSet()) {
+      projectColNames.add(fieldList.get(group).getName());
+      projectColRexNodes.add(childRexBuilder.makeInputRef(childRel, group));
+      projectSet.put(group, projectColRexNodes.size() - 1);

Review Comment:
   (nit) we already know the running size.. probably not necessary to do 
`projectColRexNodes.size() - 1`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to