This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 303b1a7cbe [Issue 7519] Adds support for multiple filtered/unfiltered aggregations with GROUP BY (#10000) 303b1a7cbe is described below commit 303b1a7cbe78244491f0580eb88e966a41b56b25 Author: Evan Galpin <egal...@users.noreply.github.com> AuthorDate: Wed Jan 4 19:01:40 2023 -0800 [Issue 7519] Adds support for multiple filtered/unfiltered aggregations with GROUP BY (#10000) --- .../operator/query/FilteredGroupByOperator.java | 215 +++++++++++++++++++++ .../pinot/core/plan/AggregationPlanNode.java | 87 +-------- .../apache/pinot/core/plan/GroupByPlanNode.java | 30 ++- .../function/AggregationFunctionUtils.java | 94 +++++++++ .../groupby/DefaultGroupByExecutor.java | 56 ++++-- .../query/aggregation/groupby/GroupByExecutor.java | 4 + .../core/query/request/context/QueryContext.java | 6 +- .../query/aggregation/groupby/GroupByTrimTest.java | 9 +- .../pinot/queries/FilteredAggregationsTest.java | 57 +++++- 9 files changed, 445 insertions(+), 113 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java new file mode 100644 index 0000000000..e895d817dd --- /dev/null +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java @@ -0,0 +1,215 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.operator.query; + +import java.util.Collection; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.core.common.Operator; +import org.apache.pinot.core.data.table.IntermediateRecord; +import org.apache.pinot.core.data.table.TableResizer; +import org.apache.pinot.core.operator.BaseOperator; +import org.apache.pinot.core.operator.ExecutionStatistics; +import org.apache.pinot.core.operator.blocks.TransformBlock; +import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock; +import org.apache.pinot.core.operator.transform.TransformOperator; +import org.apache.pinot.core.query.aggregation.function.AggregationFunction; +import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult; +import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor; +import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; +import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator; +import org.apache.pinot.core.query.request.context.QueryContext; +import org.apache.pinot.core.util.GroupByUtils; +import org.apache.pinot.spi.trace.Tracing; + + +/** + * The <code>FilteredGroupByOperator</code> class provides the operator for group-by query on a single segment when + * there are 1 or more filter expressions on aggregations. + */ +@SuppressWarnings("rawtypes") +public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> { + private static final String EXPLAIN_NAME = "GROUP_BY_FILTERED"; + + private final AggregationFunction[] _aggregationFunctions; + private final List<Pair<AggregationFunction[], TransformOperator>> _aggFunctionsWithTransformOperator; + private final ExpressionContext[] _groupByExpressions; + private final long _numTotalDocs; + private long _numDocsScanned; + private long _numEntriesScannedInFilter; + private long _numEntriesScannedPostFilter; + private final DataSchema _dataSchema; + private final QueryContext _queryContext; + + public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions, + List<Pair<AggregationFunction[], TransformOperator>> aggFunctionsWithTransformOperator, + ExpressionContext[] groupByExpressions, long numTotalDocs, QueryContext queryContext) { + _aggregationFunctions = aggregationFunctions; + _aggFunctionsWithTransformOperator = aggFunctionsWithTransformOperator; + _groupByExpressions = groupByExpressions; + _numTotalDocs = numTotalDocs; + _queryContext = queryContext; + + // NOTE: The indexedTable expects that the data schema will have group by columns before aggregation columns + int numGroupByExpressions = groupByExpressions.length; + int numAggregationFunctions = aggregationFunctions.length; + int numColumns = numGroupByExpressions + numAggregationFunctions; + String[] columnNames = new String[numColumns]; + DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numColumns]; + + // Extract column names and data types for group-by columns + for (int i = 0; i < numGroupByExpressions; i++) { + ExpressionContext groupByExpression = groupByExpressions[i]; + columnNames[i] = groupByExpression.toString(); + columnDataTypes[i] = DataSchema.ColumnDataType.fromDataTypeSV( + aggFunctionsWithTransformOperator.get(i).getRight().getResultMetadata(groupByExpression).getDataType()); + } + + // Extract column names and data types for aggregation functions + for (int i = 0; i < numAggregationFunctions; i++) { + AggregationFunction aggregationFunction = aggregationFunctions[i]; + int index = numGroupByExpressions + i; + columnNames[index] = aggregationFunction.getResultColumnName(); + columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType(); + } + + _dataSchema = new DataSchema(columnNames, columnDataTypes); + } + + @Override + protected GroupByResultsBlock getNextBlock() { + // TODO(egalpin): Support Startree query resolution when possible, even with FILTER expressions + int numAggregations = _aggregationFunctions.length; + + GroupByResultHolder[] groupByResultHolders = new GroupByResultHolder[numAggregations]; + IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap = new IdentityHashMap<>(numAggregations); + for (int i = 0; i < numAggregations; i++) { + resultHolderIndexMap.put(_aggregationFunctions[i], i); + } + + GroupKeyGenerator groupKeyGenerator = null; + for (Pair<AggregationFunction[], TransformOperator> filteredAggregation : _aggFunctionsWithTransformOperator) { + TransformOperator transformOperator = filteredAggregation.getRight(); + AggregationFunction[] filteredAggFunctions = filteredAggregation.getLeft(); + + // Perform aggregation group-by on all the blocks + DefaultGroupByExecutor groupByExecutor; + if (groupKeyGenerator == null) { + // The group key generator should be shared across all AggregationFunctions so that agg results can be + // aligned. Given that filtered aggregations are stored as an iterable of iterables so that all filtered aggs + // with the same filter can share transform blocks, rather than a singular flat iterable in the case where + // aggs are all non-filtered, sharing a GroupKeyGenerator across all aggs cannot be accomplished by allowing + // the GroupByExecutor to have sole ownership of the GroupKeyGenerator. Therefore, we allow constructing a + // GroupByExecutor with a pre-existing GroupKeyGenerator so that the GroupKeyGenerator can be shared across + // loop iterations i.e. across all aggs. + groupByExecutor = + new DefaultGroupByExecutor(_queryContext, filteredAggFunctions, _groupByExpressions, transformOperator); + groupKeyGenerator = groupByExecutor.getGroupKeyGenerator(); + } else { + groupByExecutor = + new DefaultGroupByExecutor(_queryContext, filteredAggFunctions, _groupByExpressions, transformOperator, + groupKeyGenerator); + } + + int numDocsScanned = 0; + TransformBlock transformBlock; + while ((transformBlock = transformOperator.nextBlock()) != null) { + numDocsScanned += transformBlock.getNumDocs(); + groupByExecutor.process(transformBlock); + } + + _numDocsScanned += numDocsScanned; + _numEntriesScannedInFilter += transformOperator.getExecutionStatistics().getNumEntriesScannedInFilter(); + _numEntriesScannedPostFilter += (long) numDocsScanned * transformOperator.getNumColumnsProjected(); + GroupByResultHolder[] filterGroupByResults = groupByExecutor.getGroupByResultHolders(); + for (int i = 0; i < filteredAggFunctions.length; i++) { + groupByResultHolders[resultHolderIndexMap.get(filteredAggFunctions[i])] = filterGroupByResults[i]; + } + } + assert groupKeyGenerator != null; + for (GroupByResultHolder groupByResultHolder : groupByResultHolders) { + groupByResultHolder.ensureCapacity(groupKeyGenerator.getNumKeys()); + } + + // Check if the groups limit is reached + boolean numGroupsLimitReached = groupKeyGenerator.getNumKeys() >= _queryContext.getNumGroupsLimit(); + Tracing.activeRecording().setNumGroups(_queryContext.getNumGroupsLimit(), groupKeyGenerator.getNumKeys()); + + // Trim the groups when iff: + // - Query has ORDER BY clause + // - Segment group trim is enabled + // - There are more groups than the trim size + // TODO: Currently the groups are not trimmed if there is no ordering specified. Consider ordering on group-by + // columns if no ordering is specified. + int minGroupTrimSize = _queryContext.getMinSegmentGroupTrimSize(); + if (_queryContext.getOrderByExpressions() != null && minGroupTrimSize > 0) { + int trimSize = GroupByUtils.getTableCapacity(_queryContext.getLimit(), minGroupTrimSize); + if (groupKeyGenerator.getNumKeys() > trimSize) { + TableResizer tableResizer = new TableResizer(_dataSchema, _queryContext); + Collection<IntermediateRecord> intermediateRecords = + tableResizer.trimInSegmentResults(groupKeyGenerator, groupByResultHolders, trimSize); + GroupByResultsBlock resultsBlock = new GroupByResultsBlock(_dataSchema, intermediateRecords); + resultsBlock.setNumGroupsLimitReached(numGroupsLimitReached); + return resultsBlock; + } + } + + AggregationGroupByResult aggGroupByResult = + new AggregationGroupByResult(groupKeyGenerator, _aggregationFunctions, groupByResultHolders); + GroupByResultsBlock resultsBlock = new GroupByResultsBlock(_dataSchema, aggGroupByResult); + resultsBlock.setNumGroupsLimitReached(numGroupsLimitReached); + return resultsBlock; + } + + @Override + public List<Operator> getChildOperators() { + return _aggFunctionsWithTransformOperator.stream().map(Pair::getRight).collect(Collectors.toList()); + } + + @Override + public ExecutionStatistics getExecutionStatistics() { + return new ExecutionStatistics(_numDocsScanned, _numEntriesScannedInFilter, _numEntriesScannedPostFilter, + _numTotalDocs); + } + + @Override + public String toExplainString() { + StringBuilder stringBuilder = new StringBuilder(EXPLAIN_NAME).append("(groupKeys:"); + if (_groupByExpressions.length > 0) { + stringBuilder.append(_groupByExpressions[0].toString()); + for (int i = 1; i < _groupByExpressions.length; i++) { + stringBuilder.append(", ").append(_groupByExpressions[i].toString()); + } + } + + stringBuilder.append(", aggregations:"); + if (_aggregationFunctions.length > 0) { + stringBuilder.append(_aggregationFunctions[0].toExplainString()); + for (int i = 1; i < _aggregationFunctions.length; i++) { + stringBuilder.append(", ").append(_aggregationFunctions[i].toExplainString()); + } + } + + return stringBuilder.append(')').toString(); + } +} diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java index 58d74fb00f..148911897e 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java @@ -18,19 +18,15 @@ */ package org.apache.pinot.core.plan; -import java.util.ArrayList; import java.util.EnumSet; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.request.context.ExpressionContext; -import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.core.common.Operator; import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock; import org.apache.pinot.core.operator.filter.BaseFilterOperator; -import org.apache.pinot.core.operator.filter.CombinedFilterOperator; import org.apache.pinot.core.operator.query.AggregationOperator; import org.apache.pinot.core.operator.query.FastFilteredCountOperator; import org.apache.pinot.core.operator.query.FilteredAggregationOperator; @@ -77,7 +73,7 @@ public class AggregationPlanNode implements PlanNode { @Override public Operator<AggregationResultsBlock> run() { assert _queryContext.getAggregationFunctions() != null; - return _queryContext.isHasFilteredAggregations() ? buildFilteredAggOperator() : buildNonFilteredAggOperator(); + return _queryContext.hasFilteredAggregations() ? buildFilteredAggOperator() : buildNonFilteredAggOperator(); } /** @@ -86,83 +82,18 @@ public class AggregationPlanNode implements PlanNode { private FilteredAggregationOperator buildFilteredAggOperator() { int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs(); // Build the operator chain for the main predicate - Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair = buildFilterOperator(_queryContext.getFilter()); - TransformOperator transformOperator = buildTransformOperatorForFilteredAggregates(filterOperatorPair.getRight()); - - return buildFilterOperatorInternal(filterOperatorPair.getRight(), transformOperator, numTotalDocs); - } - - /** - * Build a FilteredAggregationOperator given the parameters. - * @param mainPredicateFilterOperator Filter operator corresponding to the main predicate - * @param mainTransformOperator Transform operator corresponding to the main predicate - * @param numTotalDocs Number of total docs - */ - private FilteredAggregationOperator buildFilterOperatorInternal(BaseFilterOperator mainPredicateFilterOperator, - TransformOperator mainTransformOperator, int numTotalDocs) { - Map<FilterContext, Pair<List<AggregationFunction>, TransformOperator>> filterContextToAggFuncsMap = new HashMap<>(); - List<AggregationFunction> nonFilteredAggregationFunctions = new ArrayList<>(); - List<Pair<AggregationFunction, FilterContext>> aggregationFunctions = - _queryContext.getFilteredAggregationFunctions(); - - // For each aggregation function, check if the aggregation function is a filtered agg. - // If it is, populate the corresponding filter operator and corresponding transform operator - for (Pair<AggregationFunction, FilterContext> inputPair : aggregationFunctions) { - if (inputPair.getLeft() != null) { - FilterContext currentFilterExpression = inputPair.getRight(); - if (filterContextToAggFuncsMap.get(currentFilterExpression) != null) { - filterContextToAggFuncsMap.get(currentFilterExpression).getLeft().add(inputPair.getLeft()); - continue; - } - Pair<FilterPlanNode, BaseFilterOperator> pair = buildFilterOperator(currentFilterExpression); - BaseFilterOperator wrappedFilterOperator = - new CombinedFilterOperator(mainPredicateFilterOperator, pair.getRight(), _queryContext.getQueryOptions()); - TransformOperator newTransformOperator = buildTransformOperatorForFilteredAggregates(wrappedFilterOperator); - // For each transform operator, associate it with the underlying expression. This allows - // fetching the relevant TransformOperator when resolving blocks during aggregation - // execution - List<AggregationFunction> aggFunctionList = new ArrayList<>(); - aggFunctionList.add(inputPair.getLeft()); - filterContextToAggFuncsMap.put(currentFilterExpression, Pair.of(aggFunctionList, newTransformOperator)); - } else { - nonFilteredAggregationFunctions.add(inputPair.getLeft()); - } - } - List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList = new ArrayList<>(); - // Convert to array since FilteredAggregationOperator expects it - for (Pair<List<AggregationFunction>, TransformOperator> pair : filterContextToAggFuncsMap.values()) { - List<AggregationFunction> aggregationFunctionList = pair.getLeft(); - if (aggregationFunctionList == null) { - throw new IllegalStateException("Null aggregation list seen"); - } - aggToTransformOpList.add(Pair.of(aggregationFunctionList.toArray(new AggregationFunction[0]), pair.getRight())); - } - aggToTransformOpList.add( - Pair.of(nonFilteredAggregationFunctions.toArray(new AggregationFunction[0]), mainTransformOperator)); + Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair = + AggregationFunctionUtils.buildFilterOperator(_indexSegment, _queryContext); + TransformOperator transformOperator = + AggregationFunctionUtils.buildTransformOperatorForFilteredAggregates(_indexSegment, _queryContext, + filterOperatorPair.getRight(), null); + List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList = + AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment, _queryContext, + filterOperatorPair.getRight(), transformOperator, null); return new FilteredAggregationOperator(_queryContext.getAggregationFunctions(), aggToTransformOpList, numTotalDocs); } - /** - * Build a filter operator from the given FilterContext. - * - * It returns the FilterPlanNode to allow reusing plan level components such as predicate - * evaluator map - */ - private Pair<FilterPlanNode, BaseFilterOperator> buildFilterOperator(FilterContext filterContext) { - FilterPlanNode filterPlanNode = new FilterPlanNode(_indexSegment, _queryContext, filterContext); - return Pair.of(filterPlanNode, filterPlanNode.run()); - } - - private TransformOperator buildTransformOperatorForFilteredAggregates(BaseFilterOperator filterOperator) { - AggregationFunction[] aggregationFunctions = _queryContext.getAggregationFunctions(); - Set<ExpressionContext> expressionsToTransform = - AggregationFunctionUtils.collectExpressionsToTransform(aggregationFunctions, null); - - return new TransformPlanNode(_indexSegment, _queryContext, expressionsToTransform, - DocIdSetPlanNode.MAX_DOC_PER_CALL, filterOperator).run(); - } - /** * Processing workhorse for non filtered aggregates. Note that this code path is invoked only * if the query has no filtered aggregates at all. If a query has mixed aggregates, filtered diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java index 2b5da7896b..99fdec9746 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java @@ -21,8 +21,12 @@ package org.apache.pinot.core.plan; import java.util.List; import java.util.Map; import java.util.Set; +import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.core.common.Operator; +import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock; import org.apache.pinot.core.operator.filter.BaseFilterOperator; +import org.apache.pinot.core.operator.query.FilteredGroupByOperator; import org.apache.pinot.core.operator.query.GroupByOperator; import org.apache.pinot.core.operator.transform.TransformOperator; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; @@ -50,10 +54,34 @@ public class GroupByPlanNode implements PlanNode { } @Override - public GroupByOperator run() { + public Operator<GroupByResultsBlock> run() { assert _queryContext.getAggregationFunctions() != null; assert _queryContext.getGroupByExpressions() != null; + if (_queryContext.hasFilteredAggregations()) { + return buildFilteredGroupByPlan(); + } + return buildNonFilteredGroupByPlan(); + } + + private FilteredGroupByOperator buildFilteredGroupByPlan() { + int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs(); + // Build the operator chain for the main predicate so the filter plan can be run only one time + Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair = + AggregationFunctionUtils.buildFilterOperator(_indexSegment, _queryContext); + ExpressionContext[] groupByExpressions = _queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]); + TransformOperator transformOperator = + AggregationFunctionUtils.buildTransformOperatorForFilteredAggregates(_indexSegment, _queryContext, + filterOperatorPair.getRight(), groupByExpressions); + + List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList = + AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment, _queryContext, + filterOperatorPair.getRight(), transformOperator, groupByExpressions); + return new FilteredGroupByOperator(_queryContext.getAggregationFunctions(), aggToTransformOpList, + _queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]), numTotalDocs, _queryContext); + } + + private GroupByOperator buildNonFilteredGroupByPlan() { int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs(); AggregationFunction[] aggregationFunctions = _queryContext.getAggregationFunctions(); ExpressionContext[] groupByExpressions = _queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java index 8ef21fa1b4..0dcecb046d 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java @@ -18,6 +18,7 @@ */ package org.apache.pinot.core.query.aggregation.function; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -26,13 +27,23 @@ import java.util.List; import java.util.Map; import java.util.Set; import javax.annotation.Nullable; +import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.datatable.DataTable; import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.common.ObjectSerDeUtils; import org.apache.pinot.core.operator.blocks.TransformBlock; +import org.apache.pinot.core.operator.filter.BaseFilterOperator; +import org.apache.pinot.core.operator.filter.CombinedFilterOperator; +import org.apache.pinot.core.operator.transform.TransformOperator; +import org.apache.pinot.core.plan.DocIdSetPlanNode; +import org.apache.pinot.core.plan.FilterPlanNode; +import org.apache.pinot.core.plan.TransformPlanNode; +import org.apache.pinot.core.query.request.context.QueryContext; import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.apache.pinot.segment.spi.IndexSegment; import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair; @@ -165,4 +176,87 @@ public class AggregationFunctionUtils { throw new IllegalStateException("Illegal column data type in final result: " + columnDataType); } } + + /** + * Build a filter operator from the given FilterContext. + * + * It returns the FilterPlanNode to allow reusing plan level components such as predicate + * evaluator map + */ + public static Pair<FilterPlanNode, BaseFilterOperator> buildFilterOperator(IndexSegment indexSegment, + QueryContext queryContext, FilterContext filterContext) { + FilterPlanNode filterPlanNode = new FilterPlanNode(indexSegment, queryContext, filterContext); + return Pair.of(filterPlanNode, filterPlanNode.run()); + } + + public static Pair<FilterPlanNode, BaseFilterOperator> buildFilterOperator(IndexSegment indexSegment, + QueryContext queryContext) { + return buildFilterOperator(indexSegment, queryContext, queryContext.getFilter()); + } + + public static TransformOperator buildTransformOperatorForFilteredAggregates(IndexSegment indexSegment, + QueryContext queryContext, BaseFilterOperator filterOperator, @Nullable ExpressionContext[] groupByExpressions) { + AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions(); + assert aggregationFunctions != null; + Set<ExpressionContext> expressionsToTransform = + collectExpressionsToTransform(aggregationFunctions, groupByExpressions); + return new TransformPlanNode(indexSegment, queryContext, expressionsToTransform, DocIdSetPlanNode.MAX_DOC_PER_CALL, + filterOperator).run(); + } + + /** + * Build pairs of filtered aggregation functions and corresponding transform operator + * @param mainPredicateFilterOperator Filter operator corresponding to the main predicate + * @param mainTransformOperator Transform operator corresponding to the main predicate + */ + public static List<Pair<AggregationFunction[], TransformOperator>> buildFilteredAggTransformPairs( + IndexSegment indexSegment, QueryContext queryContext, BaseFilterOperator mainPredicateFilterOperator, + TransformOperator mainTransformOperator, @Nullable ExpressionContext[] groupByExpressions) { + Map<FilterContext, Pair<List<AggregationFunction>, TransformOperator>> filterContextToAggFuncsMap = new HashMap<>(); + List<AggregationFunction> nonFilteredAggregationFunctions = new ArrayList<>(); + List<Pair<AggregationFunction, FilterContext>> aggregationFunctions = + queryContext.getFilteredAggregationFunctions(); + List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList = new ArrayList<>(); + + // For each aggregation function, check if the aggregation function is a filtered agg. + // If it is, populate the corresponding filter operator and corresponding transform operator + assert aggregationFunctions != null; + for (Pair<AggregationFunction, FilterContext> inputPair : aggregationFunctions) { + if (inputPair.getLeft() != null) { + FilterContext currentFilterExpression = inputPair.getRight(); + if (filterContextToAggFuncsMap.get(currentFilterExpression) != null) { + filterContextToAggFuncsMap.get(currentFilterExpression).getLeft().add(inputPair.getLeft()); + continue; + } + Pair<FilterPlanNode, BaseFilterOperator> filterPlanOpPair = + buildFilterOperator(indexSegment, queryContext, currentFilterExpression); + BaseFilterOperator wrappedFilterOperator = + new CombinedFilterOperator(mainPredicateFilterOperator, filterPlanOpPair.getRight(), + queryContext.getQueryOptions()); + TransformOperator newTransformOperator = + buildTransformOperatorForFilteredAggregates(indexSegment, queryContext, wrappedFilterOperator, + groupByExpressions); + // For each transform operator, associate it with the underlying expression. This allows + // fetching the relevant TransformOperator when resolving blocks during aggregation + // execution + List<AggregationFunction> aggFunctionList = new ArrayList<>(); + aggFunctionList.add(inputPair.getLeft()); + filterContextToAggFuncsMap.put(currentFilterExpression, Pair.of(aggFunctionList, newTransformOperator)); + } else { + nonFilteredAggregationFunctions.add(inputPair.getLeft()); + } + } + // Convert to array since FilteredGroupByOperator expects it + for (Pair<List<AggregationFunction>, TransformOperator> pair : filterContextToAggFuncsMap.values()) { + List<AggregationFunction> aggregationFunctionList = pair.getLeft(); + if (aggregationFunctionList == null) { + throw new IllegalStateException("Null aggregation list seen"); + } + aggToTransformOpList.add(Pair.of(aggregationFunctionList.toArray(new AggregationFunction[0]), pair.getRight())); + } + aggToTransformOpList.add( + Pair.of(nonFilteredAggregationFunctions.toArray(new AggregationFunction[0]), mainTransformOperator)); + + return aggToTransformOpList; + } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java index e0af94070c..38ebd3706c 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java @@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.groupby; import java.util.Collection; import java.util.Map; +import javax.annotation.Nullable; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.data.table.IntermediateRecord; @@ -58,16 +59,28 @@ public class DefaultGroupByExecutor implements GroupByExecutor { protected final int[] _svGroupKeys; protected final int[][] _mvGroupKeys; + public DefaultGroupByExecutor(QueryContext queryContext, ExpressionContext[] groupByExpressions, + TransformOperator transformOperator) { + this(queryContext, queryContext.getAggregationFunctions(), groupByExpressions, transformOperator, null); + } + + public DefaultGroupByExecutor(QueryContext queryContext, AggregationFunction[] aggregationFunctions, + ExpressionContext[] groupByExpressions, TransformOperator transformOperator) { + this(queryContext, aggregationFunctions, groupByExpressions, transformOperator, null); + } + /** * Constructor for the class. * * @param queryContext Query context + * @param aggregationFunctions Aggregation functions * @param groupByExpressions Array of group-by expressions * @param transformOperator Transform operator */ - public DefaultGroupByExecutor(QueryContext queryContext, ExpressionContext[] groupByExpressions, - TransformOperator transformOperator) { - _aggregationFunctions = queryContext.getAggregationFunctions(); + public DefaultGroupByExecutor(QueryContext queryContext, AggregationFunction[] aggregationFunctions, + ExpressionContext[] groupByExpressions, TransformOperator transformOperator, + @Nullable GroupKeyGenerator groupKeyGenerator) { + _aggregationFunctions = aggregationFunctions; assert _aggregationFunctions != null; _nullHandlingEnabled = queryContext.isNullHandlingEnabled(); @@ -83,19 +96,23 @@ public class DefaultGroupByExecutor implements GroupByExecutor { // Initialize group key generator int numGroupsLimit = queryContext.getNumGroupsLimit(); int maxInitialResultHolderCapacity = queryContext.getMaxInitialResultHolderCapacity(); - if (hasNoDictionaryGroupByExpression || _nullHandlingEnabled) { - if (groupByExpressions.length == 1) { - // TODO(nhejazi): support MV and dictionary based when null handling is enabled. - _groupKeyGenerator = - new NoDictionarySingleColumnGroupKeyGenerator(transformOperator, groupByExpressions[0], numGroupsLimit, - _nullHandlingEnabled); + if (groupKeyGenerator != null) { + _groupKeyGenerator = groupKeyGenerator; + } else { + if (hasNoDictionaryGroupByExpression || _nullHandlingEnabled) { + if (groupByExpressions.length == 1) { + // TODO(nhejazi): support MV and dictionary based when null handling is enabled. + _groupKeyGenerator = + new NoDictionarySingleColumnGroupKeyGenerator(transformOperator, groupByExpressions[0], numGroupsLimit, + _nullHandlingEnabled); + } else { + _groupKeyGenerator = + new NoDictionaryMultiColumnGroupKeyGenerator(transformOperator, groupByExpressions, numGroupsLimit); + } } else { - _groupKeyGenerator = - new NoDictionaryMultiColumnGroupKeyGenerator(transformOperator, groupByExpressions, numGroupsLimit); + _groupKeyGenerator = new DictionaryBasedGroupKeyGenerator(transformOperator, groupByExpressions, numGroupsLimit, + maxInitialResultHolderCapacity); } - } else { - _groupKeyGenerator = new DictionaryBasedGroupKeyGenerator(transformOperator, groupByExpressions, numGroupsLimit, - maxInitialResultHolderCapacity); } // Initialize result holders @@ -141,7 +158,6 @@ public class DefaultGroupByExecutor implements GroupByExecutor { AggregationFunction aggregationFunction = _aggregationFunctions[functionIndex]; Map<ExpressionContext, BlockValSet> blockValSetMap = AggregationFunctionUtils.getBlockValSetMap(aggregationFunction, transformBlock); - GroupByResultHolder groupByResultHolder = _groupByResultHolders[functionIndex]; if (_hasMVGroupByExpression) { aggregationFunction.aggregateGroupByMV(length, _mvGroupKeys, groupByResultHolder, blockValSetMap); @@ -164,4 +180,14 @@ public class DefaultGroupByExecutor implements GroupByExecutor { public Collection<IntermediateRecord> trimGroupByResult(int trimSize, TableResizer tableResizer) { return tableResizer.trimInSegmentResults(_groupKeyGenerator, _groupByResultHolders, trimSize); } + + @Override + public GroupKeyGenerator getGroupKeyGenerator() { + return _groupKeyGenerator; + } + + @Override + public GroupByResultHolder[] getGroupByResultHolders() { + return _groupByResultHolders; + } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java index 869ef5dbe9..db5ff16b18 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java @@ -58,4 +58,8 @@ public interface GroupByExecutor { * */ Collection<IntermediateRecord> trimGroupByResult(int trimSize, TableResizer tableResizer); + + GroupKeyGenerator getGroupKeyGenerator(); + + GroupByResultHolder[] getGroupByResultHolders(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java index 5c1bd2fe84..fcc97dd6fd 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java @@ -260,7 +260,7 @@ public class QueryContext { /** * Returns the filtered aggregation expressions for the query. */ - public boolean isHasFilteredAggregations() { + public boolean hasFilteredAggregations() { return _hasFilteredAggregations; } @@ -536,10 +536,6 @@ public class QueryContext { FunctionContext aggregation = pair.getLeft(); FilterContext filter = pair.getRight(); if (filter != null) { - // Filtered aggregation - if (_groupByExpressions != null) { - throw new IllegalStateException("GROUP BY with FILTER clauses is not supported"); - } queryContext._hasFilteredAggregations = true; } int functionIndex = filteredAggregationFunctions.size(); diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java index 62236f3a4b..dba3faefe6 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java @@ -29,11 +29,11 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.tuple.Pair; +import org.apache.pinot.core.common.Operator; import org.apache.pinot.core.data.table.Record; import org.apache.pinot.core.data.table.Table; import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock; import org.apache.pinot.core.operator.combine.GroupByCombineOperator; -import org.apache.pinot.core.operator.query.GroupByOperator; import org.apache.pinot.core.plan.GroupByPlanNode; import org.apache.pinot.core.query.request.context.QueryContext; import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils; @@ -50,13 +50,12 @@ import org.apache.pinot.spi.data.readers.GenericRow; import org.apache.pinot.spi.utils.CommonConstants; import org.apache.pinot.spi.utils.ReadMode; import org.apache.pinot.spi.utils.builder.TableConfigBuilder; +import org.testng.Assert; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; -import static org.testng.Assert.assertEquals; - /** * Unit test for GroupBy Trim functionalities. @@ -120,7 +119,7 @@ public class GroupByTrimTest { queryContext.setMinServerGroupTrimSize(minServerGroupTrimSize); // Create a query operator - GroupByOperator groupByOperator = new GroupByPlanNode(_indexSegment, queryContext).run(); + Operator<GroupByResultsBlock> groupByOperator = new GroupByPlanNode(_indexSegment, queryContext).run(); GroupByCombineOperator combineOperator = new GroupByCombineOperator(Collections.singletonList(groupByOperator), queryContext, _executorService); @@ -130,7 +129,7 @@ public class GroupByTrimTest { // Extract the execution result List<Pair<Double, Double>> extractedResult = extractTestResult(resultsBlock.getTable()); - assertEquals(extractedResult, expectedResult); + Assert.assertEquals(extractedResult, expectedResult); } /** diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java index 2fc9ad1fa6..9d772abc3f 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java @@ -202,10 +202,10 @@ public class FilteredAggregationsTest extends BaseQueriesTest { nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true AND STARTSWITH(STRING_COL, 'abc')"; testQuery(filterQuery, nonFilterQuery); - filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND STARTSWITH(REVERSE(STRING_COL), 'abc')) FROM " - + "MyTable"; - nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true AND STARTSWITH(REVERSE(STRING_COL), " - + "'abc')"; + filterQuery = + "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND STARTSWITH(REVERSE(STRING_COL), 'abc')) FROM " + "MyTable"; + nonFilterQuery = + "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true AND STARTSWITH(REVERSE(STRING_COL), " + "'abc')"; testQuery(filterQuery, nonFilterQuery); } @@ -335,10 +335,49 @@ public class FilteredAggregationsTest extends BaseQueriesTest { testQuery(filterQuery, nonFilterQuery); } - @Test(expectedExceptions = IllegalStateException.class) - public void testGroupBySupport() { - String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 2), MAX(INT_COL) FILTER(WHERE INT_COL > 2) " - + "FROM MyTable WHERE INT_COL < 1000 GROUP BY INT_COL"; - getBrokerResponse(filterQuery); + @Test + public void testGroupBy() { + String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 25000) FROM MyTable GROUP BY BOOLEAN_COL"; + String nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 25000 GROUP BY BOOLEAN_COL"; + testQuery(filterQuery, nonFilterQuery); + } + + @Test + public void testGroupByCaseAlternative() { + String filterQuery = + "SELECT SUM(INT_COL), SUM(INT_COL) FILTER(WHERE INT_COL > 25000) AS total_sum FROM MyTable GROUP BY " + + "BOOLEAN_COL"; + String nonFilterQuery = + "SELECT SUM(INT_COL), SUM(CASE WHEN INT_COL > 25000 THEN INT_COL ELSE 0 END) AS total_sum FROM MyTable GROUP " + + "BY BOOLEAN_COL"; + testQuery(filterQuery, nonFilterQuery); + } + + @Test + public void testGroupBySameFilter() { + String filterQuery = + "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000), SUM(INT_COL) FILTER(WHERE INT_COL > 25000) FROM MyTable " + + "GROUP BY BOOLEAN_COL"; + String nonFilterQuery = "SELECT AVG(INT_COL), SUM(INT_COL) FROM MyTable WHERE INT_COL > 25000 GROUP BY BOOLEAN_COL"; + testQuery(filterQuery, nonFilterQuery); + } + + @Test + public void testMultipleAggregationsOnSameFilterGroupBy() { + String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990), " + + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) FROM MyTable GROUP BY BOOLEAN_COL"; + String nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable WHERE INT_COL > 29990 GROUP BY BOOLEAN_COL"; + testQuery(filterQuery, nonFilterQuery); + + filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS total_min, " + + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) AS total_max, " + + "SUM(INT_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_sum, " + + "MAX(NO_INDEX_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_max2 FROM MyTable GROUP BY BOOLEAN_COL"; + nonFilterQuery = "SELECT MIN(CASE WHEN (NO_INDEX_COL > 29990) THEN INT_COL ELSE 99999 END) AS total_min, " + + "MAX(CASE WHEN (INT_COL > 29990) THEN INT_COL ELSE 0 END) AS total_max, " + + "SUM(CASE WHEN (NO_INDEX_COL < 5000) THEN INT_COL ELSE 0 END) AS total_sum, " + + "MAX(CASE WHEN (NO_INDEX_COL < 5000) THEN NO_INDEX_COL ELSE 0 END) AS total_max2 FROM MyTable GROUP BY " + + "BOOLEAN_COL"; + testQuery(filterQuery, nonFilterQuery); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org