This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git
The following commit(s) were added to refs/heads/master by this push: new ffa9541 Pre-generate aggregation functions in QueryContext (#5805) ffa9541 is described below commit ffa954194e61e330c625a795cae94c15a54a2694 Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com> AuthorDate: Wed Aug 5 21:23:37 2020 -0700 Pre-generate aggregation functions in QueryContext (#5805) `AggregationFunction` itself is stateless, so we can share it among all the segments to prevent the overhead of creating it per segment. This can significantly improve the performance of high selectivity queries that hit lots of segments. - Remove the `accept(visitor)` from the `AggregationFunction` interface which may make it stateful - Make `DistinctCount` and `DistinctCountBitmap` stateless by caching the dictionary within the result holder --- pinot-common/src/test/resources/pql_queries.list | 8 +- pinot-common/src/test/resources/sql_queries.list | 8 +- .../core/common/datatable/DataTableUtils.java | 7 +- .../operator/combine/GroupByCombineOperator.java | 4 +- .../combine/GroupByOrderByCombineOperator.java | 4 +- .../plan/AggregationGroupByOrderByPlanNode.java | 3 +- .../core/plan/AggregationGroupByPlanNode.java | 3 +- .../pinot/core/plan/AggregationPlanNode.java | 3 +- .../plan/DictionaryBasedAggregationPlanNode.java | 4 +- .../plan/MetadataBasedAggregationPlanNode.java | 4 +- .../aggregation/function/AggregationFunction.java | 9 +- .../function/AggregationFunctionUtils.java | 6 +- .../function/AggregationFunctionVisitorBase.java | 114 ----------------- .../function/AvgAggregationFunction.java | 5 - .../function/AvgMVAggregationFunction.java | 5 - .../function/CountAggregationFunction.java | 5 - .../function/CountMVAggregationFunction.java | 5 - .../function/DistinctAggregationFunction.java | 5 - .../function/DistinctCountAggregationFunction.java | 114 +++++++++++------ .../DistinctCountBitmapAggregationFunction.java | 140 +++++++++++++-------- .../DistinctCountBitmapMVAggregationFunction.java | 47 +++---- .../DistinctCountHLLAggregationFunction.java | 5 - .../DistinctCountHLLMVAggregationFunction.java | 5 - .../DistinctCountMVAggregationFunction.java | 32 ++--- .../DistinctCountRawHLLAggregationFunction.java | 5 - ...inctCountRawThetaSketchAggregationFunction.java | 5 - ...istinctCountThetaSketchAggregationFunction.java | 26 ++-- .../function/FastHLLAggregationFunction.java | 5 - .../function/MaxAggregationFunction.java | 5 - .../function/MaxMVAggregationFunction.java | 5 - .../function/MinAggregationFunction.java | 5 - .../function/MinMVAggregationFunction.java | 5 - .../function/MinMaxRangeAggregationFunction.java | 5 - .../function/MinMaxRangeMVAggregationFunction.java | 5 - .../function/PercentileAggregationFunction.java | 5 - .../function/PercentileEstAggregationFunction.java | 5 - .../PercentileEstMVAggregationFunction.java | 5 - .../function/PercentileMVAggregationFunction.java | 5 - .../PercentileTDigestAggregationFunction.java | 5 - .../PercentileTDigestMVAggregationFunction.java | 5 - ...artitionedDistinctCountAggregationFunction.java | 5 - .../function/StUnionAggregationFunction.java | 5 - .../function/SumAggregationFunction.java | 5 - .../function/SumMVAggregationFunction.java | 5 - .../core/query/reduce/ResultReducerFactory.java | 6 +- .../core/query/request/context/QueryContext.java | 27 +++- .../request/context/utils/QueryContextUtils.java | 19 +-- .../pinot/core/startree/v2/BaseStarTreeV2Test.java | 3 +- .../DefaultAggregationExecutorTest.java | 4 +- .../apache/pinot/perf/BenchmarkCombineGroupBy.java | 4 +- .../apache/pinot/perf/BenchmarkIndexedTable.java | 4 +- 51 files changed, 271 insertions(+), 462 deletions(-) diff --git a/pinot-common/src/test/resources/pql_queries.list b/pinot-common/src/test/resources/pql_queries.list index e5ca687..51be85f 100644 --- a/pinot-common/src/test/resources/pql_queries.list +++ b/pinot-common/src/test/resources/pql_queries.list @@ -818,10 +818,10 @@ select mapKeys(mapField) from baseballStats where DIV(numberOfGames,10) = 100 select mapKey(mapField,k1) from baseballStats where DIV(numberOfGames,10) = 100 select mapKey(mapField,k1) from baseballStats where mapKey(mapField,k1) = 'v1' SELECT count(c1), sum(c1), min(c1), max(c1), avg(c1), minmaxrange(c1), distinctcount(c1), distinctcounthll(c1) FROM foo -SELECT distinctcountrawhll(c1), fasthll(c1), percentile(c1), percentileest(c1), percentiletdigest(c1) FROM foo +SELECT distinctcountrawhll(c1), fasthll(c1), percentile90(c1), percentileest95(c1), percentiletdigest(c1, 99) FROM foo SELECT countmv(c1), summv(c1), minmv(c1), maxmv(c1), avgmv(c1), minmaxrangemv(c1), distinctcountmv(c1), distinctcounthllmv(c1) FROM foo -SELECT distinctcountrawhllmv(c1), fasthllmv(c1), percentilemv(c1), percentileestmv(c1), percentiletdigestmv(c1) FROM foo +SELECT distinctcountrawhllmv(c1), fasthllmv(c1), percentile90mv(c1), percentileest95mv(c1), percentiletdigestmv(c1, 99) FROM foo SELECT count(c1), sum(add(c1,c2)), min(div(c1,c2)), max(sub(c1,c2)), avg(add(c1,c2)), minmaxrange(add(c1,c2)), distinctcount(sub(c1,c2)), distinctcounthll(c1) FROM foo -SELECT distinctcountrawhll(sub(c1,c2)), fasthll(div(c1,c2)), percentile(sub(c1,c2)), percentileest(add(c1,c2)), percentiletdigest(add(c1,c2)) FROM foo +SELECT distinctcountrawhll(sub(c1,c2)), fasthll(div(c1,c2)), percentile90(sub(c1,c2)), percentileest95(add(c1,c2)), percentiletdigest(add(c1,c2), 99) FROM foo SELECT countmv(c1), summv(add(c1,c2)), minmv(div(c1,c2)), maxmv(sub(c1,c2)), avgmv(add(c1,c2)), minmaxrangemv(min(c1,c2)), distinctcountmv(add(c1,c2)), distinctcounthllmv(min(c1,c2)) FROM foo -SELECT distinctcountrawhllmv(c1), fasthllmv(add(sub(c1,c2),div(c1,c2))), percentilemv(min(c1,c2)), percentileestmv(add(c1,c2)), percentiletdigestmv(min(c1,c2)) FROM foo \ No newline at end of file +SELECT distinctcountrawhllmv(c1), fasthllmv(add(sub(c1,c2),div(c1,c2))), percentile90mv(min(c1,c2)), percentileest95mv(add(c1,c2)), percentiletdigestmv(min(c1,c2), 99) FROM foo \ No newline at end of file diff --git a/pinot-common/src/test/resources/sql_queries.list b/pinot-common/src/test/resources/sql_queries.list index 2a574d9..3897ffa 100644 --- a/pinot-common/src/test/resources/sql_queries.list +++ b/pinot-common/src/test/resources/sql_queries.list @@ -818,10 +818,10 @@ select mapKeys(mapField) from baseballStats where DIV(numberOfGames,10) = 100 select mapKey(mapField,k1) from baseballStats where DIV(numberOfGames,10) = 100 select mapKey(mapField,k1) from baseballStats where mapKey(mapField,k1) = 'v1' SELECT count(c1), sum(c1), min(c1), max(c1), avg(c1), minmaxrange(c1), distinctcount(c1), distinctcounthll(c1) FROM foo -SELECT distinctcountrawhll(c1), fasthll(c1), percentile(c1), percentileest(c1), percentiletdigest(c1) FROM foo +SELECT distinctcountrawhll(c1), fasthll(c1), percentile90(c1), percentileest95(c1), percentiletdigest(c1, 99) FROM foo SELECT countmv(c1), summv(c1), minmv(c1), maxmv(c1), avgmv(c1), minmaxrangemv(c1), distinctcountmv(c1), distinctcounthllmv(c1) FROM foo -SELECT distinctcountrawhllmv(c1), fasthllmv(c1), percentilemv(c1), percentileestmv(c1), percentiletdigestmv(c1) FROM foo +SELECT distinctcountrawhllmv(c1), fasthllmv(c1), percentile90mv(c1), percentileest95mv(c1), percentiletdigestmv(c1, 99) FROM foo SELECT count(c1), sum(add(c1,c2)), min(div(c1,c2)), max(sub(c1,c2)), avg(add(c1,c2)), minmaxrange(add(c1,c2)), distinctcount(sub(c1,c2)), distinctcounthll(c1) FROM foo -SELECT distinctcountrawhll(sub(c1,c2)), fasthll(div(c1,c2)), percentile(sub(c1,c2)), percentileest(add(c1,c2)), percentiletdigest(add(c1,c2)) FROM foo +SELECT distinctcountrawhll(sub(c1,c2)), fasthll(div(c1,c2)), percentile90(sub(c1,c2)), percentileest95(add(c1,c2)), percentiletdigest(add(c1,c2), 99) FROM foo SELECT countmv(c1), summv(add(c1,c2)), minmv(div(c1,c2)), maxmv(sub(c1,c2)), avgmv(add(c1,c2)), minmaxrangemv(min(c1,c2)), distinctcountmv(add(c1,c2)), distinctcounthllmv(min(c1,c2)) FROM foo -SELECT distinctcountrawhllmv(c1), fasthllmv(add(sub(c1,c2),div(c1,c2))), percentilemv(min(c1,c2)), percentileestmv(add(c1,c2)), percentiletdigestmv(min(c1,c2)) FROM foo \ No newline at end of file +SELECT distinctcountrawhllmv(c1), fasthllmv(add(sub(c1,c2),div(c1,c2))), percentile90mv(min(c1,c2)), percentileest95mv(add(c1,c2)), percentiletdigestmv(min(c1,c2), 99) FROM foo \ No newline at end of file diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java index 70a7f48..b3dea8a 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java @@ -25,10 +25,8 @@ import java.util.List; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.DataTable; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; -import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.request.context.ExpressionContext; import org.apache.pinot.core.query.request.context.QueryContext; -import org.apache.pinot.core.query.request.context.utils.QueryContextUtils; import org.apache.pinot.core.util.QueryOptions; @@ -88,8 +86,10 @@ public class DataTableUtils { */ public static DataTable buildEmptyDataTable(QueryContext queryContext) throws IOException { + AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions(); + // Selection query. - if (!QueryContextUtils.isAggregationQuery(queryContext)) { + if (aggregationFunctions == null) { List<ExpressionContext> selectExpressions = queryContext.getSelectExpressions(); int numSelectExpressions = selectExpressions.size(); String[] columnNames = new String[numSelectExpressions]; @@ -104,7 +104,6 @@ public class DataTableUtils { } // Aggregation query. - AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(queryContext); int numAggregations = aggregationFunctions.length; List<ExpressionContext> groupByExpressions = queryContext.getGroupByExpressions(); if (groupByExpressions != null) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByCombineOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByCombineOperator.java index 476b459..e9c8b00 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByCombineOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByCombineOperator.java @@ -37,7 +37,6 @@ import org.apache.pinot.core.common.Operator; import org.apache.pinot.core.operator.BaseOperator; import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; -import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult; import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByTrimmingService; import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator; @@ -109,7 +108,8 @@ public class GroupByCombineOperator extends BaseOperator<IntermediateResultsBloc AtomicInteger numGroups = new AtomicInteger(); ConcurrentLinkedQueue<ProcessingException> mergedProcessingExceptions = new ConcurrentLinkedQueue<>(); - AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(_queryContext); + AggregationFunction[] aggregationFunctions = _queryContext.getAggregationFunctions(); + assert aggregationFunctions != null; int numAggregationFunctions = aggregationFunctions.length; // We use a CountDownLatch to track if all Futures are finished by the query timeout, and cancel the unfinished diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByOrderByCombineOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByOrderByCombineOperator.java index 627e9f4..ae7a6c9 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByOrderByCombineOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByOrderByCombineOperator.java @@ -41,7 +41,6 @@ import org.apache.pinot.core.data.table.Record; import org.apache.pinot.core.operator.BaseOperator; import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; -import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult; import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator; import org.apache.pinot.core.query.exception.EarlyTerminationException; @@ -99,7 +98,8 @@ public class GroupByOrderByCombineOperator extends BaseOperator<IntermediateResu */ @Override protected IntermediateResultsBlock getNextBlock() { - AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(_queryContext); + AggregationFunction[] aggregationFunctions = _queryContext.getAggregationFunctions(); + assert aggregationFunctions != null; int numAggregationFunctions = aggregationFunctions.length; assert _queryContext.getGroupByExpressions() != null; int numGroupByExpressions = _queryContext.getGroupByExpressions().size(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java index 5e46dae..5cb94bd 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java @@ -52,7 +52,8 @@ public class AggregationGroupByOrderByPlanNode implements PlanNode { _indexSegment = indexSegment; _maxInitialResultHolderCapacity = maxInitialResultHolderCapacity; _numGroupsLimit = numGroupsLimit; - _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(queryContext); + _aggregationFunctions = queryContext.getAggregationFunctions(); + assert _aggregationFunctions != null; List<ExpressionContext> groupByExpressions = queryContext.getGroupByExpressions(); assert groupByExpressions != null; _groupByExpressions = groupByExpressions.toArray(new ExpressionContext[0]); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java index 2f1b08e..acce657 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java @@ -52,7 +52,8 @@ public class AggregationGroupByPlanNode implements PlanNode { _indexSegment = indexSegment; _maxInitialResultHolderCapacity = maxInitialResultHolderCapacity; _numGroupsLimit = numGroupsLimit; - _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(queryContext); + _aggregationFunctions = queryContext.getAggregationFunctions(); + assert _aggregationFunctions != null; List<ExpressionContext> groupByExpressions = queryContext.getGroupByExpressions(); assert groupByExpressions != null; _groupByExpressions = groupByExpressions.toArray(new ExpressionContext[0]); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java index 0255fca..4a267a2 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java @@ -46,7 +46,8 @@ public class AggregationPlanNode implements PlanNode { public AggregationPlanNode(IndexSegment indexSegment, QueryContext queryContext) { _indexSegment = indexSegment; - _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(queryContext); + _aggregationFunctions = queryContext.getAggregationFunctions(); + assert _aggregationFunctions != null; List<StarTreeV2> starTrees = indexSegment.getStarTrees(); if (starTrees != null) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java index 04ad98b..d96ee1d 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java @@ -23,7 +23,6 @@ import java.util.Map; import org.apache.pinot.core.indexsegment.IndexSegment; import org.apache.pinot.core.operator.query.DictionaryBasedAggregationOperator; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; -import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.request.context.ExpressionContext; import org.apache.pinot.core.query.request.context.QueryContext; import org.apache.pinot.core.segment.index.readers.Dictionary; @@ -46,7 +45,8 @@ public class DictionaryBasedAggregationPlanNode implements PlanNode { */ public DictionaryBasedAggregationPlanNode(IndexSegment indexSegment, QueryContext queryContext) { _indexSegment = indexSegment; - _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(queryContext); + _aggregationFunctions = queryContext.getAggregationFunctions(); + assert _aggregationFunctions != null; _dictionaryMap = new HashMap<>(); for (AggregationFunction aggregationFunction : _aggregationFunctions) { String column = ((ExpressionContext) aggregationFunction.getInputExpressions().get(0)).getIdentifier(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java index d3aa3a3..af47928 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java @@ -25,7 +25,6 @@ import org.apache.pinot.core.common.DataSource; import org.apache.pinot.core.indexsegment.IndexSegment; import org.apache.pinot.core.operator.query.MetadataBasedAggregationOperator; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; -import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.request.context.ExpressionContext; import org.apache.pinot.core.query.request.context.QueryContext; @@ -47,7 +46,8 @@ public class MetadataBasedAggregationPlanNode implements PlanNode { */ public MetadataBasedAggregationPlanNode(IndexSegment indexSegment, QueryContext queryContext) { _indexSegment = indexSegment; - _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(queryContext); + _aggregationFunctions = queryContext.getAggregationFunctions(); + assert _aggregationFunctions != null; _dataSourceMap = new HashMap<>(); for (AggregationFunction aggregationFunction : _aggregationFunctions) { if (aggregationFunction.getType() != AggregationFunctionType.COUNT) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java index 6caa303..b72cbf8 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java @@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function; import java.util.List; import java.util.Map; +import javax.annotation.concurrent.ThreadSafe; import org.apache.pinot.common.function.AggregationFunctionType; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.core.common.BlockValSet; @@ -30,10 +31,13 @@ import org.apache.pinot.core.query.request.context.ExpressionContext; /** * Interface for aggregation functions. + * <p>The implementation should be stateless, and can be shared among multiple segments in multiple threads. The result + * for each segment should be stored and passed in via the result holder. * * @param <IntermediateResult> Intermediate result generated from segment * @param <FinalResult> Final result used in broker response */ +@ThreadSafe @SuppressWarnings("rawtypes") public interface AggregationFunction<IntermediateResult, FinalResult extends Comparable> { @@ -58,11 +62,6 @@ public interface AggregationFunction<IntermediateResult, FinalResult extends Com List<ExpressionContext> getInputExpressions(); /** - * Accepts an aggregation function visitor to visit. - */ - void accept(AggregationFunctionVisitorBase visitor); - - /** * Returns an aggregation result holder for this function (aggregation only). */ AggregationResultHolder createAggregationResultHolder(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java index b5784fb..40d67be 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java @@ -69,9 +69,9 @@ public class AggregationFunctionUtils { } /** - * Creates an array of {@link AggregationFunction}s based on the given {@link QueryContext}. + * Creates a list of {@link AggregationFunction}s based on the given {@link QueryContext}. */ - public static AggregationFunction[] getAggregationFunctions(QueryContext queryContext) { + public static List<AggregationFunction> getAggregationFunctions(QueryContext queryContext) { List<ExpressionContext> selectExpressions = queryContext.getSelectExpressions(); Set<FunctionContext> functions = new HashSet<>(); List<AggregationFunction> aggregationFunctions = new ArrayList<>(); @@ -94,7 +94,7 @@ public class AggregationFunctionUtils { } } } - return aggregationFunctions.toArray(new AggregationFunction[0]); + return aggregationFunctions; } /** diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionVisitorBase.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionVisitorBase.java deleted file mode 100644 index 2710ba7..0000000 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionVisitorBase.java +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.core.query.aggregation.function; - -/** - * No-op base class for aggregation function visitor. - */ -@SuppressWarnings("unused") -public class AggregationFunctionVisitorBase { - - public void visit(AvgAggregationFunction function) { - } - - public void visit(AvgMVAggregationFunction function) { - } - - public void visit(CountAggregationFunction function) { - } - - public void visit(CountMVAggregationFunction function) { - } - - public void visit(DistinctAggregationFunction function) { - } - - public void visit(DistinctCountAggregationFunction function) { - } - - public void visit(DistinctCountMVAggregationFunction function) { - } - - public void visit(DistinctCountBitmapAggregationFunction function) { - } - - public void visit(DistinctCountBitmapMVAggregationFunction function) { - } - - public void visit(SegmentPartitionedDistinctCountAggregationFunction function) { - } - - public void visit(DistinctCountHLLAggregationFunction function) { - } - - public void visit(DistinctCountHLLMVAggregationFunction function) { - } - - public void visit(FastHLLAggregationFunction function) { - } - - public void visit(MaxAggregationFunction function) { - } - - public void visit(MaxMVAggregationFunction function) { - } - - public void visit(MinAggregationFunction function) { - } - - public void visit(MinMVAggregationFunction function) { - } - - public void visit(MinMaxRangeAggregationFunction function) { - } - - public void visit(MinMaxRangeMVAggregationFunction function) { - } - - public void visit(PercentileAggregationFunction function) { - } - - public void visit(PercentileMVAggregationFunction function) { - } - - public void visit(PercentileEstAggregationFunction function) { - } - - public void visit(PercentileEstMVAggregationFunction function) { - } - - public void visit(PercentileTDigestAggregationFunction function) { - } - - public void visit(PercentileTDigestMVAggregationFunction function) { - } - - public void visit(SumAggregationFunction function) { - } - - public void visit(SumMVAggregationFunction function) { - } - - public void visit(DistinctCountThetaSketchAggregationFunction function) { - } - - public void visit(StUnionAggregationFunction function) { - } -} - diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java index be472f4..81f78c4 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java @@ -45,11 +45,6 @@ public class AvgAggregationFunction extends BaseSingleInputAggregationFunction<A } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java index 57fb7d4..1d034aa 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java @@ -38,11 +38,6 @@ public class AvgMVAggregationFunction extends AvgAggregationFunction { } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java index b202464..c24f5f1 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java @@ -61,11 +61,6 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long> } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new DoubleAggregationResultHolder(DEFAULT_INITIAL_VALUE); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java index 5546a70..44b0990 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java @@ -61,11 +61,6 @@ public class CountMVAggregationFunction extends CountAggregationFunction { } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { int[] valueArray = blockValSetMap.get(_expression).getNumMVEntries(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java index db093a1..72e7edc 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java @@ -101,11 +101,6 @@ public class DistinctAggregationFunction implements AggregationFunction<Distinct } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java index c13c8c0..e8e7e97 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java @@ -18,7 +18,6 @@ */ package org.apache.pinot.core.query.aggregation.function; -import it.unimi.dsi.fastutil.ints.IntIterator; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import java.util.Arrays; import java.util.Map; @@ -32,10 +31,11 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder import org.apache.pinot.core.query.request.context.ExpressionContext; import org.apache.pinot.core.segment.index.readers.Dictionary; import org.apache.pinot.spi.data.FieldSpec.DataType; +import org.roaringbitmap.PeekableIntIterator; +import org.roaringbitmap.RoaringBitmap; public class DistinctCountAggregationFunction extends BaseSingleInputAggregationFunction<IntOpenHashSet, Integer> { - protected Dictionary _dictionary; public DistinctCountAggregationFunction(ExpressionContext expression) { super(expression); @@ -47,11 +47,6 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } @@ -65,20 +60,17 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - IntOpenHashSet valueSet = getValueSet(aggregationResultHolder); - // For dictionary-encoded expression, store dictionary ids into the value set + // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; int[] dictIds = blockValSet.getDictionaryIdsSV(); - for (int i = 0; i < length; i++) { - valueSet.add(dictIds[i]); - } + getDictIdBitmap(aggregationResultHolder, dictionary).addN(dictIds, 0, length); return; } // For non-dictionary-encoded expression, store hash code of the values into the value set + IntOpenHashSet valueSet = getValueSet(aggregationResultHolder); DataType valueType = blockValSet.getValueType(); switch (valueType) { case INT: @@ -127,13 +119,12 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation Map<ExpressionContext, BlockValSet> blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - // For dictionary-encoded expression, store dictionary ids into the value set + // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; int[] dictIds = blockValSet.getDictionaryIdsSV(); for (int i = 0; i < length; i++) { - getValueSet(groupByResultHolder, groupKeyArray[i]).add(dictIds[i]); + getDictIdBitmap(groupByResultHolder, groupKeyArray[i], dictionary).add(dictIds[i]); } return; } @@ -187,13 +178,12 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation Map<ExpressionContext, BlockValSet> blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - // For dictionary-encoded expression, store dictionary ids into the value set + // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; int[] dictIds = blockValSet.getDictionaryIdsSV(); for (int i = 0; i < length; i++) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], dictIds[i]); + setDictIdForGroupKeys(groupByResultHolder, groupKeysArray[i], dictionary, dictIds[i]); } return; } @@ -244,33 +234,33 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation @Override public IntOpenHashSet extractAggregationResult(AggregationResultHolder aggregationResultHolder) { - IntOpenHashSet valueSet = aggregationResultHolder.getResult(); - if (valueSet == null) { + Object result = aggregationResultHolder.getResult(); + if (result == null) { return new IntOpenHashSet(); } - if (_dictionary != null) { + if (result instanceof DictIdsWrapper) { // For dictionary-encoded expression, convert dictionary ids to hash code of the values - return convertToValueSet(valueSet, _dictionary); + return convertToValueSet((DictIdsWrapper) result); } else { // For non-dictionary-encoded expression, directly return the value set - return valueSet; + return (IntOpenHashSet) result; } } @Override public IntOpenHashSet extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) { - IntOpenHashSet valueSet = groupByResultHolder.getResult(groupKey); - if (valueSet == null) { + Object result = groupByResultHolder.getResult(groupKey); + if (result == null) { return new IntOpenHashSet(); } - if (_dictionary != null) { + if (result instanceof DictIdsWrapper) { // For dictionary-encoded expression, convert dictionary ids to hash code of the values - return convertToValueSet(valueSet, _dictionary); + return convertToValueSet((DictIdsWrapper) result); } else { // For non-dictionary-encoded expression, directly return the value set - return valueSet; + return (IntOpenHashSet) result; } } @@ -301,6 +291,19 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation } /** + * Returns the dictionary id bitmap from the result holder or creates a new one if it does not exist. + */ + protected static RoaringBitmap getDictIdBitmap(AggregationResultHolder aggregationResultHolder, + Dictionary dictionary) { + DictIdsWrapper dictIdsWrapper = aggregationResultHolder.getResult(); + if (dictIdsWrapper == null) { + dictIdsWrapper = new DictIdsWrapper(dictionary); + aggregationResultHolder.setValue(dictIdsWrapper); + } + return dictIdsWrapper._dictIdBitmap; + } + + /** * Returns the value set from the result holder or creates a new one if it does not exist. */ protected static IntOpenHashSet getValueSet(AggregationResultHolder aggregationResultHolder) { @@ -313,6 +316,19 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation } /** + * Returns the dictionary id bitmap for the given group key or creates a new one if it does not exist. + */ + protected static RoaringBitmap getDictIdBitmap(GroupByResultHolder groupByResultHolder, int groupKey, + Dictionary dictionary) { + DictIdsWrapper dictIdsWrapper = groupByResultHolder.getResult(groupKey); + if (dictIdsWrapper == null) { + dictIdsWrapper = new DictIdsWrapper(dictionary); + groupByResultHolder.setValueForKey(groupKey, dictIdsWrapper); + } + return dictIdsWrapper._dictIdBitmap; + } + + /** * Returns the value set for the given group key or creates a new one if it does not exist. */ protected static IntOpenHashSet getValueSet(GroupByResultHolder groupByResultHolder, int groupKey) { @@ -325,6 +341,16 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation } /** + * Helper method to set dictionary id for the given group keys into the result holder. + */ + private static void setDictIdForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, + Dictionary dictionary, int dictId) { + for (int groupKey : groupKeys) { + getDictIdBitmap(groupByResultHolder, groupKey, dictionary).add(dictId); + } + } + + /** * Helper method to set value for the given group keys into the result holder. */ private static void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, int value) { @@ -337,39 +363,41 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation * Helper method to read dictionary and convert dictionary ids to hash code of the values for dictionary-encoded * expression. */ - private static IntOpenHashSet convertToValueSet(IntOpenHashSet dictIdSet, Dictionary dictionary) { - IntOpenHashSet valueSet = new IntOpenHashSet(dictIdSet.size()); - IntIterator iterator = dictIdSet.iterator(); + private static IntOpenHashSet convertToValueSet(DictIdsWrapper dictIdsWrapper) { + Dictionary dictionary = dictIdsWrapper._dictionary; + RoaringBitmap dictIdBitmap = dictIdsWrapper._dictIdBitmap; + IntOpenHashSet valueSet = new IntOpenHashSet(dictIdBitmap.getCardinality()); + PeekableIntIterator iterator = dictIdBitmap.getIntIterator(); DataType valueType = dictionary.getValueType(); switch (valueType) { case INT: while (iterator.hasNext()) { - valueSet.add(dictionary.getIntValue(iterator.nextInt())); + valueSet.add(dictionary.getIntValue(iterator.next())); } break; case LONG: while (iterator.hasNext()) { - valueSet.add(Long.hashCode(dictionary.getLongValue(iterator.nextInt()))); + valueSet.add(Long.hashCode(dictionary.getLongValue(iterator.next()))); } break; case FLOAT: while (iterator.hasNext()) { - valueSet.add(Float.hashCode(dictionary.getFloatValue(iterator.nextInt()))); + valueSet.add(Float.hashCode(dictionary.getFloatValue(iterator.next()))); } break; case DOUBLE: while (iterator.hasNext()) { - valueSet.add(Double.hashCode(dictionary.getDoubleValue(iterator.nextInt()))); + valueSet.add(Double.hashCode(dictionary.getDoubleValue(iterator.next()))); } break; case STRING: while (iterator.hasNext()) { - valueSet.add(dictionary.getStringValue(iterator.nextInt()).hashCode()); + valueSet.add(dictionary.getStringValue(iterator.next()).hashCode()); } break; case BYTES: while (iterator.hasNext()) { - valueSet.add(Arrays.hashCode(dictionary.getBytesValue(iterator.nextInt()))); + valueSet.add(Arrays.hashCode(dictionary.getBytesValue(iterator.next()))); } break; default: @@ -377,4 +405,14 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation } return valueSet; } + + private static final class DictIdsWrapper { + final Dictionary _dictionary; + final RoaringBitmap _dictIdBitmap; + + private DictIdsWrapper(Dictionary dictionary) { + _dictionary = dictionary; + _dictIdBitmap = new RoaringBitmap(); + } + } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapAggregationFunction.java index fbb5153..c5fb934 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapAggregationFunction.java @@ -40,7 +40,6 @@ import org.roaringbitmap.RoaringBitmap; * values for other data types (values with the same hash code will only be counted once). */ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggregationFunction<RoaringBitmap, Integer> { - protected Dictionary _dictionary; public DistinctCountBitmapAggregationFunction(ExpressionContext expression) { super(expression); @@ -52,11 +51,6 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } @@ -75,60 +69,58 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre DataType valueType = blockValSet.getValueType(); if (valueType == DataType.BYTES) { byte[][] bytesValues = blockValSet.getBytesValuesSV(); - RoaringBitmap bitmap = aggregationResultHolder.getResult(); - if (bitmap != null) { + RoaringBitmap valueBitmap = aggregationResultHolder.getResult(); + if (valueBitmap != null) { for (int i = 0; i < length; i++) { - bitmap.or(ObjectSerDeUtils.ROARING_BITMAP_SER_DE.deserialize(bytesValues[i])); + valueBitmap.or(ObjectSerDeUtils.ROARING_BITMAP_SER_DE.deserialize(bytesValues[i])); } } else { - bitmap = ObjectSerDeUtils.ROARING_BITMAP_SER_DE.deserialize(bytesValues[0]); - aggregationResultHolder.setValue(bitmap); + valueBitmap = ObjectSerDeUtils.ROARING_BITMAP_SER_DE.deserialize(bytesValues[0]); + aggregationResultHolder.setValue(valueBitmap); for (int i = 1; i < length; i++) { - bitmap.or(ObjectSerDeUtils.ROARING_BITMAP_SER_DE.deserialize(bytesValues[i])); + valueBitmap.or(ObjectSerDeUtils.ROARING_BITMAP_SER_DE.deserialize(bytesValues[i])); } } return; } - RoaringBitmap bitmap = getBitmap(aggregationResultHolder); - // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; int[] dictIds = blockValSet.getDictionaryIdsSV(); - bitmap.addN(dictIds, 0, length); + getDictIdBitmap(aggregationResultHolder, dictionary).addN(dictIds, 0, length); return; } // For non-dictionary-encoded expression, store hash code of the values into the bitmap + RoaringBitmap valueBitmap = getValueBitmap(aggregationResultHolder); switch (valueType) { case INT: int[] intValues = blockValSet.getIntValuesSV(); - bitmap.addN(intValues, 0, length); + valueBitmap.addN(intValues, 0, length); break; case LONG: long[] longValues = blockValSet.getLongValuesSV(); for (int i = 0; i < length; i++) { - bitmap.add(Long.hashCode(longValues[i])); + valueBitmap.add(Long.hashCode(longValues[i])); } break; case FLOAT: float[] floatValues = blockValSet.getFloatValuesSV(); for (int i = 0; i < length; i++) { - bitmap.add(Float.hashCode(floatValues[i])); + valueBitmap.add(Float.hashCode(floatValues[i])); } break; case DOUBLE: double[] doubleValues = blockValSet.getDoubleValuesSV(); for (int i = 0; i < length; i++) { - bitmap.add(Double.hashCode(doubleValues[i])); + valueBitmap.add(Double.hashCode(doubleValues[i])); } break; case STRING: String[] stringValues = blockValSet.getStringValuesSV(); for (int i = 0; i < length; i++) { - bitmap.add(stringValues[i].hashCode()); + valueBitmap.add(stringValues[i].hashCode()); } break; default: @@ -149,9 +141,9 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre for (int i = 0; i < length; i++) { RoaringBitmap value = ObjectSerDeUtils.ROARING_BITMAP_SER_DE.deserialize(bytesValues[i]); int groupKey = groupKeyArray[i]; - RoaringBitmap bitmap = groupByResultHolder.getResult(groupKey); - if (bitmap != null) { - bitmap.or(value); + RoaringBitmap valueBitmap = groupByResultHolder.getResult(groupKey); + if (valueBitmap != null) { + valueBitmap.or(value); } else { groupByResultHolder.setValueForKey(groupKey, value); } @@ -162,10 +154,9 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; int[] dictIds = blockValSet.getDictionaryIdsSV(); for (int i = 0; i < length; i++) { - getBitmap(groupByResultHolder, groupKeyArray[i]).add(dictIds[i]); + getDictIdBitmap(groupByResultHolder, groupKeyArray[i], dictionary).add(dictIds[i]); } return; } @@ -175,31 +166,31 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre case INT: int[] intValues = blockValSet.getIntValuesSV(); for (int i = 0; i < length; i++) { - getBitmap(groupByResultHolder, groupKeyArray[i]).add(intValues[i]); + getValueBitmap(groupByResultHolder, groupKeyArray[i]).add(intValues[i]); } break; case LONG: long[] longValues = blockValSet.getLongValuesSV(); for (int i = 0; i < length; i++) { - getBitmap(groupByResultHolder, groupKeyArray[i]).add(Long.hashCode(longValues[i])); + getValueBitmap(groupByResultHolder, groupKeyArray[i]).add(Long.hashCode(longValues[i])); } break; case FLOAT: float[] floatValues = blockValSet.getFloatValuesSV(); for (int i = 0; i < length; i++) { - getBitmap(groupByResultHolder, groupKeyArray[i]).add(Float.hashCode(floatValues[i])); + getValueBitmap(groupByResultHolder, groupKeyArray[i]).add(Float.hashCode(floatValues[i])); } break; case DOUBLE: double[] doubleValues = blockValSet.getDoubleValuesSV(); for (int i = 0; i < length; i++) { - getBitmap(groupByResultHolder, groupKeyArray[i]).add(Double.hashCode(doubleValues[i])); + getValueBitmap(groupByResultHolder, groupKeyArray[i]).add(Double.hashCode(doubleValues[i])); } break; case STRING: String[] stringValues = blockValSet.getStringValuesSV(); for (int i = 0; i < length; i++) { - getBitmap(groupByResultHolder, groupKeyArray[i]).add(stringValues[i].hashCode()); + getValueBitmap(groupByResultHolder, groupKeyArray[i]).add(stringValues[i].hashCode()); } break; default: @@ -235,10 +226,9 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; int[] dictIds = blockValSet.getDictionaryIdsSV(); for (int i = 0; i < length; i++) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], dictIds[i]); + setDictIdForGroupKeys(groupByResultHolder, groupKeysArray[i], dictionary, dictIds[i]); } return; } @@ -283,33 +273,33 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre @Override public RoaringBitmap extractAggregationResult(AggregationResultHolder aggregationResultHolder) { - RoaringBitmap bitmap = aggregationResultHolder.getResult(); - if (bitmap == null) { + Object result = aggregationResultHolder.getResult(); + if (result == null) { return new RoaringBitmap(); } - if (_dictionary != null) { + if (result instanceof DictIdsWrapper) { // For dictionary-encoded expression, convert dictionary ids to hash code of the values - return convertToValueBitmap(bitmap, _dictionary); + return convertToValueBitmap((DictIdsWrapper) result); } else { - // For serialized RoaringBitmap and non-dictionary-encoded expression, directly return the bitmap - return bitmap; + // For serialized RoaringBitmap and non-dictionary-encoded expression, directly return the value bitmap + return (RoaringBitmap) result; } } @Override public RoaringBitmap extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) { - RoaringBitmap bitmap = groupByResultHolder.getResult(groupKey); - if (bitmap == null) { + Object result = groupByResultHolder.getResult(groupKey); + if (result == null) { return new RoaringBitmap(); } - if (_dictionary != null) { + if (result instanceof DictIdsWrapper) { // For dictionary-encoded expression, convert dictionary ids to hash code of the values - return convertToValueBitmap(bitmap, _dictionary); + return convertToValueBitmap((DictIdsWrapper) result); } else { - // For serialized RoaringBitmap and non-dictionary-encoded expression, directly return the bitmap - return bitmap; + // For serialized RoaringBitmap and non-dictionary-encoded expression, directly return the value bitmap + return (RoaringBitmap) result; } } @@ -340,9 +330,22 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre } /** - * Returns the bitmap from the result holder or creates a new one if it does not exist. + * Returns the dictionary id bitmap from the result holder or creates a new one if it does not exist. + */ + protected static RoaringBitmap getDictIdBitmap(AggregationResultHolder aggregationResultHolder, + Dictionary dictionary) { + DictIdsWrapper dictIdsWrapper = aggregationResultHolder.getResult(); + if (dictIdsWrapper == null) { + dictIdsWrapper = new DictIdsWrapper(dictionary); + aggregationResultHolder.setValue(dictIdsWrapper); + } + return dictIdsWrapper._dictIdBitmap; + } + + /** + * Returns the value bitmap from the result holder or creates a new one if it does not exist. */ - protected static RoaringBitmap getBitmap(AggregationResultHolder aggregationResultHolder) { + protected static RoaringBitmap getValueBitmap(AggregationResultHolder aggregationResultHolder) { RoaringBitmap bitmap = aggregationResultHolder.getResult(); if (bitmap == null) { bitmap = new RoaringBitmap(); @@ -352,9 +355,22 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre } /** - * Returns the bitmap for the given group key or creates a new one if it does not exist. + * Returns the dictionary id bitmap for the given group key or creates a new one if it does not exist. */ - protected static RoaringBitmap getBitmap(GroupByResultHolder groupByResultHolder, int groupKey) { + protected static RoaringBitmap getDictIdBitmap(GroupByResultHolder groupByResultHolder, int groupKey, + Dictionary dictionary) { + DictIdsWrapper dictIdsWrapper = groupByResultHolder.getResult(groupKey); + if (dictIdsWrapper == null) { + dictIdsWrapper = new DictIdsWrapper(dictionary); + groupByResultHolder.setValueForKey(groupKey, dictIdsWrapper); + } + return dictIdsWrapper._dictIdBitmap; + } + + /** + * Returns the value bitmap for the given group key or creates a new one if it does not exist. + */ + protected static RoaringBitmap getValueBitmap(GroupByResultHolder groupByResultHolder, int groupKey) { RoaringBitmap bitmap = groupByResultHolder.getResult(groupKey); if (bitmap == null) { bitmap = new RoaringBitmap(); @@ -364,11 +380,21 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre } /** + * Helper method to set dictionary id for the given group keys into the result holder. + */ + private static void setDictIdForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, + Dictionary dictionary, int dictId) { + for (int groupKey : groupKeys) { + getDictIdBitmap(groupByResultHolder, groupKey, dictionary).add(dictId); + } + } + + /** * Helper method to set value for the given group keys into the result holder. */ private void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, int value) { for (int groupKey : groupKeys) { - getBitmap(groupByResultHolder, groupKey).add(value); + getValueBitmap(groupByResultHolder, groupKey).add(value); } } @@ -376,7 +402,9 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre * Helper method to read dictionary and convert dictionary ids to hash code of the values for dictionary-encoded * expression. */ - private static RoaringBitmap convertToValueBitmap(RoaringBitmap dictIdBitmap, Dictionary dictionary) { + private static RoaringBitmap convertToValueBitmap(DictIdsWrapper dictIdsWrapper) { + Dictionary dictionary = dictIdsWrapper._dictionary; + RoaringBitmap dictIdBitmap = dictIdsWrapper._dictIdBitmap; RoaringBitmap valueBitmap = new RoaringBitmap(); PeekableIntIterator iterator = dictIdBitmap.getIntIterator(); DataType valueType = dictionary.getValueType(); @@ -412,4 +440,14 @@ public class DistinctCountBitmapAggregationFunction extends BaseSingleInputAggre } return valueBitmap; } + + private static final class DictIdsWrapper { + final Dictionary _dictionary; + final RoaringBitmap _dictIdBitmap; + + private DictIdsWrapper(Dictionary dictionary) { + _dictionary = dictionary; + _dictIdBitmap = new RoaringBitmap(); + } + } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapMVAggregationFunction.java index 2f5dbba..795e07d 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapMVAggregationFunction.java @@ -46,41 +46,36 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - RoaringBitmap bitmap = getBitmap(aggregationResultHolder); // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; + RoaringBitmap dictIdBitmap = getDictIdBitmap(aggregationResultHolder, dictionary); int[][] dictIds = blockValSet.getDictionaryIdsMV(); for (int i = 0; i < length; i++) { - bitmap.add(dictIds[i]); + dictIdBitmap.add(dictIds[i]); } return; } // For non-dictionary-encoded expression, store hash code of the values into the bitmap + RoaringBitmap valueBitmap = getValueBitmap(aggregationResultHolder); DataType valueType = blockValSet.getValueType(); switch (valueType) { case INT: int[][] intValues = blockValSet.getIntValuesMV(); for (int i = 0; i < length; i++) { - bitmap.add(intValues[i]); + valueBitmap.add(intValues[i]); } break; case LONG: long[][] longValues = blockValSet.getLongValuesMV(); for (int i = 0; i < length; i++) { for (long value : longValues[i]) { - bitmap.add(Long.hashCode(value)); + valueBitmap.add(Long.hashCode(value)); } } break; @@ -88,14 +83,14 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma float[][] floatValues = blockValSet.getFloatValuesMV(); for (int i = 0; i < length; i++) { for (float value : floatValues[i]) { - bitmap.add(Float.hashCode(value)); + valueBitmap.add(Float.hashCode(value)); } } case DOUBLE: double[][] doubleValues = blockValSet.getDoubleValuesMV(); for (int i = 0; i < length; i++) { for (double value : doubleValues[i]) { - bitmap.add(Double.hashCode(value)); + valueBitmap.add(Double.hashCode(value)); } } break; @@ -103,7 +98,7 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma String[][] stringValues = blockValSet.getStringValuesMV(); for (int i = 0; i < length; i++) { for (String value : stringValues[i]) { - bitmap.add(value.hashCode()); + valueBitmap.add(value.hashCode()); } } break; @@ -121,10 +116,9 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; int[][] dictIds = blockValSet.getDictionaryIdsMV(); for (int i = 0; i < length; i++) { - getBitmap(groupByResultHolder, groupKeyArray[i]).add(dictIds[i]); + getDictIdBitmap(groupByResultHolder, groupKeyArray[i], dictionary).add(dictIds[i]); } return; } @@ -135,13 +129,13 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma case INT: int[][] intValues = blockValSet.getIntValuesMV(); for (int i = 0; i < length; i++) { - getBitmap(groupByResultHolder, groupKeyArray[i]).add(intValues[i]); + getValueBitmap(groupByResultHolder, groupKeyArray[i]).add(intValues[i]); } break; case LONG: long[][] longValues = blockValSet.getLongValuesMV(); for (int i = 0; i < length; i++) { - RoaringBitmap bitmap = getBitmap(groupByResultHolder, groupKeyArray[i]); + RoaringBitmap bitmap = getValueBitmap(groupByResultHolder, groupKeyArray[i]); for (long value : longValues[i]) { bitmap.add(Long.hashCode(value)); } @@ -150,7 +144,7 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma case FLOAT: float[][] floatValues = blockValSet.getFloatValuesMV(); for (int i = 0; i < length; i++) { - RoaringBitmap bitmap = getBitmap(groupByResultHolder, groupKeyArray[i]); + RoaringBitmap bitmap = getValueBitmap(groupByResultHolder, groupKeyArray[i]); for (float value : floatValues[i]) { bitmap.add(Float.hashCode(value)); } @@ -159,7 +153,7 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma case DOUBLE: double[][] doubleValues = blockValSet.getDoubleValuesMV(); for (int i = 0; i < length; i++) { - RoaringBitmap bitmap = getBitmap(groupByResultHolder, groupKeyArray[i]); + RoaringBitmap bitmap = getValueBitmap(groupByResultHolder, groupKeyArray[i]); for (double value : doubleValues[i]) { bitmap.add(Double.hashCode(value)); } @@ -168,7 +162,7 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma case STRING: String[][] stringValues = blockValSet.getStringValuesMV(); for (int i = 0; i < length; i++) { - RoaringBitmap bitmap = getBitmap(groupByResultHolder, groupKeyArray[i]); + RoaringBitmap bitmap = getValueBitmap(groupByResultHolder, groupKeyArray[i]); for (String value : stringValues[i]) { bitmap.add(value.hashCode()); } @@ -188,11 +182,10 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; int[][] dictIds = blockValSet.getDictionaryIdsMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - getBitmap(groupByResultHolder, groupKey).add(dictIds[i]); + getDictIdBitmap(groupByResultHolder, groupKey, dictionary).add(dictIds[i]); } } return; @@ -205,7 +198,7 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma int[][] intValues = blockValSet.getIntValuesMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - getBitmap(groupByResultHolder, groupKey).add(intValues[i]); + getValueBitmap(groupByResultHolder, groupKey).add(intValues[i]); } } break; @@ -213,7 +206,7 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma long[][] longValues = blockValSet.getLongValuesMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - RoaringBitmap bitmap = getBitmap(groupByResultHolder, groupKey); + RoaringBitmap bitmap = getValueBitmap(groupByResultHolder, groupKey); for (long value : longValues[i]) { bitmap.add(Long.hashCode(value)); } @@ -224,7 +217,7 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma float[][] floatValues = blockValSet.getFloatValuesMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - RoaringBitmap bitmap = getBitmap(groupByResultHolder, groupKey); + RoaringBitmap bitmap = getValueBitmap(groupByResultHolder, groupKey); for (float value : floatValues[i]) { bitmap.add(Float.hashCode(value)); } @@ -235,7 +228,7 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma double[][] doubleValues = blockValSet.getDoubleValuesMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - RoaringBitmap bitmap = getBitmap(groupByResultHolder, groupKey); + RoaringBitmap bitmap = getValueBitmap(groupByResultHolder, groupKey); for (double value : doubleValues[i]) { bitmap.add(Double.hashCode(value)); } @@ -246,7 +239,7 @@ public class DistinctCountBitmapMVAggregationFunction extends DistinctCountBitma String[][] stringValues = blockValSet.getStringValuesMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - RoaringBitmap bitmap = getBitmap(groupByResultHolder, groupKey); + RoaringBitmap bitmap = getValueBitmap(groupByResultHolder, groupKey); for (String value : stringValues[i]) { bitmap.add(value.hashCode()); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java index f802799..362bcb3 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java @@ -57,11 +57,6 @@ public class DistinctCountHLLAggregationFunction extends BaseSingleInputAggregat } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java index fbb5ec4..e68b046 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java @@ -41,11 +41,6 @@ public class DistinctCountHLLMVAggregationFunction extends DistinctCountHLLAggre } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { HyperLogLog hyperLogLog = getDefaultHyperLogLog(aggregationResultHolder); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java index 7d8d5d4..fb6b2e3 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java @@ -27,6 +27,7 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.core.query.request.context.ExpressionContext; import org.apache.pinot.core.segment.index.readers.Dictionary; import org.apache.pinot.spi.data.FieldSpec; +import org.roaringbitmap.RoaringBitmap; public class DistinctCountMVAggregationFunction extends DistinctCountAggregationFunction { @@ -41,30 +42,23 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - IntOpenHashSet valueSet = getValueSet(aggregationResultHolder); - // For dictionary-encoded expression, store dictionary ids into the value set + // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; + RoaringBitmap dictIdBitmap = getDictIdBitmap(aggregationResultHolder, dictionary); int[][] dictIds = blockValSet.getDictionaryIdsMV(); for (int i = 0; i < length; i++) { - for (int dictId : dictIds[i]) { - valueSet.add(dictId); - } + dictIdBitmap.add(dictIds[i]); } return; } // For non-dictionary-encoded expression, store hash code of the values into the value set + IntOpenHashSet valueSet = getValueSet(aggregationResultHolder); FieldSpec.DataType valueType = blockValSet.getValueType(); switch (valueType) { case INT: @@ -116,16 +110,12 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation Map<ExpressionContext, BlockValSet> blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - // For dictionary-encoded expression, store dictionary ids into the value set + // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; int[][] dictIds = blockValSet.getDictionaryIdsMV(); for (int i = 0; i < length; i++) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKeyArray[i]); - for (int dictId : dictIds[i]) { - valueSet.add(dictId); - } + getDictIdBitmap(groupByResultHolder, groupKeyArray[i], dictionary).add(dictIds[i]); } return; } @@ -188,17 +178,13 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation Map<ExpressionContext, BlockValSet> blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - // For dictionary-encoded expression, store dictionary ids into the value set + // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - _dictionary = dictionary; int[][] dictIds = blockValSet.getDictionaryIdsMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKey); - for (int dictId : dictIds[i]) { - valueSet.add(dictId); - } + getDictIdBitmap(groupByResultHolder, groupKey, dictionary).add(dictIds[i]); } } return; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java index 768f203..47040a5 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java @@ -49,11 +49,6 @@ public class DistinctCountRawHLLAggregationFunction extends BaseSingleInputAggre } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - _distinctCountHLLAggregationFunction.accept(visitor); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return _distinctCountHLLAggregationFunction.createAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawThetaSketchAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawThetaSketchAggregationFunction.java index 9cf3cb9..c4185b0 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawThetaSketchAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawThetaSketchAggregationFunction.java @@ -70,11 +70,6 @@ public class DistinctCountRawThetaSketchAggregationFunction implements Aggregati } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - _thetaSketchAggregationFunction.accept(visitor); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return _thetaSketchAggregationFunction.createAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java index d598d3a..639d862 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.Stack; import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.commons.collections.MapUtils; import org.apache.datasketches.memory.Memory; import org.apache.datasketches.theta.Intersection; import org.apache.datasketches.theta.SetOperation; @@ -55,7 +56,6 @@ import org.apache.pinot.sql.parsers.CalciteSqlParser; * Theta Sketches. * <p>TODO: For performance concern, use {@code List<Sketch>} as the intermediate result. */ -@SuppressWarnings("Duplicates") public class DistinctCountThetaSketchAggregationFunction implements AggregationFunction<Map<String, Sketch>, Long> { private final ExpressionContext _thetaSketchColumn; private final ThetaSketchParams _thetaSketchParams; @@ -160,11 +160,6 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } @@ -387,7 +382,7 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF @Override public Map<String, Sketch> extractAggregationResult(AggregationResultHolder aggregationResultHolder) { Map<Predicate, Union> unionMap = aggregationResultHolder.getResult(); - if (unionMap == null || unionMap.isEmpty()) { + if (unionMap == null) { return Collections.emptyMap(); } @@ -406,7 +401,7 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF @Override public Map<String, Sketch> extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) { Map<Predicate, Union> unionMap = groupByResultHolder.getResult(groupKey); - if (unionMap == null || unionMap.isEmpty()) { + if (unionMap == null) { return Collections.emptyMap(); } @@ -424,9 +419,10 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF @Override public Map<String, Sketch> merge(Map<String, Sketch> intermediateResult1, Map<String, Sketch> intermediateResult2) { - if (intermediateResult1 == null || intermediateResult1.isEmpty()) { + if (MapUtils.isEmpty(intermediateResult1)) { return intermediateResult2; - } else if (intermediateResult2 == null || intermediateResult2.isEmpty()) { + } + if (MapUtils.isEmpty(intermediateResult2)) { return intermediateResult1; } @@ -435,14 +431,14 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF for (Map.Entry<String, Sketch> entry : intermediateResult1.entrySet()) { String predicate = entry.getKey(); Sketch sketch = intermediateResult2.get(predicate); - - // Merge the overlapping ones if (sketch != null) { + // Merge the overlapping ones Union union = getSetOperationBuilder().buildUnion(); union.update(entry.getValue()); union.update(sketch); mergedResult.put(predicate, union.getResult()); - } else { // Collect the non-overlapping ones + } else { + // Collect the non-overlapping ones mergedResult.put(predicate, entry.getValue()); } } @@ -456,7 +452,6 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF return mergedResult; } - @Override public boolean isIntermediateResultComparable() { return false; @@ -563,9 +558,6 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF * @return Final Sketch obtained by computing the post-aggregation expression on intermediate result */ protected Sketch extractFinalSketch(Map<String, Sketch> intermediateResult) { - // NOTE: Here we parse the map keys to Predicate to handle the non-standard predicate string returned from server - // side for backward-compatibility. - // TODO: Remove the extra parsing after releasing 0.5.0 Map<Predicate, Sketch> sketchMap = new HashMap<>(); for (Map.Entry<String, Sketch> entry : intermediateResult.entrySet()) { Predicate predicate = getPredicate(entry.getKey()); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java index e7e3689..43954d7 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java @@ -50,11 +50,6 @@ public class FastHLLAggregationFunction extends BaseSingleInputAggregationFuncti } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java index 21cb160..5bb9a5d 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java @@ -42,11 +42,6 @@ public class MaxAggregationFunction extends BaseSingleInputAggregationFunction<D } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new DoubleAggregationResultHolder(DEFAULT_INITIAL_VALUE); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java index 1962310..cf94fa0 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java @@ -38,11 +38,6 @@ public class MaxMVAggregationFunction extends MaxAggregationFunction { } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java index 0c2a6c3..0598244 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java @@ -42,11 +42,6 @@ public class MinAggregationFunction extends BaseSingleInputAggregationFunction<D } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new DoubleAggregationResultHolder(DEFAULT_VALUE); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java index 9b6890c..45e0471 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java @@ -38,11 +38,6 @@ public class MinMVAggregationFunction extends MinAggregationFunction { } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java index ac48136..ad827d8 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java @@ -44,11 +44,6 @@ public class MinMaxRangeAggregationFunction extends BaseSingleInputAggregationFu } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java index 5ac96f5..c594c2b 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java @@ -38,11 +38,6 @@ public class MinMaxRangeMVAggregationFunction extends MinMaxRangeAggregationFunc } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java index 474e5ef..e777410 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java @@ -57,11 +57,6 @@ public class PercentileAggregationFunction extends BaseSingleInputAggregationFun } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java index b44eb74..3d002da 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java @@ -58,11 +58,6 @@ public class PercentileEstAggregationFunction extends BaseSingleInputAggregation } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java index d7bb415..644a304 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java @@ -49,11 +49,6 @@ public class PercentileEstMVAggregationFunction extends PercentileEstAggregation } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { long[][] valuesArray = blockValSetMap.get(_expression).getLongValuesMV(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java index 2ac412c..23cf597 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java @@ -49,11 +49,6 @@ public class PercentileMVAggregationFunction extends PercentileAggregationFuncti } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java index cf78dd2..5d638e2 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java @@ -61,11 +61,6 @@ public class PercentileTDigestAggregationFunction extends BaseSingleInputAggrega } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java index 6ae8a5d..405f21f 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java @@ -49,11 +49,6 @@ public class PercentileTDigestMVAggregationFunction extends PercentileTDigestAgg } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java index 221969d..4859c85 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java @@ -58,11 +58,6 @@ public class SegmentPartitionedDistinctCountAggregationFunction extends BaseSing } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/StUnionAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/StUnionAggregationFunction.java index 8ae1d97..eff99db 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/StUnionAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/StUnionAggregationFunction.java @@ -50,11 +50,6 @@ public class StUnionAggregationFunction extends BaseSingleInputAggregationFuncti } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new ObjectAggregationResultHolder(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java index 4e84ee4..30f7a9a 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java @@ -42,11 +42,6 @@ public class SumAggregationFunction extends BaseSingleInputAggregationFunction<D } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public AggregationResultHolder createAggregationResultHolder() { return new DoubleAggregationResultHolder(DEFAULT_VALUE); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java index ca8bd13..f7664d5 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java @@ -38,11 +38,6 @@ public class SumMVAggregationFunction extends SumAggregationFunction { } @Override - public void accept(AggregationFunctionVisitorBase visitor) { - visitor.visit(this); - } - - @Override public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java index 904e51f..66c6313 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java @@ -20,10 +20,8 @@ package org.apache.pinot.core.query.reduce; import org.apache.pinot.common.function.AggregationFunctionType; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; -import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.aggregation.function.DistinctAggregationFunction; import org.apache.pinot.core.query.request.context.QueryContext; -import org.apache.pinot.core.query.request.context.utils.QueryContextUtils; /** @@ -36,12 +34,12 @@ public final class ResultReducerFactory { * Constructs the right result reducer based on the given query context. */ public static DataTableReducer getResultReducer(QueryContext queryContext) { - if (!QueryContextUtils.isAggregationQuery(queryContext)) { + AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions(); + if (aggregationFunctions == null) { // Selection query return new SelectionDataTableReducer(queryContext); } else { // Aggregation query - AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(queryContext); if (queryContext.getGroupByExpressions() == null) { // Aggregation only query if (aggregationFunctions.length == 1 && aggregationFunctions[0].getType() == AggregationFunctionType.DISTINCT) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java index 2d35d59..4f8e725 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java @@ -23,6 +23,8 @@ import java.util.List; import java.util.Map; import javax.annotation.Nullable; import org.apache.pinot.common.request.BrokerRequest; +import org.apache.pinot.core.query.aggregation.function.AggregationFunction; +import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; /** @@ -51,6 +53,7 @@ import org.apache.pinot.common.request.BrokerRequest; * </li> * </ul> */ +@SuppressWarnings("rawtypes") public class QueryContext { private final List<ExpressionContext> _selectExpressions; private final Map<ExpressionContext, String> _aliasMap; @@ -67,6 +70,9 @@ public class QueryContext { // TODO: Remove it once the whole query engine is using the QueryContext private final BrokerRequest _brokerRequest; + // Pre-generate the aggregation functions for the query so that it can be shared among all the segments + private AggregationFunction[] _aggregationFunctions; + private QueryContext(List<ExpressionContext> selectExpressions, Map<ExpressionContext, String> aliasMap, @Nullable FilterContext filter, @Nullable List<ExpressionContext> groupByExpressions, @Nullable List<OrderByExpressionContext> orderByExpressions, @Nullable FilterContext havingFilter, int limit, @@ -169,6 +175,14 @@ public class QueryContext { } /** + * Returns the aggregation functions for the query, or {@code null} if the query does not have any aggregation. + */ + @Nullable + public AggregationFunction[] getAggregationFunctions() { + return _aggregationFunctions; + } + + /** * NOTE: For debugging only. */ @Override @@ -250,8 +264,17 @@ public class QueryContext { public QueryContext build() { // TODO: Add validation logic here - return new QueryContext(_selectExpressions, _aliasMap, _filter, _groupByExpressions, _orderByExpressions, - _havingFilter, _limit, _offset, _queryOptions, _debugOptions, _brokerRequest); + QueryContext queryContext = + new QueryContext(_selectExpressions, _aliasMap, _filter, _groupByExpressions, _orderByExpressions, + _havingFilter, _limit, _offset, _queryOptions, _debugOptions, _brokerRequest); + + // Pre-generate the aggregation functions for the query + List<AggregationFunction> aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(queryContext); + if (!aggregationFunctions.isEmpty()) { + queryContext._aggregationFunctions = aggregationFunctions.toArray(new AggregationFunction[0]); + } + + return queryContext; } } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/utils/QueryContextUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/utils/QueryContextUtils.java index c9c186d..4e5a2c4 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/utils/QueryContextUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/utils/QueryContextUtils.java @@ -23,7 +23,6 @@ import java.util.List; import java.util.Set; import org.apache.pinot.core.query.request.context.ExpressionContext; import org.apache.pinot.core.query.request.context.FilterContext; -import org.apache.pinot.core.query.request.context.FunctionContext; import org.apache.pinot.core.query.request.context.OrderByExpressionContext; import org.apache.pinot.core.query.request.context.QueryContext; @@ -67,24 +66,8 @@ public class QueryContextUtils { /** * Returns {@code true} if the given query is an aggregation query, {@code false} otherwise. - * <p>A query is an aggregation query if there are aggregation functions in the SELECT clause or ORDER-BY clause. */ public static boolean isAggregationQuery(QueryContext query) { - for (ExpressionContext selectExpression : query.getSelectExpressions()) { - FunctionContext function = selectExpression.getFunction(); - if (function != null && function.getType() == FunctionContext.Type.AGGREGATION) { - return true; - } - } - List<OrderByExpressionContext> orderByExpressions = query.getOrderByExpressions(); - if (orderByExpressions != null) { - for (OrderByExpressionContext orderByExpression : orderByExpressions) { - FunctionContext function = orderByExpression.getExpression().getFunction(); - if (function != null && function.getType() == FunctionContext.Type.AGGREGATION) { - return true; - } - } - } - return false; + return query.getAggregationFunctions() != null; } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java index 6ee9235..7b46277 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java @@ -184,7 +184,8 @@ abstract class BaseStarTreeV2Test<R, A> { QueryContext queryContext = QueryContextConverterUtils.getQueryContextFromPQL(query); // Aggregations - AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(queryContext); + AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions(); + assert aggregationFunctions != null; int numAggregations = aggregationFunctions.length; List<AggregationFunctionColumnPair> functionColumnPairs = new ArrayList<>(numAggregations); for (AggregationFunction aggregationFunction : aggregationFunctions) { diff --git a/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java b/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java index d4160cc..618514f 100644 --- a/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java @@ -40,7 +40,6 @@ import org.apache.pinot.core.plan.DocIdSetPlanNode; import org.apache.pinot.core.query.aggregation.AggregationExecutor; import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; -import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.request.context.ExpressionContext; import org.apache.pinot.core.query.request.context.QueryContext; import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils; @@ -135,7 +134,8 @@ public class DefaultAggregationExecutorTest { ProjectionOperator projectionOperator = new ProjectionOperator(dataSourceMap, docIdSetOperator); TransformOperator transformOperator = new TransformOperator(projectionOperator, expressions); TransformBlock transformBlock = transformOperator.nextBlock(); - AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(_queryContext); + AggregationFunction[] aggregationFunctions = _queryContext.getAggregationFunctions(); + assert aggregationFunctions != null; AggregationExecutor aggregationExecutor = new DefaultAggregationExecutor(aggregationFunctions); aggregationExecutor.aggregate(transformBlock); List<Object> result = aggregationExecutor.getResult(); diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java index 442a205..d768282 100644 --- a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java @@ -41,7 +41,6 @@ import org.apache.pinot.core.data.table.ConcurrentIndexedTable; import org.apache.pinot.core.data.table.IndexedTable; import org.apache.pinot.core.data.table.Record; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; -import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByTrimmingService; import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator; import org.apache.pinot.core.query.request.context.QueryContext; @@ -101,7 +100,8 @@ public class BenchmarkCombineGroupBy { _queryContext = QueryContextConverterUtils .getQueryContextFromPQL("SELECT sum(m1), max(m2) FROM testTable GROUP BY d1, d2 ORDER BY sum(m1) TOP 500"); - _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(_queryContext); + _aggregationFunctions = _queryContext.getAggregationFunctions(); + assert _aggregationFunctions != null; _dataSchema = new DataSchema(new String[]{"d1", "d2", "sum(m1)", "max(m2)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE}); diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java index 55468f8..3d9d0b7 100644 --- a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java @@ -38,7 +38,6 @@ import org.apache.pinot.core.data.table.IndexedTable; import org.apache.pinot.core.data.table.Record; import org.apache.pinot.core.data.table.SimpleIndexedTable; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; -import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.request.context.QueryContext; import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils; import org.apache.pinot.core.util.trace.TraceRunnable; @@ -91,7 +90,8 @@ public class BenchmarkIndexedTable { _queryContext = QueryContextConverterUtils .getQueryContextFromPQL("SELECT sum(m1), max(m2) FROM testTable GROUP BY d1, d2 ORDER BY sum(m1) TOP 500"); - _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(_queryContext); + _aggregationFunctions = _queryContext.getAggregationFunctions(); + assert _aggregationFunctions != null; _dataSchema = new DataSchema(new String[]{"d1", "d2", "sum(m1)", "max(m2)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE}); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org