This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new e2c5e73970 Pass literal within AggregateCall via rexList (#13282)
e2c5e73970 is described below

commit e2c5e73970b1e8f64df7c763c5bcac36ff19d2a6
Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com>
AuthorDate: Fri May 31 18:00:00 2024 -0700

    Pass literal within AggregateCall via rexList (#13282)
---
 .../pinot/calcite/rel/hint/PinotHintOptions.java   |  13 -
 .../PinotAggregateExchangeNodeInsertRule.java      | 422 ++++++++++-----------
 .../rules/PinotAggregateLiteralAttachmentRule.java | 107 ------
 .../calcite/rel/rules/PinotQueryRuleSets.java      |   5 -
 .../org/apache/pinot/query/QueryEnvironment.java   |   4 -
 .../query/parser/CalciteRexExpressionParser.java   |   4 +-
 .../query/planner/logical/LiteralHintUtils.java    |  85 -----
 .../query/planner/logical/RexExpressionUtils.java  |   6 +-
 .../apache/pinot/query/QueryCompilationTest.java   |   3 +-
 .../src/test/resources/queries/GroupByPlans.json   |  18 +-
 .../src/test/resources/queries/OrderByPlans.json   |   4 +-
 .../test/resources/queries/PinotHintablePlans.json |  33 +-
 .../query/runtime/operator/AggregateOperator.java  | 125 ++----
 .../src/test/resources/queries/QueryHints.json     |   8 +-
 .../pinot/segment/spi/AggregationFunctionType.java |   7 +-
 15 files changed, 256 insertions(+), 588 deletions(-)

diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
index 1d53a3184e..99e07b61df 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java
@@ -20,7 +20,6 @@ package org.apache.pinot.calcite.rel.hint;
 
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.hint.RelHint;
-import org.apache.pinot.query.planner.logical.LiteralHintUtils;
 
 
 /**
@@ -47,18 +46,6 @@ public class PinotHintOptions {
 
   public static class InternalAggregateOptions {
     public static final String AGG_TYPE = "agg_type";
-    /**
-     * agg call signature is used to store LITERAL inputs to the Aggregate 
Call. which is not supported in Calcite
-     * here
-     * 1. we store the Map of Pair[aggCallIdx, argListIdx] to RexLiteral to 
indicate the RexLiteral being passed into
-     *     the aggregateCalls[aggCallIdx].operandList[argListIdx] is supposed 
to be a RexLiteral.
-     * 2. not all RexLiteral types are supported to be part of the input 
constant call signature.
-     * 3. RexLiteral are encoded as String and decoded as Pinot Literal 
objects.
-     *
-     * see: {@link LiteralHintUtils}.
-     * see: https://issues.apache.org/jira/projects/CALCITE/issues/CALCITE-5833
-     */
-    public static final String AGG_CALL_SIGNATURE = "agg_call_signature";
   }
 
   public static class AggregateOptions {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
index ffe0741751..0e6e13b0e7 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
@@ -19,20 +19,16 @@
 package org.apache.pinot.calcite.rel.rules;
 
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
-import java.util.Locale;
 import java.util.Map;
-import java.util.Set;
 import javax.annotation.Nullable;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.plan.hep.HepRelVertex;
 import org.apache.calcite.rel.RelCollation;
 import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.RelDistribution;
 import org.apache.calcite.rel.RelDistributions;
 import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.calcite.rel.RelNode;
@@ -44,16 +40,16 @@ import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.rel.rules.AggregateExtractProjectRule;
 import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule;
 import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.ImmutableIntList;
-import org.apache.calcite.util.Util;
 import org.apache.calcite.util.mapping.Mapping;
 import org.apache.calcite.util.mapping.MappingType;
 import org.apache.calcite.util.mapping.Mappings;
@@ -88,8 +84,6 @@ import org.apache.pinot.segment.spi.AggregationFunctionType;
 public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
   public static final PinotAggregateExchangeNodeInsertRule INSTANCE =
       new 
PinotAggregateExchangeNodeInsertRule(PinotRuleUtils.PINOT_REL_FACTORY);
-  public static final Set<String> LIST_AGG_FUNCTION_NAMES =
-      ImmutableSet.of("LISTAGG", "LIST_AGG", "ARRAYsAGG", "ARRAY_AGG");
 
   public PinotAggregateExchangeNodeInsertRule(RelBuilderFactory factory) {
     super(operand(LogicalAggregate.class, any()), factory, null);
@@ -119,137 +113,104 @@ public class PinotAggregateExchangeNodeInsertRule 
extends RelOptRule {
    */
   @Override
   public void onMatch(RelOptRuleCall call) {
-    Aggregate oldAggRel = call.rel(0);
-    ImmutableList<RelHint> oldHints = oldAggRel.getHints();
-    // Both collation and distinct are not supported in leaf stage aggregation.
-    boolean hasCollation = hasCollation(oldAggRel);
-    boolean hasDistinct = hasDistinct(oldAggRel);
-    Aggregate newAgg;
-    if (!oldAggRel.getGroupSet().isEmpty() && 
PinotHintStrategyTable.isHintOptionTrue(oldHints,
-        PinotHintOptions.AGGREGATE_HINT_OPTIONS, 
PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS)) {
-      // 
------------------------------------------------------------------------
-      // If the "is_partitioned_by_group_by_keys" aggregate hint option is 
set, just add additional hints indicating
-      // this is a single stage aggregation.
-      List<RelHint> newHints = 
PinotHintStrategyTable.replaceHintOptions(oldAggRel.getHints(),
-          PinotHintOptions.INTERNAL_AGG_OPTIONS, 
PinotHintOptions.InternalAggregateOptions.AGG_TYPE,
-          AggType.DIRECT.name());
-      newAgg =
-          new LogicalAggregate(oldAggRel.getCluster(), 
oldAggRel.getTraitSet(), newHints, oldAggRel.getInput(),
-              oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), 
oldAggRel.getAggCallList());
-    } else if (hasCollation || hasDistinct || 
(!oldAggRel.getGroupSet().isEmpty()
-        && PinotHintStrategyTable.isHintOptionTrue(oldHints, 
PinotHintOptions.AGGREGATE_HINT_OPTIONS,
+    Aggregate argRel = call.rel(0);
+    ImmutableList<RelHint> hints = argRel.getHints();
+    // Collation is not supported in leaf stage aggregation.
+    RelCollation collation = extractWithInGroupCollation(argRel);
+    boolean hasGroupBy = !argRel.getGroupSet().isEmpty();
+    if (collation != null || (hasGroupBy && 
PinotHintStrategyTable.isHintOptionTrue(hints,
+        PinotHintOptions.AGGREGATE_HINT_OPTIONS,
         
PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION))) {
-      // 
------------------------------------------------------------------------
-      // If "is_skip_leaf_stage_group_by" SQLHint option is passed, the leaf 
stage aggregation is skipped.
-      newAgg = (Aggregate) createPlanWithExchangeDirectAggregation(call);
+      call.transformTo(createPlanWithExchangeDirectAggregation(call, 
collation));
+    } else if (hasGroupBy && PinotHintStrategyTable.isHintOptionTrue(hints, 
PinotHintOptions.AGGREGATE_HINT_OPTIONS,
+        PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS)) {
+      call.transformTo(createPlanWithDirectAggregation(call));
     } else {
-      // 
------------------------------------------------------------------------
-      newAgg = (Aggregate) createPlanWithLeafExchangeFinalAggregate(call);
+      call.transformTo(createPlanWithLeafExchangeFinalAggregate(call));
     }
-    call.transformTo(newAgg);
   }
 
-  private boolean hasDistinct(Aggregate aggRel) {
+  // TODO: Currently it only handles one WITHIN GROUP collation across all 
AggregateCalls.
+  @Nullable
+  private static RelCollation extractWithInGroupCollation(Aggregate aggRel) {
     for (AggregateCall aggCall : aggRel.getAggCallList()) {
-      // If the aggregation function is a list aggregation function and it is 
distinct, we can skip leaf stage.
-      // For COUNT(DISTINCT), there could be more leaf stage optimization.
-      if (aggCall.isDistinct() && 
LIST_AGG_FUNCTION_NAMES.contains(aggCall.getAggregation().getName().toUpperCase()))
 {
-        return true;
+      List<RelFieldCollation> fieldCollations = 
aggCall.getCollation().getFieldCollations();
+      if (!fieldCollations.isEmpty()) {
+        return RelCollations.of(fieldCollations);
       }
     }
-    return false;
+    return null;
   }
 
-  private boolean hasCollation(Aggregate aggRel) {
-    for (AggregateCall aggCall : aggRel.getAggCallList()) {
-      if (!aggCall.getCollation().getKeys().isEmpty()) {
-        return true;
-      }
-    }
-    return false;
+  private static RelNode createPlanWithDirectAggregation(RelOptRuleCall call) {
+    Aggregate aggRel = call.rel(0);
+    List<RelHint> newHints =
+        PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(), 
PinotHintOptions.INTERNAL_AGG_OPTIONS,
+            PinotHintOptions.InternalAggregateOptions.AGG_TYPE, 
AggType.DIRECT.name());
+    return new LogicalAggregate(aggRel.getCluster(), aggRel.getTraitSet(), 
newHints, aggRel.getInput(),
+        aggRel.getGroupSet(), aggRel.getGroupSets(), buildAggCalls(aggRel, 
AggType.DIRECT));
   }
 
   /**
    * Aggregate node will be split into LEAF + exchange + FINAL.
-   * optionally we can insert INTERMEDIATE to reduce hotspot in the future.
+   * TODO: Add optional INTERMEDIATE stage to reduce hotspot.
    */
-  private RelNode createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall 
call) {
-    // TODO: add optional intermediate stage here when hinted.
-    Aggregate oldAggRel = call.rel(0);
-    // 1. attach leaf agg RelHint to original agg. Perform any aggregation 
call conversions necessary
-    Aggregate leafAgg = convertAggForLeafInput(oldAggRel);
-    // 2. attach exchange.
-    List<Integer> groupSetIndices = ImmutableIntList.range(0, 
oldAggRel.getGroupCount());
-    PinotLogicalExchange exchange;
-    if (groupSetIndices.size() == 0) {
-      exchange = PinotLogicalExchange.create(leafAgg, 
RelDistributions.hash(Collections.emptyList()));
-    } else {
-      exchange = PinotLogicalExchange.create(leafAgg, 
RelDistributions.hash(groupSetIndices));
-    }
-    // 3. attach final agg stage.
-    return convertAggFromIntermediateInput(call, oldAggRel, exchange, 
AggType.FINAL);
+  private static RelNode 
createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall call) {
+    Aggregate aggRel = call.rel(0);
+    // Create a LEAF aggregate.
+    Aggregate leafAggRel = convertAggForLeafInput(aggRel);
+    // Create an exchange node over the LEAF aggregate.
+    PinotLogicalExchange exchange = PinotLogicalExchange.create(leafAggRel,
+        RelDistributions.hash(ImmutableIntList.range(0, 
aggRel.getGroupCount())));
+    // Create a FINAL aggregate over the exchange.
+    return convertAggFromIntermediateInput(call, exchange, AggType.FINAL);
   }
 
   /**
    * Use this group by optimization to skip leaf stage aggregation when 
aggregating at leaf level is not desired.
    * Many situation could be wasted effort to do group-by on leaf, eg: when 
cardinality of group by column is very high.
    */
-  private RelNode createPlanWithExchangeDirectAggregation(RelOptRuleCall call) 
{
-    Aggregate oldAggRel = call.rel(0);
-    List<RelHint> newHints = 
PinotHintStrategyTable.replaceHintOptions(oldAggRel.getHints(),
-        PinotHintOptions.INTERNAL_AGG_OPTIONS, 
PinotHintOptions.InternalAggregateOptions.AGG_TYPE,
-        AggType.DIRECT.name());
-
-    // Convert Aggregate WithGroup Collation into a Sort
-    RelCollation relCollation = extractWithInGroupCollation(oldAggRel);
+  private static RelNode 
createPlanWithExchangeDirectAggregation(RelOptRuleCall call,
+      @Nullable RelCollation collation) {
+    Aggregate aggRel = call.rel(0);
+    RelNode input = aggRel.getInput();
+    // Create Project when there's none below the aggregate.
+    if (!(PinotRuleUtils.unboxRel(input) instanceof Project)) {
+      aggRel = (Aggregate) generateProjectUnderAggregate(call);
+      input = aggRel.getInput();
+    }
 
-    // create project when there's none below the aggregate to reduce exchange 
overhead
-    RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
-    if (!(childRel instanceof Project)) {
-      return convertAggForExchangeDirectAggregate(call, newHints, 
relCollation);
+    ImmutableBitSet groupSet = aggRel.getGroupSet();
+    RelDistribution distribution = RelDistributions.hash(groupSet.asList());
+    RelNode exchange;
+    if (collation != null) {
+      // Insert a LogicalSort node between exchange and aggregate whe 
collation exists.
+      exchange = PinotLogicalSortExchange.create(input, distribution, 
collation, false, true);
     } else {
-      // create normal exchange
-      List<Integer> groupSetIndices = new ArrayList<>();
-      oldAggRel.getGroupSet().forEach(groupSetIndices::add);
-      RelNode newAggChild;
-      if (relCollation != null) {
-        newAggChild =
-            (groupSetIndices.isEmpty()) ? 
PinotLogicalSortExchange.create(childRel, RelDistributions.SINGLETON,
-                relCollation, false, true)
-                : PinotLogicalSortExchange.create(childRel, 
RelDistributions.hash(groupSetIndices),
-                    relCollation, false, true);
-      } else {
-        newAggChild = PinotLogicalExchange.create(childRel, 
RelDistributions.hash(groupSetIndices));
-      }
-      return new LogicalAggregate(oldAggRel.getCluster(), 
oldAggRel.getTraitSet(), newHints, newAggChild,
-          oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), 
oldAggRel.getAggCallList());
+      exchange = PinotLogicalExchange.create(input, distribution);
     }
-  }
 
-  // Extract the first collation in the AggregateCall list
-  @Nullable
-  private RelCollation extractWithInGroupCollation(Aggregate aggRel) {
-    for (AggregateCall aggCall : aggRel.getAggCallList()) {
-      List<RelFieldCollation> fieldCollations = 
aggCall.getCollation().getFieldCollations();
-      if (!fieldCollations.isEmpty()) {
-        return RelCollations.of(fieldCollations);
-      }
-    }
-    return null;
+    List<RelHint> newHints =
+        PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(), 
PinotHintOptions.INTERNAL_AGG_OPTIONS,
+            PinotHintOptions.InternalAggregateOptions.AGG_TYPE, 
AggType.DIRECT.name());
+    return new LogicalAggregate(aggRel.getCluster(), aggRel.getTraitSet(), 
newHints, exchange, groupSet,
+        aggRel.getGroupSets(), buildAggCalls(aggRel, AggType.DIRECT));
   }
 
   /**
-   * The following is copied from {@link 
AggregateExtractProjectRule#onMatch(RelOptRuleCall)}
-   * with modification to insert an exchange in between the Aggregate and 
Project
+   * The following is copied from {@link 
AggregateExtractProjectRule#onMatch(RelOptRuleCall)} with modification to take
+   * aggregate input as input.
    */
-  private RelNode convertAggForExchangeDirectAggregate(RelOptRuleCall call, 
List<RelHint> newHints,
-      @Nullable RelCollation collation) {
+  private static RelNode generateProjectUnderAggregate(RelOptRuleCall call) {
     final Aggregate aggregate = call.rel(0);
+    // --------------- MODIFIED ---------------
     final RelNode input = aggregate.getInput();
+    // final RelNode input = call.rel(1);
+    // ------------- END MODIFIED -------------
+
     // Compute which input fields are used.
     // 1. group fields are always used
-    final ImmutableBitSet.Builder inputFieldsUsed =
-        aggregate.getGroupSet().rebuild();
+    final ImmutableBitSet.Builder inputFieldsUsed = 
aggregate.getGroupSet().rebuild();
     // 2. agg functions
     for (AggregateCall aggCall : aggregate.getAggCallList()) {
       for (int i : aggCall.getArgList()) {
@@ -259,149 +220,164 @@ public class PinotAggregateExchangeNodeInsertRule 
extends RelOptRule {
         inputFieldsUsed.set(aggCall.filterArg);
       }
     }
-    final RelBuilder relBuilder1 = call.builder().push(input);
+    final RelBuilder relBuilder = call.builder().push(input);
     final List<RexNode> projects = new ArrayList<>();
     final Mapping mapping =
-        Mappings.create(MappingType.INVERSE_SURJECTION,
-            aggregate.getInput().getRowType().getFieldCount(),
+        Mappings.create(MappingType.INVERSE_SURJECTION, 
aggregate.getInput().getRowType().getFieldCount(),
             inputFieldsUsed.cardinality());
     int j = 0;
     for (int i : inputFieldsUsed.build()) {
-      projects.add(relBuilder1.field(i));
+      projects.add(relBuilder.field(i));
       mapping.set(i, j++);
     }
-    relBuilder1.project(projects);
-    final ImmutableBitSet newGroupSet =
-        Mappings.apply(mapping, aggregate.getGroupSet());
-    Project project = (Project) relBuilder1.build();
 
-    // ------------------------------------------------------------------------
-    RelNode newAggChild;
-    if (collation != null) {
-      // Insert a LogicalSort node between the exchange and the aggregate
-      newAggChild = newGroupSet.isEmpty() ? 
PinotLogicalSortExchange.create(project, RelDistributions.SINGLETON,
-          collation, false, true)
-          : PinotLogicalSortExchange.create(project, 
RelDistributions.hash(newGroupSet.asList()),
-              collation, false, true);
-    } else {
-      newAggChild = PinotLogicalExchange.create(project, 
RelDistributions.hash(newGroupSet.asList()));
-    }
-    // ------------------------------------------------------------------------
+    relBuilder.project(projects);
 
-    final RelBuilder relBuilder2 = call.builder().push(newAggChild);
+    final ImmutableBitSet newGroupSet = Mappings.apply(mapping, 
aggregate.getGroupSet());
     final List<ImmutableBitSet> newGroupSets =
-        aggregate.getGroupSets().stream()
-            .map(bitSet -> Mappings.apply(mapping, bitSet))
-            .collect(Util.toImmutableList());
+        aggregate.getGroupSets().stream().map(bitSet -> 
Mappings.apply(mapping, bitSet))
+            .collect(ImmutableList.toImmutableList());
     final List<RelBuilder.AggCall> newAggCallList =
-        aggregate.getAggCallList().stream()
-            .map(aggCall -> relBuilder2.aggregateCall(aggCall, mapping))
-            .collect(Util.toImmutableList());
-    final RelBuilder.GroupKey groupKey =
-        relBuilder2.groupKey(newGroupSet, newGroupSets);
-    relBuilder2.aggregate(groupKey, newAggCallList).hints(newHints);
-    return relBuilder2.build();
+        aggregate.getAggCallList().stream().map(aggCall -> 
relBuilder.aggregateCall(aggCall, mapping))
+            .collect(ImmutableList.toImmutableList());
+
+    final RelBuilder.GroupKey groupKey = relBuilder.groupKey(newGroupSet, 
newGroupSets);
+    relBuilder.aggregate(groupKey, newAggCallList);
+    return relBuilder.build();
   }
 
-  private Aggregate convertAggForLeafInput(Aggregate oldAggRel) {
-    List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
-    List<AggregateCall> newCalls = new ArrayList<>();
-    for (AggregateCall oldCall : oldCalls) {
-      newCalls.add(buildAggregateCall(oldAggRel.getInput(), oldCall, 
oldCall.getArgList(), oldAggRel.getGroupCount(),
-          AggType.LEAF));
-    }
-    List<RelHint> newHints = 
PinotHintStrategyTable.replaceHintOptions(oldAggRel.getHints(),
-        PinotHintOptions.INTERNAL_AGG_OPTIONS, 
PinotHintOptions.InternalAggregateOptions.AGG_TYPE, AggType.LEAF.name());
-    return new LogicalAggregate(oldAggRel.getCluster(), 
oldAggRel.getTraitSet(), newHints, oldAggRel.getInput(),
-        oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), newCalls);
+  private static Aggregate convertAggForLeafInput(Aggregate aggRel) {
+    List<RelHint> newHints =
+        PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(), 
PinotHintOptions.INTERNAL_AGG_OPTIONS,
+            PinotHintOptions.InternalAggregateOptions.AGG_TYPE, 
AggType.LEAF.name());
+    return new LogicalAggregate(aggRel.getCluster(), aggRel.getTraitSet(), 
newHints, aggRel.getInput(),
+        aggRel.getGroupSet(), aggRel.getGroupSets(), buildAggCalls(aggRel, 
AggType.LEAF));
   }
 
-  private RelNode convertAggFromIntermediateInput(RelOptRuleCall ruleCall, 
Aggregate oldAggRel,
-      PinotLogicalExchange exchange, AggType aggType) {
-    // add the exchange as the input node to the relation builder.
-    RelBuilder relBuilder = ruleCall.builder();
-    relBuilder.push(exchange);
+  private static RelNode convertAggFromIntermediateInput(RelOptRuleCall call, 
PinotLogicalExchange exchange,
+      AggType aggType) {
+    Aggregate aggRel = call.rel(0);
+    RelNode input = PinotRuleUtils.unboxRel(aggRel.getInput());
+    List<RexNode> projects = (input instanceof Project) ? ((Project) 
input).getProjects() : null;
 
-    // make input ref to the exchange after the leaf aggregate, all groups 
should be at the front
     RexBuilder rexBuilder = exchange.getCluster().getRexBuilder();
-    final int nGroups = oldAggRel.getGroupCount();
-    for (int i = 0; i < nGroups; i++) {
-      rexBuilder.makeInputRef(oldAggRel, i);
-    }
-
-    List<AggregateCall> newCalls = new ArrayList<>();
+    int groupCount = aggRel.getGroupCount();
+    List<AggregateCall> orgAggCalls = aggRel.getAggCallList();
+    int numAggCalls = orgAggCalls.size();
+    List<AggregateCall> aggCalls = new ArrayList<>(numAggCalls);
     Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();
 
-    // create new aggregate function calls from exchange input, all aggCalls 
are followed one by one from exchange
-    // b/c the exchange produces intermediate results, thus the input to the 
newCall will be indexed at
-    // [nGroup + oldCallIndex]
-    List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
-    for (int oldCallIndex = 0; oldCallIndex < oldCalls.size(); oldCallIndex++) 
{
-      AggregateCall oldCall = oldCalls.get(oldCallIndex);
-      // intermediate stage input only supports single argument inputs.
-      List<Integer> argList = Collections.singletonList(nGroups + 
oldCallIndex);
-      AggregateCall newCall = buildAggregateCall(exchange, oldCall, argList, 
nGroups, aggType);
-      rexBuilder.addAggCall(newCall, nGroups, newCalls, aggCallMapping, 
oldAggRel.getInput()::fieldIsNullable);
+    // Create new AggregateCalls from exchange input. Exchange produces 
results with group keys followed by intermediate
+    // aggregate results.
+    for (int i = 0; i < numAggCalls; i++) {
+      AggregateCall orgAggCall = orgAggCalls.get(i);
+      List<Integer> argList = orgAggCall.getArgList();
+      int index = groupCount + i;
+      RexInputRef inputRef = RexInputRef.of(index, aggRel.getRowType());
+      // Generate rexList from argList and replace literal reference with 
literal. Keep the first argument as is.
+      int numArguments = argList.size();
+      List<RexNode> rexList;
+      if (numArguments <= 1) {
+        rexList = ImmutableList.of(inputRef);
+      } else {
+        rexList = new ArrayList<>(numArguments);
+        rexList.add(inputRef);
+        for (int j = 1; j < numArguments; j++) {
+          int argument = argList.get(j);
+          if (projects != null && projects.get(argument) instanceof 
RexLiteral) {
+            rexList.add(projects.get(argument));
+          } else {
+            // Replace all the input reference in the rexList to the new input 
reference.
+            rexList.add(inputRef);
+          }
+        }
+      }
+      AggregateCall newAggregate = buildAggCall(exchange, orgAggCall, rexList, 
groupCount, aggType);
+      rexBuilder.addAggCall(newAggregate, groupCount, aggCalls, 
aggCallMapping, aggRel.getInput()::fieldIsNullable);
     }
 
-    // create new aggregate relation.
-    ImmutableList<RelHint> orgHints = oldAggRel.getHints();
-    List<RelHint> newAggHint = 
PinotHintStrategyTable.replaceHintOptions(orgHints,
-        PinotHintOptions.INTERNAL_AGG_OPTIONS, 
PinotHintOptions.InternalAggregateOptions.AGG_TYPE, aggType.name());
-    ImmutableBitSet groupSet = ImmutableBitSet.range(nGroups);
-    relBuilder.aggregate(relBuilder.groupKey(groupSet, 
ImmutableList.of(groupSet)), newCalls);
-    relBuilder.hints(newAggHint);
+    RelBuilder relBuilder = call.builder();
+    relBuilder.push(exchange);
+    ImmutableBitSet groupSet = ImmutableBitSet.range(groupCount);
+    relBuilder.aggregate(relBuilder.groupKey(groupSet, 
ImmutableList.of(groupSet)), aggCalls);
+    List<RelHint> newHints =
+        PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(), 
PinotHintOptions.INTERNAL_AGG_OPTIONS,
+            PinotHintOptions.InternalAggregateOptions.AGG_TYPE, 
aggType.name());
+    relBuilder.hints(newHints);
     return relBuilder.build();
   }
 
-  private static AggregateCall buildAggregateCall(RelNode inputNode, 
AggregateCall orgAggCall, List<Integer> argList,
-      int numberGroups, AggType aggType) {
-    final SqlAggFunction oldAggFunction = orgAggCall.getAggregation();
-    final SqlKind aggKind = oldAggFunction.getKind();
-    String functionName = getFunctionNameFromAggregateCall(orgAggCall);
-    AggregationFunctionType functionType = 
AggregationFunctionType.getAggregationFunctionType(functionName);
-    // create the aggFunction
-    SqlAggFunction sqlAggFunction;
-    if (functionType.getIntermediateReturnTypeInference() != null) {
-      switch (aggType) {
-        case LEAF:
-          sqlAggFunction = new 
PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
-              functionType.getSqlKind(), 
functionType.getIntermediateReturnTypeInference(), null,
-              functionType.getOperandTypeChecker(), 
functionType.getSqlFunctionCategory());
-          break;
-        case INTERMEDIATE:
-          sqlAggFunction = new 
PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
-              functionType.getSqlKind(), 
functionType.getIntermediateReturnTypeInference(), null,
-              OperandTypes.ANY, functionType.getSqlFunctionCategory());
-          break;
-        case FINAL:
-          sqlAggFunction = new 
PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
-              functionType.getSqlKind(), 
ReturnTypes.explicit(orgAggCall.getType()), null,
-              OperandTypes.ANY, functionType.getSqlFunctionCategory());
-          break;
-        default:
-          throw new UnsupportedOperationException("Unsuppoted aggType: " + 
aggType + " for " + functionName);
+  private static List<AggregateCall> buildAggCalls(Aggregate aggRel, AggType 
aggType) {
+    RelNode input = PinotRuleUtils.unboxRel(aggRel.getInput());
+    List<RexNode> projects = (input instanceof Project) ? ((Project) 
input).getProjects() : null;
+    List<AggregateCall> orgAggCalls = aggRel.getAggCallList();
+    List<AggregateCall> aggCalls = new ArrayList<>(orgAggCalls.size());
+    for (AggregateCall orgAggCall : orgAggCalls) {
+      // Generate rexList from argList and replace literal reference with 
literal. Keep the first argument as is.
+      List<Integer> argList = orgAggCall.getArgList();
+      int numArguments = argList.size();
+      List<RexNode> rexList;
+      if (numArguments == 0) {
+        rexList = ImmutableList.of();
+      } else if (numArguments == 1) {
+        rexList = ImmutableList.of(RexInputRef.of(argList.get(0), 
input.getRowType()));
+      } else {
+        rexList = new ArrayList<>(numArguments);
+        rexList.add(RexInputRef.of(argList.get(0), input.getRowType()));
+        for (int i = 1; i < numArguments; i++) {
+          int argument = argList.get(i);
+          if (projects != null && projects.get(argument) instanceof 
RexLiteral) {
+            rexList.add(projects.get(argument));
+          } else {
+            rexList.add(RexInputRef.of(argument, input.getRowType()));
+          }
+        }
       }
-    } else {
-      sqlAggFunction = oldAggFunction;
+      aggCalls.add(buildAggCall(input, orgAggCall, rexList, 
aggRel.getGroupCount(), aggType));
     }
-
-    return AggregateCall.create(sqlAggFunction,
-        functionName.equals("distinctCount") || orgAggCall.isDistinct(),
-        orgAggCall.isApproximate(),
-        orgAggCall.ignoreNulls(),
-        argList,
-        aggType.isInputIntermediateFormat() ? -1 : orgAggCall.filterArg,
-        orgAggCall.distinctKeys,
-        orgAggCall.collation,
-        numberGroups,
-        inputNode,
-        null,
-        null);
+    return aggCalls;
   }
 
-  private static String getFunctionNameFromAggregateCall(AggregateCall 
aggregateCall) {
-    return aggregateCall.getAggregation().getName().equalsIgnoreCase("COUNT") 
&& aggregateCall.isDistinct()
-        ? "distinctCount" : aggregateCall.getAggregation().getName();
+  // TODO: Revisit the following logic:
+  //   - DISTINCT is resolved here
+  //   - argList is replaced with rexList
+  private static AggregateCall buildAggCall(RelNode input, AggregateCall 
orgAggCall, List<RexNode> rexList,
+      int numGroups, AggType aggType) {
+    String functionName = orgAggCall.getAggregation().getName();
+    if (orgAggCall.isDistinct()) {
+      if (functionName.equals("COUNT")) {
+        functionName = "DISTINCTCOUNT";
+      } else if (functionName.equals("LISTAGG")) {
+        rexList.add(input.getCluster().getRexBuilder().makeLiteral(true));
+      }
+    }
+    AggregationFunctionType functionType = 
AggregationFunctionType.getAggregationFunctionType(functionName);
+    SqlAggFunction sqlAggFunction;
+    switch (aggType) {
+      case DIRECT:
+        sqlAggFunction = new PinotSqlAggFunction(functionName, null, 
functionType.getSqlKind(),
+            ReturnTypes.explicit(orgAggCall.getType()), null, 
functionType.getOperandTypeChecker(),
+            functionType.getSqlFunctionCategory());
+        break;
+      case LEAF:
+        sqlAggFunction = new PinotSqlAggFunction(functionName, null, 
functionType.getSqlKind(),
+            functionType.getIntermediateReturnTypeInference(), null, 
functionType.getOperandTypeChecker(),
+            functionType.getSqlFunctionCategory());
+        break;
+      case INTERMEDIATE:
+        sqlAggFunction = new PinotSqlAggFunction(functionName, null, 
functionType.getSqlKind(),
+            functionType.getIntermediateReturnTypeInference(), null, 
OperandTypes.ANY,
+            functionType.getSqlFunctionCategory());
+        break;
+      case FINAL:
+        sqlAggFunction = new PinotSqlAggFunction(functionName, null, 
functionType.getSqlKind(),
+            ReturnTypes.explicit(orgAggCall.getType()), null, 
OperandTypes.ANY, functionType.getSqlFunctionCategory());
+        break;
+      default:
+        throw new IllegalStateException("Unsupported AggType: " + aggType);
+    }
+    return AggregateCall.create(sqlAggFunction, false, 
orgAggCall.isApproximate(), orgAggCall.ignoreNulls(), rexList,
+        ImmutableList.of(), aggType.isInputIntermediateFormat() ? -1 : 
orgAggCall.filterArg, orgAggCall.distinctKeys,
+        orgAggCall.collation, numGroups, input, null, null);
   }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateLiteralAttachmentRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateLiteralAttachmentRule.java
deleted file mode 100644
index 74af35b47a..0000000000
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateLiteralAttachmentRule.java
+++ /dev/null
@@ -1,107 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.calcite.rel.rules;
-
-import com.google.common.collect.ImmutableList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import org.apache.calcite.plan.RelOptRule;
-import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.core.Aggregate;
-import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.rel.core.Project;
-import org.apache.calcite.rel.hint.RelHint;
-import org.apache.calcite.rel.logical.LogicalAggregate;
-import org.apache.calcite.rex.RexLiteral;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.tools.RelBuilderFactory;
-import org.apache.calcite.util.Pair;
-import org.apache.pinot.calcite.rel.hint.PinotHintOptions;
-import org.apache.pinot.calcite.rel.hint.PinotHintStrategyTable;
-import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
-import org.apache.pinot.query.planner.logical.LiteralHintUtils;
-import org.apache.pinot.query.planner.logical.RexExpression;
-import org.apache.pinot.query.planner.logical.RexExpressionUtils;
-
-
-/**
- * Special rule to attach Literal to Aggregate call.
- */
-public class PinotAggregateLiteralAttachmentRule extends RelOptRule {
-  public static final PinotAggregateLiteralAttachmentRule INSTANCE =
-      new 
PinotAggregateLiteralAttachmentRule(PinotRuleUtils.PINOT_REL_FACTORY);
-
-  public PinotAggregateLiteralAttachmentRule(RelBuilderFactory factory) {
-    super(operand(LogicalAggregate.class, any()), factory, null);
-  }
-
-  @Override
-  public boolean matches(RelOptRuleCall call) {
-    if (call.rels.length < 1) {
-      return false;
-    }
-    if (call.rel(0) instanceof Aggregate) {
-      Aggregate agg = call.rel(0);
-      ImmutableList<RelHint> hints = agg.getHints();
-      return !PinotHintStrategyTable.containsHintOption(hints,
-          PinotHintOptions.INTERNAL_AGG_OPTIONS, 
PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE);
-    }
-    return false;
-  }
-
-  @Override
-  public void onMatch(RelOptRuleCall call) {
-    Aggregate aggregate = call.rel(0);
-    Map<Pair<Integer, Integer>, RexExpression.Literal> rexLiterals = 
extractLiterals(call);
-    List<RelHint> newHints = 
PinotHintStrategyTable.replaceHintOptions(aggregate.getHints(),
-        PinotHintOptions.INTERNAL_AGG_OPTIONS, 
PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE,
-        LiteralHintUtils.literalMapToHintString(rexLiterals));
-    // TODO: validate against AggregationFunctionType to see if expected 
literal positions are properly attached
-    call.transformTo(new LogicalAggregate(aggregate.getCluster(), 
aggregate.getTraitSet(), newHints,
-        aggregate.getInput(), aggregate.getGroupSet(), 
aggregate.getGroupSets(), aggregate.getAggCallList()));
-  }
-
-  private static Map<Pair<Integer, Integer>, RexExpression.Literal> 
extractLiterals(RelOptRuleCall call) {
-    Aggregate aggregate = call.rel(0);
-    RelNode input = PinotRuleUtils.unboxRel(aggregate.getInput());
-    List<RexNode> rexNodes = (input instanceof Project) ? ((Project) 
input).getProjects() : null;
-    List<AggregateCall> aggCallList = aggregate.getAggCallList();
-    final Map<Pair<Integer, Integer>, RexExpression.Literal> rexLiteralMap = 
new HashMap<>();
-    for (int aggIdx = 0; aggIdx < aggCallList.size(); aggIdx++) {
-      AggregateCall aggCall = aggCallList.get(aggIdx);
-      int argSize = aggCall.getArgList().size();
-      if (argSize > 1) {
-        // use -1 argIdx to indicate size of the agg operands.
-        rexLiteralMap.put(new Pair<>(aggIdx, -1), new 
RexExpression.Literal(ColumnDataType.INT, argSize));
-        // put the literals in to the map.
-        for (int argIdx = 0; argIdx < argSize; argIdx++) {
-          if (rexNodes != null) {
-            RexNode field = rexNodes.get(aggCall.getArgList().get(argIdx));
-            if (field instanceof RexLiteral) {
-              rexLiteralMap.put(new Pair<>(aggIdx, argIdx), 
RexExpressionUtils.fromRexLiteral((RexLiteral) field));
-            }
-          }
-        }
-      }
-    }
-    return rexLiteralMap;
-  }
-}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
index cbac4de9e3..6c2498c70b 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
@@ -117,11 +117,6 @@ public class PinotQueryRuleSets {
       PruneEmptyRules.UNION_INSTANCE
   );
 
-  // Pinot specific rules to run using a single RuleCollection since we attach 
aggregate info after optimizer.
-  public static final Collection<RelOptRule> PINOT_AGG_PROCESS_RULES = 
ImmutableList.of(
-      PinotAggregateLiteralAttachmentRule.INSTANCE
-  );
-
   // Pinot specific rules that should be run AFTER all other rules
   public static final Collection<RelOptRule> PINOT_POST_RULES = 
ImmutableList.of(
       // Evaluate the Literal filter nodes
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
index 059faac2d4..9c53cdee6a 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
@@ -328,10 +328,6 @@ public class QueryEnvironment {
       hepProgramBuilder.addRuleInstance(relOptRule);
     }
 
-    // ----
-    // Run Pinot rule to attach aggregation auxiliary info
-    
hepProgramBuilder.addRuleCollection(PinotQueryRuleSets.PINOT_AGG_PROCESS_RULES);
-
     // ----
     // Pushdown filters using a single HepInstruction.
     
hepProgramBuilder.addRuleCollection(PinotQueryRuleSets.FILTER_PUSHDOWN_RULES);
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
index debe59d0ab..1862adf95e 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
@@ -231,7 +231,7 @@ public class CalciteRexExpressionParser {
         }
         break;
       default:
-        functionName = functionKind.name();
+        functionName = canonicalizeFunctionName(functionKind.name());
         break;
     }
     List<RexExpression> childNodes = rexCall.getFunctionOperands();
@@ -288,7 +288,7 @@ public class CalciteRexExpressionParser {
 
   private static Expression getFunctionExpression(String canonicalName) {
     Expression expression = new Expression(ExpressionType.FUNCTION);
-    Function function = new Function(canonicalizeFunctionName(canonicalName));
+    Function function = new Function(canonicalName);
     expression.setFunctionCall(function);
     return expression;
   }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralHintUtils.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralHintUtils.java
deleted file mode 100644
index ea854e9aba..0000000000
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralHintUtils.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.query.planner.logical;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import org.apache.calcite.util.Pair;
-import org.apache.commons.lang3.StringUtils;
-import org.apache.pinot.common.request.Literal;
-import org.apache.pinot.spi.data.FieldSpec;
-import org.apache.pinot.spi.utils.BytesUtils;
-
-
-public class LiteralHintUtils {
-  private LiteralHintUtils() {
-  }
-
-  public static String literalMapToHintString(Map<Pair<Integer, Integer>, 
RexExpression.Literal> literals) {
-    List<String> literalStrings = new ArrayList<>(literals.size());
-    for (Map.Entry<Pair<Integer, Integer>, RexExpression.Literal> e : 
literals.entrySet()) {
-      // individual literal parts are joined with `|`
-      literalStrings.add(
-          String.format("%d|%d|%s|%s", e.getKey().left, e.getKey().right, 
e.getValue().getDataType().name(),
-              e.getValue().getValue()));
-    }
-    // semi-colon is used to separate between encoded literals
-    return "{" + StringUtils.join(literalStrings, ";:;") + "}";
-  }
-
-  public static Map<Integer, Map<Integer, Literal>> 
hintStringToLiteralMap(String literalString) {
-    Map<Integer, Map<Integer, Literal>> aggCallToLiteralArgsMap = new 
HashMap<>();
-    if (StringUtils.isNotEmpty(literalString) && !"{}".equals(literalString)) {
-      String[] literalStringArr = literalString.substring(1, 
literalString.length() - 1).split(";:;");
-      for (String literalStr : literalStringArr) {
-        String[] literalStrParts = literalStr.split("\\|", 4);
-        int aggIdx = Integer.parseInt(literalStrParts[0]);
-        int argListIdx = Integer.parseInt(literalStrParts[1]);
-        String dataTypeNameStr = literalStrParts[2];
-        String valueStr = literalStrParts[3];
-        Map<Integer, Literal> literalArgs = 
aggCallToLiteralArgsMap.computeIfAbsent(aggIdx, i -> new HashMap<>());
-        literalArgs.put(argListIdx, stringToLiteral(dataTypeNameStr, 
valueStr));
-      }
-    }
-    return aggCallToLiteralArgsMap;
-  }
-
-  private static Literal stringToLiteral(String dataTypeStr, String valueStr) {
-    FieldSpec.DataType dataType = FieldSpec.DataType.valueOf(dataTypeStr);
-    switch (dataType) {
-      case BOOLEAN:
-        return Literal.boolValue(valueStr.equals("1"));
-      case INT:
-        return Literal.intValue(Integer.parseInt(valueStr));
-      case LONG:
-        return Literal.longValue(Long.parseLong(valueStr));
-      case FLOAT:
-      case DOUBLE:
-        return Literal.doubleValue(Double.parseDouble(valueStr));
-      case STRING:
-        return Literal.stringValue(valueStr);
-      case BYTES:
-        return Literal.binaryValue(BytesUtils.toBytes(valueStr));
-      default:
-        throw new UnsupportedOperationException("Unsupported RexLiteral type: 
" + dataTypeStr);
-    }
-  }
-}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
index 5a80cd2596..c2e9890358 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
@@ -246,8 +246,10 @@ public class RexExpressionUtils {
   }
 
   public static RexExpression fromAggregateCall(AggregateCall aggregateCall) {
-    List<RexExpression> operands =
-        
aggregateCall.getArgList().stream().map(RexExpression.InputRef::new).collect(Collectors.toList());
+    List<RexExpression> operands = new 
ArrayList<>(aggregateCall.rexList.size());
+    for (RexNode rexNode : aggregateCall.rexList) {
+      operands.add(fromRexNode(rexNode));
+    }
     return new 
RexExpression.FunctionCall(aggregateCall.getAggregation().getKind(),
         
RelToPlanNodeConverter.convertToColumnDataType(aggregateCall.getType()),
         aggregateCall.getAggregation().getName(), operands, 
aggregateCall.isDistinct());
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 810202ca49..8e74660e7a 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -255,7 +255,8 @@ public class QueryCompilationTest extends 
QueryEnvironmentTestBase {
   public void testQueryWithHint() {
     // Hinting the query to use final stage aggregation makes server directly 
return final result
     // This is useful when data is already partitioned by col1
-    String query = "SELECT /*+ aggOptionsInternal(agg_type='DIRECT') */ col1, 
COUNT(*) FROM b GROUP BY col1";
+    String query =
+        "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ 
col1, COUNT(*) FROM b GROUP BY col1";
     DispatchableSubPlan dispatchableSubPlan = 
_queryEnvironment.planQuery(query);
     List<DispatchablePlanFragment> stagePlans = 
dispatchableSubPlan.getQueryStageList();
     int numStages = stagePlans.size();
diff --git a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json 
b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
index a7a4b1a8be..8a0878d6e1 100644
--- a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
@@ -102,7 +102,7 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a 
GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+          "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
           "\n  PinotLogicalExchange(distribution=[hash[0]])",
           "\n    LogicalProject(col1=[$0], col3=[$2])",
           "\n      LogicalTableScan(table=[[default, a]])",
@@ -128,7 +128,7 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($1):DOUBLE 
NOT NULL, $2)], EXPR$3=[$3], EXPR$4=[$4])",
-          "\n  LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
agg#1=[COUNT()], EXPR$3=[MAX($1)], EXPR$4=[MIN($1)])",
+          "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
agg#1=[COUNT()], agg#2=[MAX($1)], agg#3=[MIN($1)])",
           "\n    PinotLogicalExchange(distribution=[hash[0]])",
           "\n      LogicalProject(col1=[$0], col3=[$2])",
           "\n        LogicalTableScan(table=[[default, a]])",
@@ -140,7 +140,7 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a 
WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+          "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
           "\n  PinotLogicalExchange(distribution=[hash[0]])",
           "\n    LogicalProject(col1=[$0], col3=[$2])",
           "\n      LogicalFilter(condition=[AND(>=($2, 0), =($1, 
_UTF-8'a'))])",
@@ -153,7 +153,7 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3), 
MAX(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
EXPR$2=[MAX($1)])",
+          "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
agg#1=[MAX($1)])",
           "\n  PinotLogicalExchange(distribution=[hash[0]])",
           "\n    LogicalProject(col1=[$0], col3=[$2])",
           "\n      LogicalFilter(condition=[AND(>=($2, 0), =($1, 
_UTF-8'a'))])",
@@ -167,7 +167,7 @@
         "notes": "TODO: Needs follow up. Project should only keep a.col1 since 
the other columns are pushed to the filter, but it currently keeps them all",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], EXPR$1=[COUNT()])",
+          "\nLogicalAggregate(group=[{0}], agg#0=[COUNT()])",
           "\n  PinotLogicalExchange(distribution=[hash[0]])",
           "\n    LogicalProject(col1=[$0])",
           "\n      LogicalFilter(condition=[AND(>=($2, 0), =($1, 
_UTF-8'a'))])",
@@ -181,7 +181,7 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])",
-          "\n  LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])",
+          "\n  LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
           "\n    PinotLogicalExchange(distribution=[hash[0, 1]])",
           "\n      LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
           "\n        LogicalFilter(condition=[AND(>=($2, 0), =($0, 
_UTF-8'a'))])",
@@ -196,7 +196,7 @@
           "Execution Plan",
           "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])",
           "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
-          "\n    LogicalAggregate(group=[{0}], EXPR$1=[COUNT()], 
EXPR$2=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
+          "\n    LogicalAggregate(group=[{0}], agg#0=[COUNT()], 
agg#1=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
           "\n      PinotLogicalExchange(distribution=[hash[0]])",
           "\n        LogicalProject(col1=[$0], col3=[$2])",
           "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 
_UTF-8'a'))])",
@@ -211,7 +211,7 @@
           "Execution Plan",
           "\nLogicalProject(col1=[$0], EXPR$1=[$1])",
           "\n  LogicalFilter(condition=[AND(>=($2, 0), <($3, 20), <=($1, 10), 
=(/(CAST($1):DOUBLE NOT NULL, $4), 5))])",
-          "\n    LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()])",
+          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()])",
           "\n      PinotLogicalExchange(distribution=[hash[0]])",
           "\n        LogicalProject(col1=[$0], col3=[$2])",
           "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 
_UTF-8'a'))])",
@@ -226,7 +226,7 @@
           "Execution Plan",
           "\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])",
           "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
-          "\n    LogicalAggregate(group=[{0}], count=[COUNT()], 
SUM=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
+          "\n    LogicalAggregate(group=[{0}], agg#0=[COUNT()], 
agg#1=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
           "\n      PinotLogicalExchange(distribution=[hash[0]])",
           "\n        LogicalProject(col1=[$0], col3=[$2])",
           "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 
_UTF-8'a'))])",
diff --git a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json 
b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
index 7b97f583ea..32d1eb65f8 100644
--- a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
@@ -93,7 +93,7 @@
           "Execution Plan",
           "\nLogicalSort(sort0=[$0], dir0=[ASC])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[0]], 
isSortOnSender=[false], isSortOnReceiver=[true])",
-          "\n    LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
           "\n      PinotLogicalExchange(distribution=[hash[0]])",
           "\n        LogicalProject(col1=[$0], col3=[$2])",
           "\n          LogicalTableScan(table=[[default, a]])",
@@ -121,7 +121,7 @@
           "Execution Plan",
           "\nLogicalSort(sort0=[$0], dir0=[ASC])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[0]], 
isSortOnSender=[false], isSortOnReceiver=[true])",
-          "\n    LogicalAggregate(group=[{0}], sum=[$SUM0($1)])",
+          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
           "\n      PinotLogicalExchange(distribution=[hash[0]])",
           "\n        LogicalProject(col1=[$0], col3=[$2])",
           "\n          LogicalTableScan(table=[[default, a]])",
diff --git 
a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json 
b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
index 5841c442ff..3f0a4cd0f0 100644
--- a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
+++ b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
@@ -100,10 +100,10 @@
       },
       {
         "description": "semi-join with dynamic_broadcast join strategy then 
group-by on same key",
-        "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptionsInternal(agg_type='DIRECT') */ a.col1, SUM(a.col3) FROM a WHERE 
a.col1 IN (SELECT col2 FROM b WHERE b.col3 > 0) GROUP BY 1",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptions(is_partitioned_by_group_by_keys='true') */ a.col1, SUM(a.col3) FROM 
a WHERE a.col1 IN (SELECT col2 FROM b WHERE b.col3 > 0) GROUP BY 1",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+          "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
           "\n  LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
           "\n    LogicalProject(col1=[$0], col3=[$2])",
           "\n      LogicalTableScan(table=[[default, a]])",
@@ -138,7 +138,7 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])",
-          "\n  LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])",
+          "\n  LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
           "\n    PinotLogicalExchange(distribution=[hash[0, 1]])",
           "\n      LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
           "\n        LogicalFilter(condition=[AND(>=($2, 0), =($0, 
_UTF-8'a'))])",
@@ -153,7 +153,7 @@
           "Execution Plan",
           "\nLogicalProject(col2=[$0], EXPR$1=[$1], EXPR$2=[$2], EXPR$3=[$3])",
           "\n  LogicalFilter(condition=[AND(>($1, 10), >=($4, 0), <($5, 20), 
<=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
-          "\n    LogicalAggregate(group=[{0}], EXPR$1=[COUNT()], 
EXPR$2=[$SUM0($1)], EXPR$3=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)])",
+          "\n    LogicalAggregate(group=[{0}], agg#0=[COUNT()], 
agg#1=[$SUM0($1)], agg#2=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)])",
           "\n      PinotLogicalExchange(distribution=[hash[0]])",
           "\n        LogicalProject(col2=[$1], col3=[$2], 
$f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])",
           "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 
_UTF-8'a'))])",
@@ -162,24 +162,11 @@
         ]
       },
       {
-        "description": "aggregate with skip intermediate stage hint (via 
hinting the leaf stage group by as final stage_",
-        "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptionsInternal(agg_type='DIRECT') */ a.col2, COUNT(*), SUM(a.col3), 
SUM(a.col1) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col2 HAVING 
COUNT(*) > 10",
-        "output": [
-          "Execution Plan",
-          "\nLogicalFilter(condition=[>($1, 10)])",
-          "\n  LogicalAggregate(group=[{0}], EXPR$1=[COUNT()], 
EXPR$2=[$SUM0($1)], EXPR$3=[$SUM0($2)])",
-          "\n    LogicalProject(col2=[$1], col3=[$2], 
$f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])",
-          "\n      LogicalFilter(condition=[AND(>=($2, 0), =($1, 
_UTF-8'a'))])",
-          "\n        LogicalTableScan(table=[[default, a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "aggregate with skip leaf stage hint (via hint option 
is_partitioned_by_group_by_keys",
+        "description": "aggregate with skip intermediate stage hint (via hint 
option is_partitioned_by_group_by_keys)",
         "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, COUNT(*), 
SUM(a.col3), SUM(a.col1) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY 
a.col2",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], EXPR$1=[COUNT()], 
EXPR$2=[$SUM0($1)], EXPR$3=[$SUM0($2)])",
+          "\nLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], 
agg#2=[$SUM0($2)])",
           "\n  LogicalProject(col2=[$1], col3=[$2], 
$f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])",
           "\n    LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
           "\n      LogicalTableScan(table=[[default, a]])",
@@ -409,7 +396,7 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM 
a /*+ tableOptions(partition_function='hashcode', partition_key='col2', 
partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b /*+ 
tableOptions(partition_function='hashcode', partition_key='col1', 
partition_size='4') */ WHERE b.col3 > 0) GROUP BY 1",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+          "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
           "\n  LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
           "\n    LogicalProject(col2=[$1], col3=[$2])",
           "\n      LogicalTableScan(table=[[default, a]])",
@@ -425,7 +412,7 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ 
aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM 
a /*+ tableOptions(partition_function='hashcode', partition_key='col2', 
partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b WHERE b.col3 > 0) 
GROUP BY 1",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+          "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
           "\n  LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
           "\n    LogicalProject(col2=[$1], col3=[$2])",
           "\n      LogicalTableScan(table=[[default, a]])",
@@ -443,7 +430,7 @@
           "Execution Plan",
           "\nLogicalProject(col2=[$0], EXPR$1=[$1])",
           "\n  LogicalFilter(condition=[>($2, 5)])",
-          "\n    LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
agg#1=[COUNT()])",
+          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
agg#1=[COUNT()])",
           "\n      LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
           "\n        LogicalProject(col2=[$1], col3=[$2])",
           "\n          LogicalTableScan(table=[[default, a]])",
@@ -461,7 +448,7 @@
           "Execution Plan",
           "\nLogicalSort(sort0=[$1], dir0=[DESC])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[1 
DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
-          "\n    LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
           "\n      LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
           "\n        LogicalProject(col2=[$1], col3=[$2])",
           "\n          LogicalTableScan(table=[[default, a]])",
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 7cf7d5f2a7..a19ff64d4e 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -27,10 +27,8 @@ import java.util.List;
 import java.util.Map;
 import javax.annotation.Nullable;
 import org.apache.calcite.sql.SqlKind;
-import org.apache.pinot.calcite.rel.hint.PinotHintOptions;
 import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.common.datatable.StatMap;
-import org.apache.pinot.common.request.Literal;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.FunctionContext;
 import org.apache.pinot.common.utils.DataSchema;
@@ -43,13 +41,13 @@ import 
org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import 
org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
 import 
org.apache.pinot.core.query.aggregation.function.CountAggregationFunction;
 import org.apache.pinot.core.util.DataBlockExtractUtils;
-import org.apache.pinot.query.planner.logical.LiteralHintUtils;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
 import org.apache.pinot.query.planner.plannode.AggregateNode.AggType;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
-import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.utils.BooleanUtils;
 import org.roaringbitmap.RoaringBitmap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -65,11 +63,9 @@ public class AggregateOperator extends MultiStageOperator {
   private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR";
   private static final CountAggregationFunction COUNT_STAR_AGG_FUNCTION =
       new 
CountAggregationFunction(Collections.singletonList(ExpressionContext.forIdentifier("*")),
 false);
-  private static final ExpressionContext PLACEHOLDER_IDENTIFIER = 
ExpressionContext.forIdentifier("__PLACEHOLDER__");
 
   private final MultiStageOperator _inputOperator;
   private final DataSchema _resultSchema;
-  private final AggType _aggType;
   private final MultistageAggregationExecutor _aggregationExecutor;
   private final MultistageGroupByExecutor _groupByExecutor;
   @Nullable
@@ -78,29 +74,15 @@ public class AggregateOperator extends MultiStageOperator {
 
   private boolean _hasConstructedAggregateBlock;
 
-  public AggregateOperator(OpChainExecutionContext context, MultiStageOperator 
inputOperator,
-      DataSchema resultSchema, List<RexExpression> aggCalls, 
List<RexExpression> groupSet, AggType aggType,
-      List<Integer> filterArgIndices, @Nullable AbstractPlanNode.NodeHint 
nodeHint) {
+  public AggregateOperator(OpChainExecutionContext context, MultiStageOperator 
inputOperator, DataSchema resultSchema,
+      List<RexExpression> aggCalls, List<RexExpression> groupSet, AggType 
aggType, List<Integer> filterArgIndices,
+      @Nullable AbstractPlanNode.NodeHint nodeHint) {
     super(context);
     _inputOperator = inputOperator;
     _resultSchema = resultSchema;
-    _aggType = aggType;
-
-    // Process literal hints
-    Map<Integer, Map<Integer, Literal>> literalArgumentsMap = null;
-    if (nodeHint != null) {
-      Map<String, String> aggOptions = 
nodeHint._hintOptions.get(PinotHintOptions.INTERNAL_AGG_OPTIONS);
-      if (aggOptions != null) {
-        literalArgumentsMap = LiteralHintUtils.hintStringToLiteralMap(
-            
aggOptions.get(PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE));
-      }
-    }
-    if (literalArgumentsMap == null) {
-      literalArgumentsMap = Collections.emptyMap();
-    }
 
     // Initialize the aggregation functions
-    AggregationFunction<?, ?>[] aggFunctions = getAggFunctions(aggCalls, 
literalArgumentsMap);
+    AggregationFunction<?, ?>[] aggFunctions = getAggFunctions(aggCalls);
 
     // Process the filter argument indices
     int numFunctions = aggFunctions.length;
@@ -214,27 +196,16 @@ public class AggregateOperator extends MultiStageOperator 
{
     return block;
   }
 
-  private AggregationFunction<?, ?>[] getAggFunctions(List<RexExpression> 
aggCalls,
-      Map<Integer, Map<Integer, Literal>> literalArgumentsMap) {
+  private AggregationFunction<?, ?>[] getAggFunctions(List<RexExpression> 
aggCalls) {
     int numFunctions = aggCalls.size();
     AggregationFunction<?, ?>[] aggFunctions = new 
AggregationFunction[numFunctions];
-    if (!_aggType.isInputIntermediateFormat()) {
-      for (int i = 0; i < numFunctions; i++) {
-        Map<Integer, Literal> literalArguments = 
literalArgumentsMap.getOrDefault(i, Collections.emptyMap());
-        aggFunctions[i] = 
getAggFunctionForRawInput((RexExpression.FunctionCall) aggCalls.get(i), 
literalArguments);
-      }
-    } else {
-      for (int i = 0; i < numFunctions; i++) {
-        Map<Integer, Literal> literalArguments = 
literalArgumentsMap.getOrDefault(i, Collections.emptyMap());
-        aggFunctions[i] =
-            getAggFunctionForIntermediateInput((RexExpression.FunctionCall) 
aggCalls.get(i), literalArguments);
-      }
+    for (int i = 0; i < numFunctions; i++) {
+      aggFunctions[i] = getAggFunction((RexExpression.FunctionCall) 
aggCalls.get(i));
     }
     return aggFunctions;
   }
 
-  private AggregationFunction<?, ?> 
getAggFunctionForRawInput(RexExpression.FunctionCall functionCall,
-      Map<Integer, Literal> literalArguments) {
+  private AggregationFunction<?, ?> getAggFunction(RexExpression.FunctionCall 
functionCall) {
     String functionName = functionCall.getFunctionName();
     List<RexExpression> operands = functionCall.getFunctionOperands();
     int numArguments = operands.size();
@@ -244,78 +215,26 @@ public class AggregateOperator extends MultiStageOperator 
{
       return COUNT_STAR_AGG_FUNCTION;
     }
     List<ExpressionContext> arguments = new ArrayList<>(numArguments);
-    for (int i = 0; i < numArguments; i++) {
-      Literal literalArgument = literalArguments.get(i);
-      if (literalArgument != null) {
-        arguments.add(ExpressionContext.forLiteralContext(literalArgument));
+    for (RexExpression operand : operands) {
+      if (operand instanceof RexExpression.InputRef) {
+        RexExpression.InputRef inputRef = (RexExpression.InputRef) operand;
+        
arguments.add(ExpressionContext.forIdentifier(fromColIdToIdentifier(inputRef.getIndex())));
       } else {
-        RexExpression operand = operands.get(i);
-        switch (operand.getKind()) {
-          case INPUT_REF:
-            RexExpression.InputRef inputRef = (RexExpression.InputRef) operand;
-            
arguments.add(ExpressionContext.forIdentifier(fromColIdToIdentifier(inputRef.getIndex())));
-            break;
-          case LITERAL:
-            RexExpression.Literal literalRexExp = (RexExpression.Literal) 
operand;
-            
arguments.add(ExpressionContext.forLiteralContext(literalRexExp.getDataType().toDataType(),
-                literalRexExp.getValue()));
-            break;
-          default:
-            throw new IllegalStateException("Illegal aggregation function 
operand type: " + operand.getKind());
+        assert operand instanceof RexExpression.Literal;
+        RexExpression.Literal literal = (RexExpression.Literal) operand;
+        DataType dataType = literal.getDataType().toDataType();
+        Object value = literal.getValue();
+        // TODO: Fix BOOLEAN literal to directly store true/false
+        if (dataType == DataType.BOOLEAN) {
+          value = BooleanUtils.fromNonNullInternalValue(value);
         }
+        arguments.add(ExpressionContext.forLiteralContext(dataType, value));
       }
     }
-    handleListAggDistinctArg(functionName, functionCall, arguments);
     return AggregationFunctionFactory.getAggregationFunction(
         new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, 
arguments), true);
   }
 
-  private static AggregationFunction<?, ?> 
getAggFunctionForIntermediateInput(RexExpression.FunctionCall functionCall,
-      Map<Integer, Literal> literalArguments) {
-    String functionName = functionCall.getFunctionName();
-    List<RexExpression> operands = functionCall.getFunctionOperands();
-    int numArguments = operands.size();
-    Preconditions.checkState(numArguments == 1, "Intermediate aggregate must 
have 1 argument, got: %s", numArguments);
-    RexExpression operand = operands.get(0);
-    Preconditions.checkState(operand.getKind() == SqlKind.INPUT_REF,
-        "Intermediate aggregate argument must be an input reference, got: %s", 
operand.getKind());
-    // We might need to append extra arguments extracted from the hint to 
match the signature of the aggregation
-    Literal numArgumentsLiteral = literalArguments.get(-1);
-    if (numArgumentsLiteral == null) {
-      return AggregationFunctionFactory.getAggregationFunction(
-          new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, 
Collections.singletonList(
-              
ExpressionContext.forIdentifier(fromColIdToIdentifier(((RexExpression.InputRef) 
operand).getIndex())))),
-          true);
-    } else {
-      int numExpectedArguments = numArgumentsLiteral.getIntValue();
-      List<ExpressionContext> arguments = new 
ArrayList<>(numExpectedArguments);
-      arguments.add(
-          
ExpressionContext.forIdentifier(fromColIdToIdentifier(((RexExpression.InputRef) 
operand).getIndex())));
-      for (int i = 1; i < numExpectedArguments; i++) {
-        Literal literalArgument = literalArguments.get(i);
-        if (literalArgument != null) {
-          arguments.add(ExpressionContext.forLiteralContext(literalArgument));
-        } else {
-          arguments.add(PLACEHOLDER_IDENTIFIER);
-        }
-      }
-      handleListAggDistinctArg(functionName, functionCall, arguments);
-      return AggregationFunctionFactory.getAggregationFunction(
-          new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, 
arguments), true);
-    }
-  }
-
-  private static void handleListAggDistinctArg(String functionName, 
RexExpression.FunctionCall functionCall,
-      List<ExpressionContext> arguments) {
-    String upperCaseFunctionName =
-        
AggregationFunctionType.getNormalizedAggregationFunctionName(functionName);
-    if (upperCaseFunctionName.equals("LISTAGG")) {
-      if (functionCall.isDistinct()) {
-        
arguments.add(ExpressionContext.forLiteralContext(Literal.boolValue(true)));
-      }
-    }
-  }
-
   private static String fromColIdToIdentifier(int colId) {
     return "$" + colId;
   }
diff --git a/pinot-query-runtime/src/test/resources/queries/QueryHints.json 
b/pinot-query-runtime/src/test/resources/queries/QueryHints.json
index f8c850fcd3..81a939c2e1 100644
--- a/pinot-query-runtime/src/test/resources/queries/QueryHints.json
+++ b/pinot-query-runtime/src/test/resources/queries/QueryHints.json
@@ -275,7 +275,7 @@
       },
       {
         "description": "semi-join with dynamic_broadcast join strategy then 
group-by on same key",
-        "sql": "SELECT /*+ aggOptionsInternal(agg_type='DIRECT') */ 
{tbl1}.num, SUM({tbl1}.val) FROM {tbl1} WHERE {tbl1}.name IN (SELECT id FROM 
{tbl2} WHERE {tbl2}.data > 0) GROUP BY {tbl1}.num"
+        "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') 
*/ {tbl1}.num, SUM({tbl1}.val) FROM {tbl1} WHERE {tbl1}.name IN (SELECT id FROM 
{tbl2} WHERE {tbl2}.data > 0) GROUP BY {tbl1}.num"
       },
       {
         "description": "semi-join with dynamic_broadcast join strategy then 
group-by on different key",
@@ -290,11 +290,7 @@
         "sql": "SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ 
{tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num) FROM {tbl1} WHERE 
{tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num HAVING COUNT(*) > 10 
AND MAX({tbl1}.val) >= 0 AND MIN({tbl1}.val) < 20 AND SUM({tbl1}.val) <= 10 AND 
AVG({tbl1}.val) = 5"
       },
       {
-        "description": "aggregate with skip intermediate stage hint (via 
hinting the leaf stage group by as final stage_",
-        "sql": "SELECT /*+ aggOptionsInternal(agg_type='DIRECT') */ 
{tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num) FROM {tbl1} WHERE 
{tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num HAVING COUNT(*) > 10"
-      },
-      {
-        "description": "aggregate with skip leaf stage hint (via hint option 
is_partitioned_by_group_by_keys",
+        "description": "aggregate with skip intermediate stage hint (via hint 
option is_partitioned_by_group_by_keys)",
         "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') 
*/ {tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num) FROM {tbl1} WHERE 
{tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num"
       },
       {
diff --git 
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
 
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index a6c468d8fe..877ac7f232 100644
--- 
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++ 
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -460,9 +460,10 @@ public enum AggregationFunctionType {
    * <p>NOTE: Underscores in the function name are ignored.
    */
   public static AggregationFunctionType getAggregationFunctionType(String 
functionName) {
-    if (functionName.regionMatches(true, 0, "percentile", 0, 10)) {
+    String normalizedFunctionName = 
getNormalizedAggregationFunctionName(functionName);
+    if (normalizedFunctionName.regionMatches(false, 0, "PERCENTILE", 0, 10)) {
       // This style of aggregation functions is not supported in the 
multistage engine
-      String remainingFunctionName = 
getNormalizedAggregationFunctionName(functionName).substring(10).toUpperCase();
+      String remainingFunctionName = 
normalizedFunctionName.substring(10).toUpperCase();
       if (remainingFunctionName.isEmpty() || 
remainingFunctionName.matches("\\d+")) {
         return PERCENTILE;
       } else if (remainingFunctionName.equals("EST") || 
remainingFunctionName.matches("EST\\d+")) {
@@ -496,7 +497,7 @@ public enum AggregationFunctionType {
       }
     } else {
       try {
-        return 
AggregationFunctionType.valueOf(getNormalizedAggregationFunctionName(functionName));
+        return AggregationFunctionType.valueOf(normalizedFunctionName);
       } catch (IllegalArgumentException e) {
         throw new IllegalArgumentException("Invalid aggregation function name: 
" + functionName);
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to