This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new e2c5e73970 Pass literal within AggregateCall via rexList (#13282) e2c5e73970 is described below commit e2c5e73970b1e8f64df7c763c5bcac36ff19d2a6 Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com> AuthorDate: Fri May 31 18:00:00 2024 -0700 Pass literal within AggregateCall via rexList (#13282) --- .../pinot/calcite/rel/hint/PinotHintOptions.java | 13 - .../PinotAggregateExchangeNodeInsertRule.java | 422 ++++++++++----------- .../rules/PinotAggregateLiteralAttachmentRule.java | 107 ------ .../calcite/rel/rules/PinotQueryRuleSets.java | 5 - .../org/apache/pinot/query/QueryEnvironment.java | 4 - .../query/parser/CalciteRexExpressionParser.java | 4 +- .../query/planner/logical/LiteralHintUtils.java | 85 ----- .../query/planner/logical/RexExpressionUtils.java | 6 +- .../apache/pinot/query/QueryCompilationTest.java | 3 +- .../src/test/resources/queries/GroupByPlans.json | 18 +- .../src/test/resources/queries/OrderByPlans.json | 4 +- .../test/resources/queries/PinotHintablePlans.json | 33 +- .../query/runtime/operator/AggregateOperator.java | 125 ++---- .../src/test/resources/queries/QueryHints.json | 8 +- .../pinot/segment/spi/AggregationFunctionType.java | 7 +- 15 files changed, 256 insertions(+), 588 deletions(-) diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java index 1d53a3184e..99e07b61df 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java @@ -20,7 +20,6 @@ package org.apache.pinot.calcite.rel.hint; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.hint.RelHint; -import org.apache.pinot.query.planner.logical.LiteralHintUtils; /** @@ -47,18 +46,6 @@ public class PinotHintOptions { public static class InternalAggregateOptions { public static final String AGG_TYPE = "agg_type"; - /** - * agg call signature is used to store LITERAL inputs to the Aggregate Call. which is not supported in Calcite - * here - * 1. we store the Map of Pair[aggCallIdx, argListIdx] to RexLiteral to indicate the RexLiteral being passed into - * the aggregateCalls[aggCallIdx].operandList[argListIdx] is supposed to be a RexLiteral. - * 2. not all RexLiteral types are supported to be part of the input constant call signature. - * 3. RexLiteral are encoded as String and decoded as Pinot Literal objects. - * - * see: {@link LiteralHintUtils}. - * see: https://issues.apache.org/jira/projects/CALCITE/issues/CALCITE-5833 - */ - public static final String AGG_CALL_SIGNATURE = "agg_call_signature"; } public static class AggregateOptions { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java index ffe0741751..0e6e13b0e7 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java @@ -19,20 +19,16 @@ package org.apache.pinot.calcite.rel.rules; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.Map; -import java.util.Set; import javax.annotation.Nullable; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; -import org.apache.calcite.plan.hep.HepRelVertex; import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelDistributions; import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; @@ -44,16 +40,16 @@ import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.rules.AggregateExtractProjectRule; import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; -import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.tools.RelBuilderFactory; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.ImmutableIntList; -import org.apache.calcite.util.Util; import org.apache.calcite.util.mapping.Mapping; import org.apache.calcite.util.mapping.MappingType; import org.apache.calcite.util.mapping.Mappings; @@ -88,8 +84,6 @@ import org.apache.pinot.segment.spi.AggregationFunctionType; public class PinotAggregateExchangeNodeInsertRule extends RelOptRule { public static final PinotAggregateExchangeNodeInsertRule INSTANCE = new PinotAggregateExchangeNodeInsertRule(PinotRuleUtils.PINOT_REL_FACTORY); - public static final Set<String> LIST_AGG_FUNCTION_NAMES = - ImmutableSet.of("LISTAGG", "LIST_AGG", "ARRAYsAGG", "ARRAY_AGG"); public PinotAggregateExchangeNodeInsertRule(RelBuilderFactory factory) { super(operand(LogicalAggregate.class, any()), factory, null); @@ -119,137 +113,104 @@ public class PinotAggregateExchangeNodeInsertRule extends RelOptRule { */ @Override public void onMatch(RelOptRuleCall call) { - Aggregate oldAggRel = call.rel(0); - ImmutableList<RelHint> oldHints = oldAggRel.getHints(); - // Both collation and distinct are not supported in leaf stage aggregation. - boolean hasCollation = hasCollation(oldAggRel); - boolean hasDistinct = hasDistinct(oldAggRel); - Aggregate newAgg; - if (!oldAggRel.getGroupSet().isEmpty() && PinotHintStrategyTable.isHintOptionTrue(oldHints, - PinotHintOptions.AGGREGATE_HINT_OPTIONS, PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS)) { - // ------------------------------------------------------------------------ - // If the "is_partitioned_by_group_by_keys" aggregate hint option is set, just add additional hints indicating - // this is a single stage aggregation. - List<RelHint> newHints = PinotHintStrategyTable.replaceHintOptions(oldAggRel.getHints(), - PinotHintOptions.INTERNAL_AGG_OPTIONS, PinotHintOptions.InternalAggregateOptions.AGG_TYPE, - AggType.DIRECT.name()); - newAgg = - new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), newHints, oldAggRel.getInput(), - oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), oldAggRel.getAggCallList()); - } else if (hasCollation || hasDistinct || (!oldAggRel.getGroupSet().isEmpty() - && PinotHintStrategyTable.isHintOptionTrue(oldHints, PinotHintOptions.AGGREGATE_HINT_OPTIONS, + Aggregate argRel = call.rel(0); + ImmutableList<RelHint> hints = argRel.getHints(); + // Collation is not supported in leaf stage aggregation. + RelCollation collation = extractWithInGroupCollation(argRel); + boolean hasGroupBy = !argRel.getGroupSet().isEmpty(); + if (collation != null || (hasGroupBy && PinotHintStrategyTable.isHintOptionTrue(hints, + PinotHintOptions.AGGREGATE_HINT_OPTIONS, PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION))) { - // ------------------------------------------------------------------------ - // If "is_skip_leaf_stage_group_by" SQLHint option is passed, the leaf stage aggregation is skipped. - newAgg = (Aggregate) createPlanWithExchangeDirectAggregation(call); + call.transformTo(createPlanWithExchangeDirectAggregation(call, collation)); + } else if (hasGroupBy && PinotHintStrategyTable.isHintOptionTrue(hints, PinotHintOptions.AGGREGATE_HINT_OPTIONS, + PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS)) { + call.transformTo(createPlanWithDirectAggregation(call)); } else { - // ------------------------------------------------------------------------ - newAgg = (Aggregate) createPlanWithLeafExchangeFinalAggregate(call); + call.transformTo(createPlanWithLeafExchangeFinalAggregate(call)); } - call.transformTo(newAgg); } - private boolean hasDistinct(Aggregate aggRel) { + // TODO: Currently it only handles one WITHIN GROUP collation across all AggregateCalls. + @Nullable + private static RelCollation extractWithInGroupCollation(Aggregate aggRel) { for (AggregateCall aggCall : aggRel.getAggCallList()) { - // If the aggregation function is a list aggregation function and it is distinct, we can skip leaf stage. - // For COUNT(DISTINCT), there could be more leaf stage optimization. - if (aggCall.isDistinct() && LIST_AGG_FUNCTION_NAMES.contains(aggCall.getAggregation().getName().toUpperCase())) { - return true; + List<RelFieldCollation> fieldCollations = aggCall.getCollation().getFieldCollations(); + if (!fieldCollations.isEmpty()) { + return RelCollations.of(fieldCollations); } } - return false; + return null; } - private boolean hasCollation(Aggregate aggRel) { - for (AggregateCall aggCall : aggRel.getAggCallList()) { - if (!aggCall.getCollation().getKeys().isEmpty()) { - return true; - } - } - return false; + private static RelNode createPlanWithDirectAggregation(RelOptRuleCall call) { + Aggregate aggRel = call.rel(0); + List<RelHint> newHints = + PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(), PinotHintOptions.INTERNAL_AGG_OPTIONS, + PinotHintOptions.InternalAggregateOptions.AGG_TYPE, AggType.DIRECT.name()); + return new LogicalAggregate(aggRel.getCluster(), aggRel.getTraitSet(), newHints, aggRel.getInput(), + aggRel.getGroupSet(), aggRel.getGroupSets(), buildAggCalls(aggRel, AggType.DIRECT)); } /** * Aggregate node will be split into LEAF + exchange + FINAL. - * optionally we can insert INTERMEDIATE to reduce hotspot in the future. + * TODO: Add optional INTERMEDIATE stage to reduce hotspot. */ - private RelNode createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall call) { - // TODO: add optional intermediate stage here when hinted. - Aggregate oldAggRel = call.rel(0); - // 1. attach leaf agg RelHint to original agg. Perform any aggregation call conversions necessary - Aggregate leafAgg = convertAggForLeafInput(oldAggRel); - // 2. attach exchange. - List<Integer> groupSetIndices = ImmutableIntList.range(0, oldAggRel.getGroupCount()); - PinotLogicalExchange exchange; - if (groupSetIndices.size() == 0) { - exchange = PinotLogicalExchange.create(leafAgg, RelDistributions.hash(Collections.emptyList())); - } else { - exchange = PinotLogicalExchange.create(leafAgg, RelDistributions.hash(groupSetIndices)); - } - // 3. attach final agg stage. - return convertAggFromIntermediateInput(call, oldAggRel, exchange, AggType.FINAL); + private static RelNode createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall call) { + Aggregate aggRel = call.rel(0); + // Create a LEAF aggregate. + Aggregate leafAggRel = convertAggForLeafInput(aggRel); + // Create an exchange node over the LEAF aggregate. + PinotLogicalExchange exchange = PinotLogicalExchange.create(leafAggRel, + RelDistributions.hash(ImmutableIntList.range(0, aggRel.getGroupCount()))); + // Create a FINAL aggregate over the exchange. + return convertAggFromIntermediateInput(call, exchange, AggType.FINAL); } /** * Use this group by optimization to skip leaf stage aggregation when aggregating at leaf level is not desired. * Many situation could be wasted effort to do group-by on leaf, eg: when cardinality of group by column is very high. */ - private RelNode createPlanWithExchangeDirectAggregation(RelOptRuleCall call) { - Aggregate oldAggRel = call.rel(0); - List<RelHint> newHints = PinotHintStrategyTable.replaceHintOptions(oldAggRel.getHints(), - PinotHintOptions.INTERNAL_AGG_OPTIONS, PinotHintOptions.InternalAggregateOptions.AGG_TYPE, - AggType.DIRECT.name()); - - // Convert Aggregate WithGroup Collation into a Sort - RelCollation relCollation = extractWithInGroupCollation(oldAggRel); + private static RelNode createPlanWithExchangeDirectAggregation(RelOptRuleCall call, + @Nullable RelCollation collation) { + Aggregate aggRel = call.rel(0); + RelNode input = aggRel.getInput(); + // Create Project when there's none below the aggregate. + if (!(PinotRuleUtils.unboxRel(input) instanceof Project)) { + aggRel = (Aggregate) generateProjectUnderAggregate(call); + input = aggRel.getInput(); + } - // create project when there's none below the aggregate to reduce exchange overhead - RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel(); - if (!(childRel instanceof Project)) { - return convertAggForExchangeDirectAggregate(call, newHints, relCollation); + ImmutableBitSet groupSet = aggRel.getGroupSet(); + RelDistribution distribution = RelDistributions.hash(groupSet.asList()); + RelNode exchange; + if (collation != null) { + // Insert a LogicalSort node between exchange and aggregate whe collation exists. + exchange = PinotLogicalSortExchange.create(input, distribution, collation, false, true); } else { - // create normal exchange - List<Integer> groupSetIndices = new ArrayList<>(); - oldAggRel.getGroupSet().forEach(groupSetIndices::add); - RelNode newAggChild; - if (relCollation != null) { - newAggChild = - (groupSetIndices.isEmpty()) ? PinotLogicalSortExchange.create(childRel, RelDistributions.SINGLETON, - relCollation, false, true) - : PinotLogicalSortExchange.create(childRel, RelDistributions.hash(groupSetIndices), - relCollation, false, true); - } else { - newAggChild = PinotLogicalExchange.create(childRel, RelDistributions.hash(groupSetIndices)); - } - return new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), newHints, newAggChild, - oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), oldAggRel.getAggCallList()); + exchange = PinotLogicalExchange.create(input, distribution); } - } - // Extract the first collation in the AggregateCall list - @Nullable - private RelCollation extractWithInGroupCollation(Aggregate aggRel) { - for (AggregateCall aggCall : aggRel.getAggCallList()) { - List<RelFieldCollation> fieldCollations = aggCall.getCollation().getFieldCollations(); - if (!fieldCollations.isEmpty()) { - return RelCollations.of(fieldCollations); - } - } - return null; + List<RelHint> newHints = + PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(), PinotHintOptions.INTERNAL_AGG_OPTIONS, + PinotHintOptions.InternalAggregateOptions.AGG_TYPE, AggType.DIRECT.name()); + return new LogicalAggregate(aggRel.getCluster(), aggRel.getTraitSet(), newHints, exchange, groupSet, + aggRel.getGroupSets(), buildAggCalls(aggRel, AggType.DIRECT)); } /** - * The following is copied from {@link AggregateExtractProjectRule#onMatch(RelOptRuleCall)} - * with modification to insert an exchange in between the Aggregate and Project + * The following is copied from {@link AggregateExtractProjectRule#onMatch(RelOptRuleCall)} with modification to take + * aggregate input as input. */ - private RelNode convertAggForExchangeDirectAggregate(RelOptRuleCall call, List<RelHint> newHints, - @Nullable RelCollation collation) { + private static RelNode generateProjectUnderAggregate(RelOptRuleCall call) { final Aggregate aggregate = call.rel(0); + // --------------- MODIFIED --------------- final RelNode input = aggregate.getInput(); + // final RelNode input = call.rel(1); + // ------------- END MODIFIED ------------- + // Compute which input fields are used. // 1. group fields are always used - final ImmutableBitSet.Builder inputFieldsUsed = - aggregate.getGroupSet().rebuild(); + final ImmutableBitSet.Builder inputFieldsUsed = aggregate.getGroupSet().rebuild(); // 2. agg functions for (AggregateCall aggCall : aggregate.getAggCallList()) { for (int i : aggCall.getArgList()) { @@ -259,149 +220,164 @@ public class PinotAggregateExchangeNodeInsertRule extends RelOptRule { inputFieldsUsed.set(aggCall.filterArg); } } - final RelBuilder relBuilder1 = call.builder().push(input); + final RelBuilder relBuilder = call.builder().push(input); final List<RexNode> projects = new ArrayList<>(); final Mapping mapping = - Mappings.create(MappingType.INVERSE_SURJECTION, - aggregate.getInput().getRowType().getFieldCount(), + Mappings.create(MappingType.INVERSE_SURJECTION, aggregate.getInput().getRowType().getFieldCount(), inputFieldsUsed.cardinality()); int j = 0; for (int i : inputFieldsUsed.build()) { - projects.add(relBuilder1.field(i)); + projects.add(relBuilder.field(i)); mapping.set(i, j++); } - relBuilder1.project(projects); - final ImmutableBitSet newGroupSet = - Mappings.apply(mapping, aggregate.getGroupSet()); - Project project = (Project) relBuilder1.build(); - // ------------------------------------------------------------------------ - RelNode newAggChild; - if (collation != null) { - // Insert a LogicalSort node between the exchange and the aggregate - newAggChild = newGroupSet.isEmpty() ? PinotLogicalSortExchange.create(project, RelDistributions.SINGLETON, - collation, false, true) - : PinotLogicalSortExchange.create(project, RelDistributions.hash(newGroupSet.asList()), - collation, false, true); - } else { - newAggChild = PinotLogicalExchange.create(project, RelDistributions.hash(newGroupSet.asList())); - } - // ------------------------------------------------------------------------ + relBuilder.project(projects); - final RelBuilder relBuilder2 = call.builder().push(newAggChild); + final ImmutableBitSet newGroupSet = Mappings.apply(mapping, aggregate.getGroupSet()); final List<ImmutableBitSet> newGroupSets = - aggregate.getGroupSets().stream() - .map(bitSet -> Mappings.apply(mapping, bitSet)) - .collect(Util.toImmutableList()); + aggregate.getGroupSets().stream().map(bitSet -> Mappings.apply(mapping, bitSet)) + .collect(ImmutableList.toImmutableList()); final List<RelBuilder.AggCall> newAggCallList = - aggregate.getAggCallList().stream() - .map(aggCall -> relBuilder2.aggregateCall(aggCall, mapping)) - .collect(Util.toImmutableList()); - final RelBuilder.GroupKey groupKey = - relBuilder2.groupKey(newGroupSet, newGroupSets); - relBuilder2.aggregate(groupKey, newAggCallList).hints(newHints); - return relBuilder2.build(); + aggregate.getAggCallList().stream().map(aggCall -> relBuilder.aggregateCall(aggCall, mapping)) + .collect(ImmutableList.toImmutableList()); + + final RelBuilder.GroupKey groupKey = relBuilder.groupKey(newGroupSet, newGroupSets); + relBuilder.aggregate(groupKey, newAggCallList); + return relBuilder.build(); } - private Aggregate convertAggForLeafInput(Aggregate oldAggRel) { - List<AggregateCall> oldCalls = oldAggRel.getAggCallList(); - List<AggregateCall> newCalls = new ArrayList<>(); - for (AggregateCall oldCall : oldCalls) { - newCalls.add(buildAggregateCall(oldAggRel.getInput(), oldCall, oldCall.getArgList(), oldAggRel.getGroupCount(), - AggType.LEAF)); - } - List<RelHint> newHints = PinotHintStrategyTable.replaceHintOptions(oldAggRel.getHints(), - PinotHintOptions.INTERNAL_AGG_OPTIONS, PinotHintOptions.InternalAggregateOptions.AGG_TYPE, AggType.LEAF.name()); - return new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), newHints, oldAggRel.getInput(), - oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), newCalls); + private static Aggregate convertAggForLeafInput(Aggregate aggRel) { + List<RelHint> newHints = + PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(), PinotHintOptions.INTERNAL_AGG_OPTIONS, + PinotHintOptions.InternalAggregateOptions.AGG_TYPE, AggType.LEAF.name()); + return new LogicalAggregate(aggRel.getCluster(), aggRel.getTraitSet(), newHints, aggRel.getInput(), + aggRel.getGroupSet(), aggRel.getGroupSets(), buildAggCalls(aggRel, AggType.LEAF)); } - private RelNode convertAggFromIntermediateInput(RelOptRuleCall ruleCall, Aggregate oldAggRel, - PinotLogicalExchange exchange, AggType aggType) { - // add the exchange as the input node to the relation builder. - RelBuilder relBuilder = ruleCall.builder(); - relBuilder.push(exchange); + private static RelNode convertAggFromIntermediateInput(RelOptRuleCall call, PinotLogicalExchange exchange, + AggType aggType) { + Aggregate aggRel = call.rel(0); + RelNode input = PinotRuleUtils.unboxRel(aggRel.getInput()); + List<RexNode> projects = (input instanceof Project) ? ((Project) input).getProjects() : null; - // make input ref to the exchange after the leaf aggregate, all groups should be at the front RexBuilder rexBuilder = exchange.getCluster().getRexBuilder(); - final int nGroups = oldAggRel.getGroupCount(); - for (int i = 0; i < nGroups; i++) { - rexBuilder.makeInputRef(oldAggRel, i); - } - - List<AggregateCall> newCalls = new ArrayList<>(); + int groupCount = aggRel.getGroupCount(); + List<AggregateCall> orgAggCalls = aggRel.getAggCallList(); + int numAggCalls = orgAggCalls.size(); + List<AggregateCall> aggCalls = new ArrayList<>(numAggCalls); Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>(); - // create new aggregate function calls from exchange input, all aggCalls are followed one by one from exchange - // b/c the exchange produces intermediate results, thus the input to the newCall will be indexed at - // [nGroup + oldCallIndex] - List<AggregateCall> oldCalls = oldAggRel.getAggCallList(); - for (int oldCallIndex = 0; oldCallIndex < oldCalls.size(); oldCallIndex++) { - AggregateCall oldCall = oldCalls.get(oldCallIndex); - // intermediate stage input only supports single argument inputs. - List<Integer> argList = Collections.singletonList(nGroups + oldCallIndex); - AggregateCall newCall = buildAggregateCall(exchange, oldCall, argList, nGroups, aggType); - rexBuilder.addAggCall(newCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); + // Create new AggregateCalls from exchange input. Exchange produces results with group keys followed by intermediate + // aggregate results. + for (int i = 0; i < numAggCalls; i++) { + AggregateCall orgAggCall = orgAggCalls.get(i); + List<Integer> argList = orgAggCall.getArgList(); + int index = groupCount + i; + RexInputRef inputRef = RexInputRef.of(index, aggRel.getRowType()); + // Generate rexList from argList and replace literal reference with literal. Keep the first argument as is. + int numArguments = argList.size(); + List<RexNode> rexList; + if (numArguments <= 1) { + rexList = ImmutableList.of(inputRef); + } else { + rexList = new ArrayList<>(numArguments); + rexList.add(inputRef); + for (int j = 1; j < numArguments; j++) { + int argument = argList.get(j); + if (projects != null && projects.get(argument) instanceof RexLiteral) { + rexList.add(projects.get(argument)); + } else { + // Replace all the input reference in the rexList to the new input reference. + rexList.add(inputRef); + } + } + } + AggregateCall newAggregate = buildAggCall(exchange, orgAggCall, rexList, groupCount, aggType); + rexBuilder.addAggCall(newAggregate, groupCount, aggCalls, aggCallMapping, aggRel.getInput()::fieldIsNullable); } - // create new aggregate relation. - ImmutableList<RelHint> orgHints = oldAggRel.getHints(); - List<RelHint> newAggHint = PinotHintStrategyTable.replaceHintOptions(orgHints, - PinotHintOptions.INTERNAL_AGG_OPTIONS, PinotHintOptions.InternalAggregateOptions.AGG_TYPE, aggType.name()); - ImmutableBitSet groupSet = ImmutableBitSet.range(nGroups); - relBuilder.aggregate(relBuilder.groupKey(groupSet, ImmutableList.of(groupSet)), newCalls); - relBuilder.hints(newAggHint); + RelBuilder relBuilder = call.builder(); + relBuilder.push(exchange); + ImmutableBitSet groupSet = ImmutableBitSet.range(groupCount); + relBuilder.aggregate(relBuilder.groupKey(groupSet, ImmutableList.of(groupSet)), aggCalls); + List<RelHint> newHints = + PinotHintStrategyTable.replaceHintOptions(aggRel.getHints(), PinotHintOptions.INTERNAL_AGG_OPTIONS, + PinotHintOptions.InternalAggregateOptions.AGG_TYPE, aggType.name()); + relBuilder.hints(newHints); return relBuilder.build(); } - private static AggregateCall buildAggregateCall(RelNode inputNode, AggregateCall orgAggCall, List<Integer> argList, - int numberGroups, AggType aggType) { - final SqlAggFunction oldAggFunction = orgAggCall.getAggregation(); - final SqlKind aggKind = oldAggFunction.getKind(); - String functionName = getFunctionNameFromAggregateCall(orgAggCall); - AggregationFunctionType functionType = AggregationFunctionType.getAggregationFunctionType(functionName); - // create the aggFunction - SqlAggFunction sqlAggFunction; - if (functionType.getIntermediateReturnTypeInference() != null) { - switch (aggType) { - case LEAF: - sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, - functionType.getSqlKind(), functionType.getIntermediateReturnTypeInference(), null, - functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); - break; - case INTERMEDIATE: - sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, - functionType.getSqlKind(), functionType.getIntermediateReturnTypeInference(), null, - OperandTypes.ANY, functionType.getSqlFunctionCategory()); - break; - case FINAL: - sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, - functionType.getSqlKind(), ReturnTypes.explicit(orgAggCall.getType()), null, - OperandTypes.ANY, functionType.getSqlFunctionCategory()); - break; - default: - throw new UnsupportedOperationException("Unsuppoted aggType: " + aggType + " for " + functionName); + private static List<AggregateCall> buildAggCalls(Aggregate aggRel, AggType aggType) { + RelNode input = PinotRuleUtils.unboxRel(aggRel.getInput()); + List<RexNode> projects = (input instanceof Project) ? ((Project) input).getProjects() : null; + List<AggregateCall> orgAggCalls = aggRel.getAggCallList(); + List<AggregateCall> aggCalls = new ArrayList<>(orgAggCalls.size()); + for (AggregateCall orgAggCall : orgAggCalls) { + // Generate rexList from argList and replace literal reference with literal. Keep the first argument as is. + List<Integer> argList = orgAggCall.getArgList(); + int numArguments = argList.size(); + List<RexNode> rexList; + if (numArguments == 0) { + rexList = ImmutableList.of(); + } else if (numArguments == 1) { + rexList = ImmutableList.of(RexInputRef.of(argList.get(0), input.getRowType())); + } else { + rexList = new ArrayList<>(numArguments); + rexList.add(RexInputRef.of(argList.get(0), input.getRowType())); + for (int i = 1; i < numArguments; i++) { + int argument = argList.get(i); + if (projects != null && projects.get(argument) instanceof RexLiteral) { + rexList.add(projects.get(argument)); + } else { + rexList.add(RexInputRef.of(argument, input.getRowType())); + } + } } - } else { - sqlAggFunction = oldAggFunction; + aggCalls.add(buildAggCall(input, orgAggCall, rexList, aggRel.getGroupCount(), aggType)); } - - return AggregateCall.create(sqlAggFunction, - functionName.equals("distinctCount") || orgAggCall.isDistinct(), - orgAggCall.isApproximate(), - orgAggCall.ignoreNulls(), - argList, - aggType.isInputIntermediateFormat() ? -1 : orgAggCall.filterArg, - orgAggCall.distinctKeys, - orgAggCall.collation, - numberGroups, - inputNode, - null, - null); + return aggCalls; } - private static String getFunctionNameFromAggregateCall(AggregateCall aggregateCall) { - return aggregateCall.getAggregation().getName().equalsIgnoreCase("COUNT") && aggregateCall.isDistinct() - ? "distinctCount" : aggregateCall.getAggregation().getName(); + // TODO: Revisit the following logic: + // - DISTINCT is resolved here + // - argList is replaced with rexList + private static AggregateCall buildAggCall(RelNode input, AggregateCall orgAggCall, List<RexNode> rexList, + int numGroups, AggType aggType) { + String functionName = orgAggCall.getAggregation().getName(); + if (orgAggCall.isDistinct()) { + if (functionName.equals("COUNT")) { + functionName = "DISTINCTCOUNT"; + } else if (functionName.equals("LISTAGG")) { + rexList.add(input.getCluster().getRexBuilder().makeLiteral(true)); + } + } + AggregationFunctionType functionType = AggregationFunctionType.getAggregationFunctionType(functionName); + SqlAggFunction sqlAggFunction; + switch (aggType) { + case DIRECT: + sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), + ReturnTypes.explicit(orgAggCall.getType()), null, functionType.getOperandTypeChecker(), + functionType.getSqlFunctionCategory()); + break; + case LEAF: + sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), + functionType.getIntermediateReturnTypeInference(), null, functionType.getOperandTypeChecker(), + functionType.getSqlFunctionCategory()); + break; + case INTERMEDIATE: + sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), + functionType.getIntermediateReturnTypeInference(), null, OperandTypes.ANY, + functionType.getSqlFunctionCategory()); + break; + case FINAL: + sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), + ReturnTypes.explicit(orgAggCall.getType()), null, OperandTypes.ANY, functionType.getSqlFunctionCategory()); + break; + default: + throw new IllegalStateException("Unsupported AggType: " + aggType); + } + return AggregateCall.create(sqlAggFunction, false, orgAggCall.isApproximate(), orgAggCall.ignoreNulls(), rexList, + ImmutableList.of(), aggType.isInputIntermediateFormat() ? -1 : orgAggCall.filterArg, orgAggCall.distinctKeys, + orgAggCall.collation, numGroups, input, null, null); } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateLiteralAttachmentRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateLiteralAttachmentRule.java deleted file mode 100644 index 74af35b47a..0000000000 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateLiteralAttachmentRule.java +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.calcite.rel.rules; - -import com.google.common.collect.ImmutableList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.apache.calcite.plan.RelOptRule; -import org.apache.calcite.plan.RelOptRuleCall; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.core.Aggregate; -import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rel.hint.RelHint; -import org.apache.calcite.rel.logical.LogicalAggregate; -import org.apache.calcite.rex.RexLiteral; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.tools.RelBuilderFactory; -import org.apache.calcite.util.Pair; -import org.apache.pinot.calcite.rel.hint.PinotHintOptions; -import org.apache.pinot.calcite.rel.hint.PinotHintStrategyTable; -import org.apache.pinot.common.utils.DataSchema.ColumnDataType; -import org.apache.pinot.query.planner.logical.LiteralHintUtils; -import org.apache.pinot.query.planner.logical.RexExpression; -import org.apache.pinot.query.planner.logical.RexExpressionUtils; - - -/** - * Special rule to attach Literal to Aggregate call. - */ -public class PinotAggregateLiteralAttachmentRule extends RelOptRule { - public static final PinotAggregateLiteralAttachmentRule INSTANCE = - new PinotAggregateLiteralAttachmentRule(PinotRuleUtils.PINOT_REL_FACTORY); - - public PinotAggregateLiteralAttachmentRule(RelBuilderFactory factory) { - super(operand(LogicalAggregate.class, any()), factory, null); - } - - @Override - public boolean matches(RelOptRuleCall call) { - if (call.rels.length < 1) { - return false; - } - if (call.rel(0) instanceof Aggregate) { - Aggregate agg = call.rel(0); - ImmutableList<RelHint> hints = agg.getHints(); - return !PinotHintStrategyTable.containsHintOption(hints, - PinotHintOptions.INTERNAL_AGG_OPTIONS, PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE); - } - return false; - } - - @Override - public void onMatch(RelOptRuleCall call) { - Aggregate aggregate = call.rel(0); - Map<Pair<Integer, Integer>, RexExpression.Literal> rexLiterals = extractLiterals(call); - List<RelHint> newHints = PinotHintStrategyTable.replaceHintOptions(aggregate.getHints(), - PinotHintOptions.INTERNAL_AGG_OPTIONS, PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE, - LiteralHintUtils.literalMapToHintString(rexLiterals)); - // TODO: validate against AggregationFunctionType to see if expected literal positions are properly attached - call.transformTo(new LogicalAggregate(aggregate.getCluster(), aggregate.getTraitSet(), newHints, - aggregate.getInput(), aggregate.getGroupSet(), aggregate.getGroupSets(), aggregate.getAggCallList())); - } - - private static Map<Pair<Integer, Integer>, RexExpression.Literal> extractLiterals(RelOptRuleCall call) { - Aggregate aggregate = call.rel(0); - RelNode input = PinotRuleUtils.unboxRel(aggregate.getInput()); - List<RexNode> rexNodes = (input instanceof Project) ? ((Project) input).getProjects() : null; - List<AggregateCall> aggCallList = aggregate.getAggCallList(); - final Map<Pair<Integer, Integer>, RexExpression.Literal> rexLiteralMap = new HashMap<>(); - for (int aggIdx = 0; aggIdx < aggCallList.size(); aggIdx++) { - AggregateCall aggCall = aggCallList.get(aggIdx); - int argSize = aggCall.getArgList().size(); - if (argSize > 1) { - // use -1 argIdx to indicate size of the agg operands. - rexLiteralMap.put(new Pair<>(aggIdx, -1), new RexExpression.Literal(ColumnDataType.INT, argSize)); - // put the literals in to the map. - for (int argIdx = 0; argIdx < argSize; argIdx++) { - if (rexNodes != null) { - RexNode field = rexNodes.get(aggCall.getArgList().get(argIdx)); - if (field instanceof RexLiteral) { - rexLiteralMap.put(new Pair<>(aggIdx, argIdx), RexExpressionUtils.fromRexLiteral((RexLiteral) field)); - } - } - } - } - } - return rexLiteralMap; - } -} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java index cbac4de9e3..6c2498c70b 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java @@ -117,11 +117,6 @@ public class PinotQueryRuleSets { PruneEmptyRules.UNION_INSTANCE ); - // Pinot specific rules to run using a single RuleCollection since we attach aggregate info after optimizer. - public static final Collection<RelOptRule> PINOT_AGG_PROCESS_RULES = ImmutableList.of( - PinotAggregateLiteralAttachmentRule.INSTANCE - ); - // Pinot specific rules that should be run AFTER all other rules public static final Collection<RelOptRule> PINOT_POST_RULES = ImmutableList.of( // Evaluate the Literal filter nodes diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java index 059faac2d4..9c53cdee6a 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java @@ -328,10 +328,6 @@ public class QueryEnvironment { hepProgramBuilder.addRuleInstance(relOptRule); } - // ---- - // Run Pinot rule to attach aggregation auxiliary info - hepProgramBuilder.addRuleCollection(PinotQueryRuleSets.PINOT_AGG_PROCESS_RULES); - // ---- // Pushdown filters using a single HepInstruction. hepProgramBuilder.addRuleCollection(PinotQueryRuleSets.FILTER_PUSHDOWN_RULES); diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java index debe59d0ab..1862adf95e 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java @@ -231,7 +231,7 @@ public class CalciteRexExpressionParser { } break; default: - functionName = functionKind.name(); + functionName = canonicalizeFunctionName(functionKind.name()); break; } List<RexExpression> childNodes = rexCall.getFunctionOperands(); @@ -288,7 +288,7 @@ public class CalciteRexExpressionParser { private static Expression getFunctionExpression(String canonicalName) { Expression expression = new Expression(ExpressionType.FUNCTION); - Function function = new Function(canonicalizeFunctionName(canonicalName)); + Function function = new Function(canonicalName); expression.setFunctionCall(function); return expression; } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralHintUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralHintUtils.java deleted file mode 100644 index ea854e9aba..0000000000 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralHintUtils.java +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.query.planner.logical; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.apache.calcite.util.Pair; -import org.apache.commons.lang3.StringUtils; -import org.apache.pinot.common.request.Literal; -import org.apache.pinot.spi.data.FieldSpec; -import org.apache.pinot.spi.utils.BytesUtils; - - -public class LiteralHintUtils { - private LiteralHintUtils() { - } - - public static String literalMapToHintString(Map<Pair<Integer, Integer>, RexExpression.Literal> literals) { - List<String> literalStrings = new ArrayList<>(literals.size()); - for (Map.Entry<Pair<Integer, Integer>, RexExpression.Literal> e : literals.entrySet()) { - // individual literal parts are joined with `|` - literalStrings.add( - String.format("%d|%d|%s|%s", e.getKey().left, e.getKey().right, e.getValue().getDataType().name(), - e.getValue().getValue())); - } - // semi-colon is used to separate between encoded literals - return "{" + StringUtils.join(literalStrings, ";:;") + "}"; - } - - public static Map<Integer, Map<Integer, Literal>> hintStringToLiteralMap(String literalString) { - Map<Integer, Map<Integer, Literal>> aggCallToLiteralArgsMap = new HashMap<>(); - if (StringUtils.isNotEmpty(literalString) && !"{}".equals(literalString)) { - String[] literalStringArr = literalString.substring(1, literalString.length() - 1).split(";:;"); - for (String literalStr : literalStringArr) { - String[] literalStrParts = literalStr.split("\\|", 4); - int aggIdx = Integer.parseInt(literalStrParts[0]); - int argListIdx = Integer.parseInt(literalStrParts[1]); - String dataTypeNameStr = literalStrParts[2]; - String valueStr = literalStrParts[3]; - Map<Integer, Literal> literalArgs = aggCallToLiteralArgsMap.computeIfAbsent(aggIdx, i -> new HashMap<>()); - literalArgs.put(argListIdx, stringToLiteral(dataTypeNameStr, valueStr)); - } - } - return aggCallToLiteralArgsMap; - } - - private static Literal stringToLiteral(String dataTypeStr, String valueStr) { - FieldSpec.DataType dataType = FieldSpec.DataType.valueOf(dataTypeStr); - switch (dataType) { - case BOOLEAN: - return Literal.boolValue(valueStr.equals("1")); - case INT: - return Literal.intValue(Integer.parseInt(valueStr)); - case LONG: - return Literal.longValue(Long.parseLong(valueStr)); - case FLOAT: - case DOUBLE: - return Literal.doubleValue(Double.parseDouble(valueStr)); - case STRING: - return Literal.stringValue(valueStr); - case BYTES: - return Literal.binaryValue(BytesUtils.toBytes(valueStr)); - default: - throw new UnsupportedOperationException("Unsupported RexLiteral type: " + dataTypeStr); - } - } -} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java index 5a80cd2596..c2e9890358 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java @@ -246,8 +246,10 @@ public class RexExpressionUtils { } public static RexExpression fromAggregateCall(AggregateCall aggregateCall) { - List<RexExpression> operands = - aggregateCall.getArgList().stream().map(RexExpression.InputRef::new).collect(Collectors.toList()); + List<RexExpression> operands = new ArrayList<>(aggregateCall.rexList.size()); + for (RexNode rexNode : aggregateCall.rexList) { + operands.add(fromRexNode(rexNode)); + } return new RexExpression.FunctionCall(aggregateCall.getAggregation().getKind(), RelToPlanNodeConverter.convertToColumnDataType(aggregateCall.getType()), aggregateCall.getAggregation().getName(), operands, aggregateCall.isDistinct()); diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java index 810202ca49..8e74660e7a 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java @@ -255,7 +255,8 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase { public void testQueryWithHint() { // Hinting the query to use final stage aggregation makes server directly return final result // This is useful when data is already partitioned by col1 - String query = "SELECT /*+ aggOptionsInternal(agg_type='DIRECT') */ col1, COUNT(*) FROM b GROUP BY col1"; + String query = + "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ col1, COUNT(*) FROM b GROUP BY col1"; DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(query); List<DispatchablePlanFragment> stagePlans = dispatchableSubPlan.getQueryStageList(); int numStages = stagePlans.size(); diff --git a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json index a7a4b1a8be..8a0878d6e1 100644 --- a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json +++ b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json @@ -102,7 +102,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a GROUP BY a.col1", "output": [ "Execution Plan", - "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])", + "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -128,7 +128,7 @@ "output": [ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($1):DOUBLE NOT NULL, $2)], EXPR$3=[$3], EXPR$4=[$4])", - "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], agg#1=[COUNT()], EXPR$3=[MAX($1)], EXPR$4=[MIN($1)])", + "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()], agg#2=[MAX($1)], agg#3=[MIN($1)])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -140,7 +140,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1", "output": [ "Execution Plan", - "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])", + "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -153,7 +153,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3), MAX(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1", "output": [ "Execution Plan", - "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], EXPR$2=[MAX($1)])", + "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[MAX($1)])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -167,7 +167,7 @@ "notes": "TODO: Needs follow up. Project should only keep a.col1 since the other columns are pushed to the filter, but it currently keeps them all", "output": [ "Execution Plan", - "\nLogicalAggregate(group=[{0}], EXPR$1=[COUNT()])", + "\nLogicalAggregate(group=[{0}], agg#0=[COUNT()])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -181,7 +181,7 @@ "output": [ "Execution Plan", "\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])", - "\n LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])", + "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($0, _UTF-8'a'))])", @@ -196,7 +196,7 @@ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])", "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])", - "\n LogicalAggregate(group=[{0}], EXPR$1=[COUNT()], EXPR$2=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])", + "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -211,7 +211,7 @@ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[$1])", "\n LogicalFilter(condition=[AND(>=($2, 0), <($3, 20), <=($1, 10), =(/(CAST($1):DOUBLE NOT NULL, $4), 5))])", - "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()])", + "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -226,7 +226,7 @@ "Execution Plan", "\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])", "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])", - "\n LogicalAggregate(group=[{0}], count=[COUNT()], SUM=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])", + "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", diff --git a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json index 7b97f583ea..32d1eb65f8 100644 --- a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json +++ b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json @@ -93,7 +93,7 @@ "Execution Plan", "\nLogicalSort(sort0=[$0], dir0=[ASC])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])", + "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -121,7 +121,7 @@ "Execution Plan", "\nLogicalSort(sort0=[$0], dir0=[ASC])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n LogicalAggregate(group=[{0}], sum=[$SUM0($1)])", + "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", diff --git a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json index 5841c442ff..3f0a4cd0f0 100644 --- a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json +++ b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json @@ -100,10 +100,10 @@ }, { "description": "semi-join with dynamic_broadcast join strategy then group-by on same key", - "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptionsInternal(agg_type='DIRECT') */ a.col1, SUM(a.col3) FROM a WHERE a.col1 IN (SELECT col2 FROM b WHERE b.col3 > 0) GROUP BY 1", + "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ a.col1, SUM(a.col3) FROM a WHERE a.col1 IN (SELECT col2 FROM b WHERE b.col3 > 0) GROUP BY 1", "output": [ "Execution Plan", - "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])", + "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -138,7 +138,7 @@ "output": [ "Execution Plan", "\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])", - "\n LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])", + "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($0, _UTF-8'a'))])", @@ -153,7 +153,7 @@ "Execution Plan", "\nLogicalProject(col2=[$0], EXPR$1=[$1], EXPR$2=[$2], EXPR$3=[$3])", "\n LogicalFilter(condition=[AND(>($1, 10), >=($4, 0), <($5, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])", - "\n LogicalAggregate(group=[{0}], EXPR$1=[COUNT()], EXPR$2=[$SUM0($1)], EXPR$3=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)])", + "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col2=[$1], col3=[$2], $f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -162,24 +162,11 @@ ] }, { - "description": "aggregate with skip intermediate stage hint (via hinting the leaf stage group by as final stage_", - "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptionsInternal(agg_type='DIRECT') */ a.col2, COUNT(*), SUM(a.col3), SUM(a.col1) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col2 HAVING COUNT(*) > 10", - "output": [ - "Execution Plan", - "\nLogicalFilter(condition=[>($1, 10)])", - "\n LogicalAggregate(group=[{0}], EXPR$1=[COUNT()], EXPR$2=[$SUM0($1)], EXPR$3=[$SUM0($2)])", - "\n LogicalProject(col2=[$1], col3=[$2], $f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])", - "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", - "\n LogicalTableScan(table=[[default, a]])", - "\n" - ] - }, - { - "description": "aggregate with skip leaf stage hint (via hint option is_partitioned_by_group_by_keys", + "description": "aggregate with skip intermediate stage hint (via hint option is_partitioned_by_group_by_keys)", "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, COUNT(*), SUM(a.col3), SUM(a.col1) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col2", "output": [ "Execution Plan", - "\nLogicalAggregate(group=[{0}], EXPR$1=[COUNT()], EXPR$2=[$SUM0($1)], EXPR$3=[$SUM0($2)])", + "\nLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[$SUM0($2)])", "\n LogicalProject(col2=[$1], col3=[$2], $f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", "\n LogicalTableScan(table=[[default, a]])", @@ -409,7 +396,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b /*+ tableOptions(partition_function='hashcode', partition_key='col1', partition_size='4') */ WHERE b.col3 > 0) GROUP BY 1", "output": [ "Execution Plan", - "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])", + "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -425,7 +412,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b WHERE b.col3 > 0) GROUP BY 1", "output": [ "Execution Plan", - "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])", + "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -443,7 +430,7 @@ "Execution Plan", "\nLogicalProject(col2=[$0], EXPR$1=[$1])", "\n LogicalFilter(condition=[>($2, 5)])", - "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], agg#1=[COUNT()])", + "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -461,7 +448,7 @@ "Execution Plan", "\nLogicalSort(sort0=[$1], dir0=[DESC])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])", + "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java index 7cf7d5f2a7..a19ff64d4e 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java @@ -27,10 +27,8 @@ import java.util.List; import java.util.Map; import javax.annotation.Nullable; import org.apache.calcite.sql.SqlKind; -import org.apache.pinot.calcite.rel.hint.PinotHintOptions; import org.apache.pinot.common.datablock.DataBlock; import org.apache.pinot.common.datatable.StatMap; -import org.apache.pinot.common.request.Literal; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.common.request.context.FunctionContext; import org.apache.pinot.common.utils.DataSchema; @@ -43,13 +41,13 @@ import org.apache.pinot.core.query.aggregation.function.AggregationFunction; import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory; import org.apache.pinot.core.query.aggregation.function.CountAggregationFunction; import org.apache.pinot.core.util.DataBlockExtractUtils; -import org.apache.pinot.query.planner.logical.LiteralHintUtils; import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.planner.plannode.AbstractPlanNode; import org.apache.pinot.query.planner.plannode.AggregateNode.AggType; import org.apache.pinot.query.runtime.blocks.TransferableBlock; import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; -import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.apache.pinot.spi.data.FieldSpec.DataType; +import org.apache.pinot.spi.utils.BooleanUtils; import org.roaringbitmap.RoaringBitmap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -65,11 +63,9 @@ public class AggregateOperator extends MultiStageOperator { private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR"; private static final CountAggregationFunction COUNT_STAR_AGG_FUNCTION = new CountAggregationFunction(Collections.singletonList(ExpressionContext.forIdentifier("*")), false); - private static final ExpressionContext PLACEHOLDER_IDENTIFIER = ExpressionContext.forIdentifier("__PLACEHOLDER__"); private final MultiStageOperator _inputOperator; private final DataSchema _resultSchema; - private final AggType _aggType; private final MultistageAggregationExecutor _aggregationExecutor; private final MultistageGroupByExecutor _groupByExecutor; @Nullable @@ -78,29 +74,15 @@ public class AggregateOperator extends MultiStageOperator { private boolean _hasConstructedAggregateBlock; - public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator, - DataSchema resultSchema, List<RexExpression> aggCalls, List<RexExpression> groupSet, AggType aggType, - List<Integer> filterArgIndices, @Nullable AbstractPlanNode.NodeHint nodeHint) { + public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator, DataSchema resultSchema, + List<RexExpression> aggCalls, List<RexExpression> groupSet, AggType aggType, List<Integer> filterArgIndices, + @Nullable AbstractPlanNode.NodeHint nodeHint) { super(context); _inputOperator = inputOperator; _resultSchema = resultSchema; - _aggType = aggType; - - // Process literal hints - Map<Integer, Map<Integer, Literal>> literalArgumentsMap = null; - if (nodeHint != null) { - Map<String, String> aggOptions = nodeHint._hintOptions.get(PinotHintOptions.INTERNAL_AGG_OPTIONS); - if (aggOptions != null) { - literalArgumentsMap = LiteralHintUtils.hintStringToLiteralMap( - aggOptions.get(PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE)); - } - } - if (literalArgumentsMap == null) { - literalArgumentsMap = Collections.emptyMap(); - } // Initialize the aggregation functions - AggregationFunction<?, ?>[] aggFunctions = getAggFunctions(aggCalls, literalArgumentsMap); + AggregationFunction<?, ?>[] aggFunctions = getAggFunctions(aggCalls); // Process the filter argument indices int numFunctions = aggFunctions.length; @@ -214,27 +196,16 @@ public class AggregateOperator extends MultiStageOperator { return block; } - private AggregationFunction<?, ?>[] getAggFunctions(List<RexExpression> aggCalls, - Map<Integer, Map<Integer, Literal>> literalArgumentsMap) { + private AggregationFunction<?, ?>[] getAggFunctions(List<RexExpression> aggCalls) { int numFunctions = aggCalls.size(); AggregationFunction<?, ?>[] aggFunctions = new AggregationFunction[numFunctions]; - if (!_aggType.isInputIntermediateFormat()) { - for (int i = 0; i < numFunctions; i++) { - Map<Integer, Literal> literalArguments = literalArgumentsMap.getOrDefault(i, Collections.emptyMap()); - aggFunctions[i] = getAggFunctionForRawInput((RexExpression.FunctionCall) aggCalls.get(i), literalArguments); - } - } else { - for (int i = 0; i < numFunctions; i++) { - Map<Integer, Literal> literalArguments = literalArgumentsMap.getOrDefault(i, Collections.emptyMap()); - aggFunctions[i] = - getAggFunctionForIntermediateInput((RexExpression.FunctionCall) aggCalls.get(i), literalArguments); - } + for (int i = 0; i < numFunctions; i++) { + aggFunctions[i] = getAggFunction((RexExpression.FunctionCall) aggCalls.get(i)); } return aggFunctions; } - private AggregationFunction<?, ?> getAggFunctionForRawInput(RexExpression.FunctionCall functionCall, - Map<Integer, Literal> literalArguments) { + private AggregationFunction<?, ?> getAggFunction(RexExpression.FunctionCall functionCall) { String functionName = functionCall.getFunctionName(); List<RexExpression> operands = functionCall.getFunctionOperands(); int numArguments = operands.size(); @@ -244,78 +215,26 @@ public class AggregateOperator extends MultiStageOperator { return COUNT_STAR_AGG_FUNCTION; } List<ExpressionContext> arguments = new ArrayList<>(numArguments); - for (int i = 0; i < numArguments; i++) { - Literal literalArgument = literalArguments.get(i); - if (literalArgument != null) { - arguments.add(ExpressionContext.forLiteralContext(literalArgument)); + for (RexExpression operand : operands) { + if (operand instanceof RexExpression.InputRef) { + RexExpression.InputRef inputRef = (RexExpression.InputRef) operand; + arguments.add(ExpressionContext.forIdentifier(fromColIdToIdentifier(inputRef.getIndex()))); } else { - RexExpression operand = operands.get(i); - switch (operand.getKind()) { - case INPUT_REF: - RexExpression.InputRef inputRef = (RexExpression.InputRef) operand; - arguments.add(ExpressionContext.forIdentifier(fromColIdToIdentifier(inputRef.getIndex()))); - break; - case LITERAL: - RexExpression.Literal literalRexExp = (RexExpression.Literal) operand; - arguments.add(ExpressionContext.forLiteralContext(literalRexExp.getDataType().toDataType(), - literalRexExp.getValue())); - break; - default: - throw new IllegalStateException("Illegal aggregation function operand type: " + operand.getKind()); + assert operand instanceof RexExpression.Literal; + RexExpression.Literal literal = (RexExpression.Literal) operand; + DataType dataType = literal.getDataType().toDataType(); + Object value = literal.getValue(); + // TODO: Fix BOOLEAN literal to directly store true/false + if (dataType == DataType.BOOLEAN) { + value = BooleanUtils.fromNonNullInternalValue(value); } + arguments.add(ExpressionContext.forLiteralContext(dataType, value)); } } - handleListAggDistinctArg(functionName, functionCall, arguments); return AggregationFunctionFactory.getAggregationFunction( new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, arguments), true); } - private static AggregationFunction<?, ?> getAggFunctionForIntermediateInput(RexExpression.FunctionCall functionCall, - Map<Integer, Literal> literalArguments) { - String functionName = functionCall.getFunctionName(); - List<RexExpression> operands = functionCall.getFunctionOperands(); - int numArguments = operands.size(); - Preconditions.checkState(numArguments == 1, "Intermediate aggregate must have 1 argument, got: %s", numArguments); - RexExpression operand = operands.get(0); - Preconditions.checkState(operand.getKind() == SqlKind.INPUT_REF, - "Intermediate aggregate argument must be an input reference, got: %s", operand.getKind()); - // We might need to append extra arguments extracted from the hint to match the signature of the aggregation - Literal numArgumentsLiteral = literalArguments.get(-1); - if (numArgumentsLiteral == null) { - return AggregationFunctionFactory.getAggregationFunction( - new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, Collections.singletonList( - ExpressionContext.forIdentifier(fromColIdToIdentifier(((RexExpression.InputRef) operand).getIndex())))), - true); - } else { - int numExpectedArguments = numArgumentsLiteral.getIntValue(); - List<ExpressionContext> arguments = new ArrayList<>(numExpectedArguments); - arguments.add( - ExpressionContext.forIdentifier(fromColIdToIdentifier(((RexExpression.InputRef) operand).getIndex()))); - for (int i = 1; i < numExpectedArguments; i++) { - Literal literalArgument = literalArguments.get(i); - if (literalArgument != null) { - arguments.add(ExpressionContext.forLiteralContext(literalArgument)); - } else { - arguments.add(PLACEHOLDER_IDENTIFIER); - } - } - handleListAggDistinctArg(functionName, functionCall, arguments); - return AggregationFunctionFactory.getAggregationFunction( - new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, arguments), true); - } - } - - private static void handleListAggDistinctArg(String functionName, RexExpression.FunctionCall functionCall, - List<ExpressionContext> arguments) { - String upperCaseFunctionName = - AggregationFunctionType.getNormalizedAggregationFunctionName(functionName); - if (upperCaseFunctionName.equals("LISTAGG")) { - if (functionCall.isDistinct()) { - arguments.add(ExpressionContext.forLiteralContext(Literal.boolValue(true))); - } - } - } - private static String fromColIdToIdentifier(int colId) { return "$" + colId; } diff --git a/pinot-query-runtime/src/test/resources/queries/QueryHints.json b/pinot-query-runtime/src/test/resources/queries/QueryHints.json index f8c850fcd3..81a939c2e1 100644 --- a/pinot-query-runtime/src/test/resources/queries/QueryHints.json +++ b/pinot-query-runtime/src/test/resources/queries/QueryHints.json @@ -275,7 +275,7 @@ }, { "description": "semi-join with dynamic_broadcast join strategy then group-by on same key", - "sql": "SELECT /*+ aggOptionsInternal(agg_type='DIRECT') */ {tbl1}.num, SUM({tbl1}.val) FROM {tbl1} WHERE {tbl1}.name IN (SELECT id FROM {tbl2} WHERE {tbl2}.data > 0) GROUP BY {tbl1}.num" + "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ {tbl1}.num, SUM({tbl1}.val) FROM {tbl1} WHERE {tbl1}.name IN (SELECT id FROM {tbl2} WHERE {tbl2}.data > 0) GROUP BY {tbl1}.num" }, { "description": "semi-join with dynamic_broadcast join strategy then group-by on different key", @@ -290,11 +290,7 @@ "sql": "SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ {tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num) FROM {tbl1} WHERE {tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num HAVING COUNT(*) > 10 AND MAX({tbl1}.val) >= 0 AND MIN({tbl1}.val) < 20 AND SUM({tbl1}.val) <= 10 AND AVG({tbl1}.val) = 5" }, { - "description": "aggregate with skip intermediate stage hint (via hinting the leaf stage group by as final stage_", - "sql": "SELECT /*+ aggOptionsInternal(agg_type='DIRECT') */ {tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num) FROM {tbl1} WHERE {tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num HAVING COUNT(*) > 10" - }, - { - "description": "aggregate with skip leaf stage hint (via hint option is_partitioned_by_group_by_keys", + "description": "aggregate with skip intermediate stage hint (via hint option is_partitioned_by_group_by_keys)", "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ {tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num) FROM {tbl1} WHERE {tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num" }, { diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java index a6c468d8fe..877ac7f232 100644 --- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java +++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java @@ -460,9 +460,10 @@ public enum AggregationFunctionType { * <p>NOTE: Underscores in the function name are ignored. */ public static AggregationFunctionType getAggregationFunctionType(String functionName) { - if (functionName.regionMatches(true, 0, "percentile", 0, 10)) { + String normalizedFunctionName = getNormalizedAggregationFunctionName(functionName); + if (normalizedFunctionName.regionMatches(false, 0, "PERCENTILE", 0, 10)) { // This style of aggregation functions is not supported in the multistage engine - String remainingFunctionName = getNormalizedAggregationFunctionName(functionName).substring(10).toUpperCase(); + String remainingFunctionName = normalizedFunctionName.substring(10).toUpperCase(); if (remainingFunctionName.isEmpty() || remainingFunctionName.matches("\\d+")) { return PERCENTILE; } else if (remainingFunctionName.equals("EST") || remainingFunctionName.matches("EST\\d+")) { @@ -496,7 +497,7 @@ public enum AggregationFunctionType { } } else { try { - return AggregationFunctionType.valueOf(getNormalizedAggregationFunctionName(functionName)); + return AggregationFunctionType.valueOf(normalizedFunctionName); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Invalid aggregation function name: " + functionName); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org