This is an automated email from the ASF dual-hosted git repository. atri pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 6a82102 FILTER Clauses for Aggregates (#7916) 6a82102 is described below commit 6a8210284a3de617d87083759f3b641e006dd980 Author: Atri Sharma <atri.j...@gmail.com> AuthorDate: Mon Jan 31 14:24:01 2022 +0530 FILTER Clauses for Aggregates (#7916) This PR implements support for FILTER clauses in aggregations: SELECT SUM(COL1) FILTER(WHERE COL2 > 300), AVG(COL2) FILTER (WHERE COL2 < 50) FROM MyTable WHERE COL1 > 50; The approach implements the swim lane design highlighted in the design document by splitting at the filter operator. The implementation gets the filter block for main predicate and each filter predicate, ANDs them together and returns a combined filter operator. The main predicate is scanned only once and reused for all filter clauses. The implementation allows each filter swim lane to use any available indices independently. If two or more filter clauses have the same predicate, the result will be computed only once and fed to each of the aggregates. https://docs.google.com/document/d/1ZM-2c0jJkbeJ61m8sJF0qj19t5UYLhnTFvIAz-HCJmk/edit?usp=sharing Performance benchmark: 3 warm up iterations per run, 5 runs in total. Data set size -- 1.5 million documents. Apple M1 Pro, 32GB RAM X axis represents number of iterations and Y axis represents latency in MS. FILTER query, compared to its equivalent CASE query, is 120-140% faster on average. --- .../pinot/core/operator/blocks/FilterBlock.java | 13 +- .../core/operator/docidsets/BitmapDocIdSet.java | 12 +- .../operator/docidsets/FilterBlockDocIdSet.java | 40 ++ ...pDocIdSet.java => RangelessBitmapDocIdSet.java} | 20 +- .../operator/filter/CombinedFilterOperator.java | 66 ++++ .../query/FilteredAggregationOperator.java | 116 ++++++ .../pinot/core/plan/AggregationPlanNode.java | 98 +++++ .../org/apache/pinot/core/plan/FilterPlanNode.java | 10 +- .../core/query/reduce/PostAggregationHandler.java | 13 + .../core/query/request/context/QueryContext.java | 117 ++++-- .../BrokerRequestToQueryContextConverter.java | 2 + .../BrokerRequestToQueryContextConverterTest.java | 2 +- .../pinot/core/startree/v2/BaseStarTreeV2Test.java | 31 +- .../pinot/queries/FilteredAggregationsTest.java | 429 +++++++++++++++++++++ ...erSegmentAggregationSingleValueQueriesTest.java | 47 +++ ...terSegmentAggregationMultiValueQueriesTest.java | 11 + .../org/apache/pinot/queries/QueriesTestUtils.java | 12 + .../pinot/perf/BenchmarkFilteredAggregations.java | 197 ++++++++++ 18 files changed, 1175 insertions(+), 61 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/FilterBlock.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/FilterBlock.java index 1f87255..8afdd8c 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/FilterBlock.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/FilterBlock.java @@ -30,14 +30,25 @@ import org.apache.pinot.core.operator.docidsets.FilterBlockDocIdSet; */ public class FilterBlock implements Block { private final FilterBlockDocIdSet _filterBlockDocIdSet; + private FilterBlockDocIdSet _nonScanFilterBlockDocIdSet; public FilterBlock(FilterBlockDocIdSet filterBlockDocIdSet) { _filterBlockDocIdSet = filterBlockDocIdSet; } + /** + * Pre-scans the documents if needed, and returns a non-scan-based FilterBlockDocIdSet. + */ + public FilterBlockDocIdSet getNonScanFilterBLockDocIdSet() { + if (_nonScanFilterBlockDocIdSet == null) { + _nonScanFilterBlockDocIdSet = _filterBlockDocIdSet.toNonScanDocIdSet(); + } + return _nonScanFilterBlockDocIdSet; + } + @Override public FilterBlockDocIdSet getBlockDocIdSet() { - return _filterBlockDocIdSet; + return _nonScanFilterBlockDocIdSet != null ? _nonScanFilterBlockDocIdSet : _filterBlockDocIdSet; } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/BitmapDocIdSet.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/BitmapDocIdSet.java index eacd4e3..a69ac00 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/BitmapDocIdSet.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/BitmapDocIdSet.java @@ -23,17 +23,19 @@ import org.roaringbitmap.buffer.ImmutableRoaringBitmap; public class BitmapDocIdSet implements FilterBlockDocIdSet { - private final ImmutableRoaringBitmap _docIds; - private final int _numDocs; + private final BitmapDocIdIterator _iterator; public BitmapDocIdSet(ImmutableRoaringBitmap docIds, int numDocs) { - _docIds = docIds; - _numDocs = numDocs; + _iterator = new BitmapDocIdIterator(docIds, numDocs); + } + + public BitmapDocIdSet(BitmapDocIdIterator iterator) { + _iterator = iterator; } @Override public BitmapDocIdIterator iterator() { - return new BitmapDocIdIterator(_docIds, _numDocs); + return _iterator; } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/FilterBlockDocIdSet.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/FilterBlockDocIdSet.java index 92b6ac7..18dab2b 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/FilterBlockDocIdSet.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/FilterBlockDocIdSet.java @@ -18,7 +18,16 @@ */ package org.apache.pinot.core.operator.docidsets; +import org.apache.pinot.core.common.BlockDocIdIterator; import org.apache.pinot.core.common.BlockDocIdSet; +import org.apache.pinot.core.operator.dociditerators.AndDocIdIterator; +import org.apache.pinot.core.operator.dociditerators.BitmapDocIdIterator; +import org.apache.pinot.core.operator.dociditerators.OrDocIdIterator; +import org.apache.pinot.core.operator.dociditerators.RangelessBitmapDocIdIterator; +import org.apache.pinot.core.operator.dociditerators.ScanBasedDocIdIterator; +import org.apache.pinot.segment.spi.Constants; +import org.roaringbitmap.RoaringBitmapWriter; +import org.roaringbitmap.buffer.MutableRoaringBitmap; /** @@ -32,4 +41,35 @@ public interface FilterBlockDocIdSet extends BlockDocIdSet { * filtering phase. This method should be called after the filtering is done. */ long getNumEntriesScannedInFilter(); + + /** + * For scan-based FilterBlockDocIdSet, pre-scans the documents and returns a non-scan-based FilterBlockDocIdSet. + */ + default FilterBlockDocIdSet toNonScanDocIdSet() { + BlockDocIdIterator docIdIterator = iterator(); + + // NOTE: AND and OR DocIdIterator might contain scan-based DocIdIterator + // TODO: This scan is not counted in the execution stats + if (docIdIterator instanceof ScanBasedDocIdIterator || docIdIterator instanceof AndDocIdIterator + || docIdIterator instanceof OrDocIdIterator) { + RoaringBitmapWriter<MutableRoaringBitmap> bitmapWriter = + RoaringBitmapWriter.bufferWriter().runCompress(false).get(); + int docId; + while ((docId = docIdIterator.next()) != Constants.EOF) { + bitmapWriter.add(docId); + } + return new RangelessBitmapDocIdSet(bitmapWriter.get()); + } + + // NOTE: AND and OR DocIdSet might return BitmapBasedDocIdIterator after processing the iterators. Create a new + // DocIdSet to prevent processing the iterators again + if (docIdIterator instanceof RangelessBitmapDocIdIterator) { + return new RangelessBitmapDocIdSet((RangelessBitmapDocIdIterator) docIdIterator); + } + if (docIdIterator instanceof BitmapDocIdIterator) { + return new BitmapDocIdSet((BitmapDocIdIterator) docIdIterator); + } + + return this; + } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/BitmapDocIdSet.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/RangelessBitmapDocIdSet.java similarity index 66% copy from pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/BitmapDocIdSet.java copy to pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/RangelessBitmapDocIdSet.java index eacd4e3..463a2df 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/BitmapDocIdSet.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/docidsets/RangelessBitmapDocIdSet.java @@ -18,22 +18,24 @@ */ package org.apache.pinot.core.operator.docidsets; -import org.apache.pinot.core.operator.dociditerators.BitmapDocIdIterator; +import org.apache.pinot.core.operator.dociditerators.RangelessBitmapDocIdIterator; import org.roaringbitmap.buffer.ImmutableRoaringBitmap; -public class BitmapDocIdSet implements FilterBlockDocIdSet { - private final ImmutableRoaringBitmap _docIds; - private final int _numDocs; +public class RangelessBitmapDocIdSet implements FilterBlockDocIdSet { + private final RangelessBitmapDocIdIterator _iterator; - public BitmapDocIdSet(ImmutableRoaringBitmap docIds, int numDocs) { - _docIds = docIds; - _numDocs = numDocs; + public RangelessBitmapDocIdSet(ImmutableRoaringBitmap docIds) { + _iterator = new RangelessBitmapDocIdIterator(docIds); + } + + public RangelessBitmapDocIdSet(RangelessBitmapDocIdIterator iterator) { + _iterator = iterator; } @Override - public BitmapDocIdIterator iterator() { - return new BitmapDocIdIterator(_docIds, _numDocs); + public RangelessBitmapDocIdIterator iterator() { + return _iterator; } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/CombinedFilterOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/CombinedFilterOperator.java new file mode 100644 index 0000000..54c26dc --- /dev/null +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/CombinedFilterOperator.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.operator.filter; + +import java.util.Arrays; +import java.util.List; +import org.apache.pinot.core.common.Operator; +import org.apache.pinot.core.operator.blocks.FilterBlock; +import org.apache.pinot.core.operator.docidsets.AndDocIdSet; +import org.apache.pinot.core.operator.docidsets.FilterBlockDocIdSet; + + +/** + * A combined filter operator consisting of one main filter operator and one sub filter operator. The result block is + * the AND result of the main and sub filter. + */ +public class CombinedFilterOperator extends BaseFilterOperator { + private static final String OPERATOR_NAME = "CombinedFilterOperator"; + private static final String EXPLAIN_NAME = "FILTER_COMBINED"; + + private final BaseFilterOperator _mainFilterOperator; + private final BaseFilterOperator _subFilterOperator; + + public CombinedFilterOperator(BaseFilterOperator mainFilterOperator, BaseFilterOperator subFilterOperator) { + _mainFilterOperator = mainFilterOperator; + _subFilterOperator = subFilterOperator; + } + + @Override + public String getOperatorName() { + return OPERATOR_NAME; + } + + @Override + public List<Operator> getChildOperators() { + return Arrays.asList(_mainFilterOperator, _subFilterOperator); + } + + @Override + public String toExplainString() { + return EXPLAIN_NAME; + } + + @Override + protected FilterBlock getNextBlock() { + FilterBlockDocIdSet mainFilterDocIdSet = _mainFilterOperator.nextBlock().getNonScanFilterBLockDocIdSet(); + FilterBlockDocIdSet subFilterDocIdSet = _subFilterOperator.nextBlock().getBlockDocIdSet(); + return new FilterBlock(new AndDocIdSet(Arrays.asList(mainFilterDocIdSet, subFilterDocIdSet))); + } +} diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java new file mode 100644 index 0000000..de9c380 --- /dev/null +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java @@ -0,0 +1,116 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.operator.query; + +import java.util.Arrays; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.pinot.core.common.Operator; +import org.apache.pinot.core.operator.BaseOperator; +import org.apache.pinot.core.operator.ExecutionStatistics; +import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock; +import org.apache.pinot.core.operator.blocks.TransformBlock; +import org.apache.pinot.core.operator.transform.TransformOperator; +import org.apache.pinot.core.query.aggregation.AggregationExecutor; +import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor; +import org.apache.pinot.core.query.aggregation.function.AggregationFunction; + + +/** + * This operator processes a collection of filtered (and potentially non filtered) aggregations. + * + * For a query with either all aggregations being filtered or a mix of filtered and non filtered aggregations, + * FilteredAggregationOperator will come into execution. + */ +@SuppressWarnings("rawtypes") +public class FilteredAggregationOperator extends BaseOperator<IntermediateResultsBlock> { + private static final String OPERATOR_NAME = "FilteredAggregationOperator"; + private static final String EXPLAIN_NAME = "AGGREGATE_FILTERED"; + + private final AggregationFunction[] _aggregationFunctions; + private final List<Pair<AggregationFunction[], TransformOperator>> _aggFunctionsWithTransformOperator; + private final long _numTotalDocs; + + private long _numDocsScanned; + private long _numEntriesScannedInFilter; + private long _numEntriesScannedPostFilter; + + // We can potentially do away with aggregationFunctions parameter, but its cleaner to pass it in than to construct + // it from aggFunctionsWithTransformOperator + public FilteredAggregationOperator(AggregationFunction[] aggregationFunctions, + List<Pair<AggregationFunction[], TransformOperator>> aggFunctionsWithTransformOperator, long numTotalDocs) { + _aggregationFunctions = aggregationFunctions; + _aggFunctionsWithTransformOperator = aggFunctionsWithTransformOperator; + _numTotalDocs = numTotalDocs; + } + + @Override + protected IntermediateResultsBlock getNextBlock() { + int numAggregations = _aggregationFunctions.length; + Object[] result = new Object[numAggregations]; + IdentityHashMap<AggregationFunction, Integer> resultIndexMap = new IdentityHashMap<>(numAggregations); + for (int i = 0; i < numAggregations; i++) { + resultIndexMap.put(_aggregationFunctions[i], i); + } + + for (Pair<AggregationFunction[], TransformOperator> filteredAggregation : _aggFunctionsWithTransformOperator) { + AggregationFunction[] aggregationFunctions = filteredAggregation.getLeft(); + AggregationExecutor aggregationExecutor = new DefaultAggregationExecutor(aggregationFunctions); + TransformOperator transformOperator = filteredAggregation.getRight(); + TransformBlock transformBlock; + int numDocsScanned = 0; + while ((transformBlock = transformOperator.nextBlock()) != null) { + aggregationExecutor.aggregate(transformBlock); + numDocsScanned += transformBlock.getNumDocs(); + } + List<Object> filteredResult = aggregationExecutor.getResult(); + + for (int i = 0; i < aggregationFunctions.length; i++) { + result[resultIndexMap.get(aggregationFunctions[i])] = filteredResult.get(i); + } + _numDocsScanned += numDocsScanned; + _numEntriesScannedInFilter += transformOperator.getExecutionStatistics().getNumEntriesScannedInFilter(); + _numEntriesScannedPostFilter += (long) numDocsScanned * transformOperator.getNumColumnsProjected(); + } + return new IntermediateResultsBlock(_aggregationFunctions, Arrays.asList(result), false); + } + + @Override + public String getOperatorName() { + return OPERATOR_NAME; + } + + @Override + public List<Operator> getChildOperators() { + return _aggFunctionsWithTransformOperator.stream().map(Pair::getRight).collect(Collectors.toList()); + } + + @Override + public ExecutionStatistics getExecutionStatistics() { + return new ExecutionStatistics(_numDocsScanned, _numEntriesScannedInFilter, _numEntriesScannedPostFilter, + _numTotalDocs); + } + + @Override + public String toExplainString() { + return EXPLAIN_NAME; + } +} diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java index d7318d9..94c57f0 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java @@ -18,17 +18,23 @@ */ package org.apache.pinot.core.plan; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.core.common.Operator; +import org.apache.pinot.core.operator.BaseOperator; import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock; import org.apache.pinot.core.operator.filter.BaseFilterOperator; +import org.apache.pinot.core.operator.filter.CombinedFilterOperator; import org.apache.pinot.core.operator.query.AggregationOperator; import org.apache.pinot.core.operator.query.DictionaryBasedAggregationOperator; +import org.apache.pinot.core.operator.query.FilteredAggregationOperator; import org.apache.pinot.core.operator.query.MetadataBasedAggregationOperator; import org.apache.pinot.core.operator.transform.TransformOperator; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; @@ -61,6 +67,98 @@ public class AggregationPlanNode implements PlanNode { @Override public Operator<IntermediateResultsBlock> run() { assert _queryContext.getAggregationFunctions() != null; + return _queryContext.isHasFilteredAggregations() ? buildFilteredAggOperator() : buildNonFilteredAggOperator(); + } + + /** + * Build the operator to be used for filtered aggregations + */ + private BaseOperator<IntermediateResultsBlock> buildFilteredAggOperator() { + int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs(); + // Build the operator chain for the main predicate + Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair = buildFilterOperator(_queryContext.getFilter()); + TransformOperator transformOperator = buildTransformOperatorForFilteredAggregates(filterOperatorPair.getRight()); + + return buildFilterOperatorInternal(filterOperatorPair.getRight(), transformOperator, numTotalDocs); + } + + /** + * Build a FilteredAggregationOperator given the parameters. + * @param mainPredicateFilterOperator Filter operator corresponding to the main predicate + * @param mainTransformOperator Transform operator corresponding to the main predicate + * @param numTotalDocs Number of total docs + */ + private BaseOperator<IntermediateResultsBlock> buildFilterOperatorInternal( + BaseFilterOperator mainPredicateFilterOperator, TransformOperator mainTransformOperator, int numTotalDocs) { + Map<FilterContext, Pair<List<AggregationFunction>, TransformOperator>> filterContextToAggFuncsMap = new HashMap<>(); + List<AggregationFunction> nonFilteredAggregationFunctions = new ArrayList<>(); + List<Pair<AggregationFunction, FilterContext>> aggregationFunctions = _queryContext.getFilteredAggregations(); + + // For each aggregation function, check if the aggregation function is a filtered agg. + // If it is, populate the corresponding filter operator and corresponding transform operator + for (Pair<AggregationFunction, FilterContext> inputPair : aggregationFunctions) { + if (inputPair.getLeft() != null) { + FilterContext currentFilterExpression = inputPair.getRight(); + if (filterContextToAggFuncsMap.get(currentFilterExpression) != null) { + filterContextToAggFuncsMap.get(currentFilterExpression).getLeft().add(inputPair.getLeft()); + continue; + } + Pair<FilterPlanNode, BaseFilterOperator> pair = buildFilterOperator(currentFilterExpression); + BaseFilterOperator wrappedFilterOperator = + new CombinedFilterOperator(mainPredicateFilterOperator, pair.getRight()); + TransformOperator newTransformOperator = buildTransformOperatorForFilteredAggregates(wrappedFilterOperator); + // For each transform operator, associate it with the underlying expression. This allows + // fetching the relevant TransformOperator when resolving blocks during aggregation + // execution + List<AggregationFunction> aggFunctionList = new ArrayList<>(); + aggFunctionList.add(inputPair.getLeft()); + filterContextToAggFuncsMap.put(currentFilterExpression, Pair.of(aggFunctionList, newTransformOperator)); + } else { + nonFilteredAggregationFunctions.add(inputPair.getLeft()); + } + } + List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList = new ArrayList<>(); + // Convert to array since FilteredAggregationOperator expects it + for (Pair<List<AggregationFunction>, TransformOperator> pair : filterContextToAggFuncsMap.values()) { + List<AggregationFunction> aggregationFunctionList = pair.getLeft(); + if (aggregationFunctionList == null) { + throw new IllegalStateException("Null aggregation list seen"); + } + aggToTransformOpList.add(Pair.of(aggregationFunctionList.toArray(new AggregationFunction[0]), pair.getRight())); + } + aggToTransformOpList.add( + Pair.of(nonFilteredAggregationFunctions.toArray(new AggregationFunction[0]), mainTransformOperator)); + + return new FilteredAggregationOperator(_queryContext.getAggregationFunctions(), aggToTransformOpList, numTotalDocs); + } + + /** + * Build a filter operator from the given FilterContext. + * + * It returns the FilterPlanNode to allow reusing plan level components such as predicate + * evaluator map + */ + private Pair<FilterPlanNode, BaseFilterOperator> buildFilterOperator(FilterContext filterContext) { + FilterPlanNode filterPlanNode = new FilterPlanNode(_indexSegment, _queryContext, filterContext); + return Pair.of(filterPlanNode, filterPlanNode.run()); + } + + private TransformOperator buildTransformOperatorForFilteredAggregates(BaseFilterOperator filterOperator) { + AggregationFunction[] aggregationFunctions = _queryContext.getAggregationFunctions(); + Set<ExpressionContext> expressionsToTransform = + AggregationFunctionUtils.collectExpressionsToTransform(aggregationFunctions, null); + + return new TransformPlanNode(_indexSegment, _queryContext, expressionsToTransform, + DocIdSetPlanNode.MAX_DOC_PER_CALL, filterOperator).run(); + } + + /** + * Processing workhorse for non filtered aggregates. Note that this code path is invoked only + * if the query has no filtered aggregates at all. If a query has mixed aggregates, filtered + * aggregates code will be invoked + */ + public Operator<IntermediateResultsBlock> buildNonFilteredAggOperator() { + assert _queryContext.getAggregationFunctions() != null; int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs(); AggregationFunction[] aggregationFunctions = _queryContext.getAggregationFunctions(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java index ce48367..014a35d 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java @@ -58,6 +58,7 @@ public class FilterPlanNode implements PlanNode { private final IndexSegment _indexSegment; private final QueryContext _queryContext; private final int _numDocs; + private FilterContext _filterContext; // Cache the predicate evaluators private final Map<Predicate, PredicateEvaluator> _predicateEvaluatorMap = new HashMap<>(); @@ -70,9 +71,16 @@ public class FilterPlanNode implements PlanNode { _numDocs = _indexSegment.getSegmentMetadata().getTotalDocs(); } + public FilterPlanNode(IndexSegment indexSegment, QueryContext queryContext, + FilterContext filterContext) { + this(indexSegment, queryContext); + + _filterContext = filterContext; + } + @Override public BaseFilterOperator run() { - FilterContext filter = _queryContext.getFilter(); + FilterContext filter = _filterContext == null ? _queryContext.getFilter() : _filterContext; ThreadSafeMutableRoaringBitmap validDocIds = _indexSegment.getValidDocIds(); boolean applyValidDocIds = validDocIds != null && !QueryOptionsUtils.isSkipUpsert(_queryContext.getQueryOptions()); if (filter != null) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java index 4705951..cde2ca8 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java @@ -22,8 +22,11 @@ import com.google.common.base.Preconditions; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.common.request.context.FunctionContext; +import org.apache.pinot.common.request.context.RequestContextUtils; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.core.query.postaggregation.PostAggregationFunction; @@ -37,6 +40,7 @@ import org.apache.pinot.core.util.GapfillUtils; */ public class PostAggregationHandler { private final Map<FunctionContext, Integer> _aggregationFunctionIndexMap; + private final Map<Pair<FunctionContext, FilterContext>, Integer> _filteredAggregationsIndexMap; private final int _numGroupByExpressions; private final Map<ExpressionContext, Integer> _groupByExpressionIndexMap; private final DataSchema _dataSchema; @@ -45,6 +49,7 @@ public class PostAggregationHandler { public PostAggregationHandler(QueryContext queryContext, DataSchema dataSchema) { _aggregationFunctionIndexMap = queryContext.getAggregationFunctionIndexMap(); + _filteredAggregationsIndexMap = queryContext.getFilteredAggregationsIndexMap(); assert _aggregationFunctionIndexMap != null; List<ExpressionContext> groupByExpressions = queryContext.getGroupByExpressions(); if (groupByExpressions != null) { @@ -117,6 +122,14 @@ public class PostAggregationHandler { if (function.getType() == FunctionContext.Type.AGGREGATION) { // Aggregation function return new ColumnValueExtractor(_aggregationFunctionIndexMap.get(function) + _numGroupByExpressions); + } else if (function.getType() == FunctionContext.Type.TRANSFORM + && function.getFunctionName().equalsIgnoreCase("filter")) { + ExpressionContext filterExpression = function.getArguments().get(1); + FilterContext filter = RequestContextUtils.getFilter(filterExpression); + FunctionContext filterFunction = function.getArguments().get(0).getFunction(); + + return new ColumnValueExtractor(_filteredAggregationsIndexMap + .get(Pair.of(filterFunction, filter))); } else { // Post-aggregation function return new PostAggregationValueExtractor(function); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java index e608208..f839901 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java @@ -90,9 +90,11 @@ public class QueryContext { // Pre-calculate the aggregation functions and columns for the query so that it can be shared across all the segments private AggregationFunction[] _aggregationFunctions; - private List<Pair<AggregationFunction, FilterContext>> _filteredAggregationFunctions; - // TODO: Use Pair<FunctionContext, FilterContext> as key to support filtered aggregations in order-by and post - // aggregation + + private List<Pair<AggregationFunction, FilterContext>> _filteredAggregations; + + private boolean _hasFilteredAggregations; + private Map<Pair<FunctionContext, FilterContext>, Integer> _filteredAggregationsIndexMap; private Map<FunctionContext, Integer> _aggregationFunctionIndexMap; private Set<String> _columns; @@ -232,11 +234,18 @@ public class QueryContext { } /** - * Returns the filtered aggregation expressions for the query. + * Returns the filtered aggregations for a query */ @Nullable - public List<Pair<AggregationFunction, FilterContext>> getFilteredAggregationFunctions() { - return _filteredAggregationFunctions; + public List<Pair<AggregationFunction, FilterContext>> getFilteredAggregations() { + return _filteredAggregations; + } + + /** + * Returns the filtered aggregation expressions for the query. + */ + public boolean isHasFilteredAggregations() { + return _hasFilteredAggregations; } /** @@ -249,6 +258,16 @@ public class QueryContext { } /** + * Returns a map from the expression of a filtered aggregation to the index of corresponding AggregationFunction + * in the aggregation functions array + * @return + */ + @Nullable + public Map<Pair<FunctionContext, FilterContext>, Integer> getFilteredAggregationsIndexMap() { + return _filteredAggregationsIndexMap; + } + + /** * Returns the columns (IDENTIFIER expressions) in the query. */ public Set<String> getColumns() { @@ -441,37 +460,51 @@ public class QueryContext { */ private void generateAggregationFunctions(QueryContext queryContext) { List<AggregationFunction> aggregationFunctions = new ArrayList<>(); - List<Pair<AggregationFunction, FilterContext>> filteredAggregationFunctions = new ArrayList<>(); + List<Pair<AggregationFunction, FilterContext>> filteredAggregations = new ArrayList<>(); Map<FunctionContext, Integer> aggregationFunctionIndexMap = new HashMap<>(); + Map<Pair<FunctionContext, FilterContext>, Integer> filterExpressionIndexMap = new HashMap<>(); // Add aggregation functions in the SELECT clause // NOTE: DO NOT deduplicate the aggregation functions in the SELECT clause because that involves protocol change. - List<FunctionContext> aggregationsInSelect = new ArrayList<>(); - List<Pair<FunctionContext, FilterContext>> filteredAggregations = new ArrayList<>(); + List<Pair<FilterContext, FunctionContext>> aggregationsInSelect = new ArrayList<>(); for (ExpressionContext selectExpression : queryContext._selectExpressions) { - getAggregations(selectExpression, aggregationsInSelect, filteredAggregations); + getAggregations(selectExpression, aggregationsInSelect); } - for (FunctionContext function : aggregationsInSelect) { - int functionIndex = aggregationFunctions.size(); + for (Pair<FilterContext, FunctionContext> pair : aggregationsInSelect) { + FunctionContext function = pair.getRight(); + int functionIndex = filteredAggregations.size(); AggregationFunction aggregationFunction = AggregationFunctionFactory.getAggregationFunction(function, queryContext); - aggregationFunctions.add(aggregationFunction); + + FilterContext filterContext = null; + // If the left pair is not null, implies a filtered aggregation + if (pair.getLeft() != null) { + if (_groupByExpressions != null) { + throw new IllegalStateException("GROUP BY with FILTER clauses is not supported"); + } + queryContext._hasFilteredAggregations = true; + filterContext = pair.getLeft(); + Pair<FunctionContext, FilterContext> filterContextPair = + Pair.of(function, filterContext); + if (!filterExpressionIndexMap.containsKey(filterContextPair)) { + int filterMapIndex = filterExpressionIndexMap.size(); + filterExpressionIndexMap.put(filterContextPair, filterMapIndex); + } + } + filteredAggregations.add(Pair.of(aggregationFunction, filterContext)); aggregationFunctionIndexMap.put(function, functionIndex); } - for (Pair<FunctionContext, FilterContext> pair : filteredAggregations) { - AggregationFunction aggregationFunction = - aggregationFunctions.get(aggregationFunctionIndexMap.get(pair.getLeft())); - filteredAggregationFunctions.add(Pair.of(aggregationFunction, pair.getRight())); - } // Add aggregation functions in the HAVING clause but not in the SELECT clause if (queryContext._havingFilter != null) { - List<FunctionContext> aggregationsInHaving = new ArrayList<>(); + List<Pair<FilterContext, FunctionContext>> aggregationsInHaving = new ArrayList<>(); getAggregations(queryContext._havingFilter, aggregationsInHaving); - for (FunctionContext function : aggregationsInHaving) { + for (Pair<FilterContext, FunctionContext> pair : aggregationsInHaving) { + FunctionContext function = pair.getRight(); if (!aggregationFunctionIndexMap.containsKey(function)) { - int functionIndex = aggregationFunctions.size(); - aggregationFunctions.add(AggregationFunctionFactory.getAggregationFunction(function, queryContext)); + int functionIndex = filteredAggregations.size(); + filteredAggregations.add(Pair.of( + AggregationFunctionFactory.getAggregationFunction(function, queryContext), null)); aggregationFunctionIndexMap.put(function, functionIndex); } } @@ -479,38 +512,47 @@ public class QueryContext { // Add aggregation functions in the ORDER-BY clause but not in the SELECT or HAVING clause if (queryContext._orderByExpressions != null) { - List<FunctionContext> aggregationsInOrderBy = new ArrayList<>(); + List<Pair<FilterContext, FunctionContext>> aggregationsInOrderBy = new ArrayList<>(); for (OrderByExpressionContext orderByExpression : queryContext._orderByExpressions) { - getAggregations(orderByExpression.getExpression(), aggregationsInOrderBy, null); + getAggregations(orderByExpression.getExpression(), aggregationsInOrderBy); } - for (FunctionContext function : aggregationsInOrderBy) { + for (Pair<FilterContext, FunctionContext> pair : aggregationsInOrderBy) { + FunctionContext function = pair.getRight(); if (!aggregationFunctionIndexMap.containsKey(function)) { - int functionIndex = aggregationFunctions.size(); - aggregationFunctions.add(AggregationFunctionFactory.getAggregationFunction(function, queryContext)); + int functionIndex = filteredAggregations.size(); + filteredAggregations.add(Pair.of( + AggregationFunctionFactory.getAggregationFunction(function, queryContext), null)); aggregationFunctionIndexMap.put(function, functionIndex); } } } - if (!aggregationFunctions.isEmpty()) { + if (!filteredAggregations.isEmpty()) { + for (Pair<AggregationFunction, FilterContext> pair : filteredAggregations) { + aggregationFunctions.add(pair.getLeft()); + } + queryContext._aggregationFunctions = aggregationFunctions.toArray(new AggregationFunction[0]); - queryContext._filteredAggregationFunctions = filteredAggregationFunctions; + queryContext._filteredAggregations = filteredAggregations; queryContext._aggregationFunctionIndexMap = aggregationFunctionIndexMap; + queryContext._filteredAggregationsIndexMap = filterExpressionIndexMap; } } /** * Helper method to extract AGGREGATION FunctionContexts from the given expression. + * + * NOTE: The left pair of aggregations should be set only for filtered aggregations */ - private static void getAggregations(ExpressionContext expression, List<FunctionContext> aggregations, - List<Pair<FunctionContext, FilterContext>> filteredAggregations) { + private static void getAggregations(ExpressionContext expression, + List<Pair<FilterContext, FunctionContext>> aggregations) { FunctionContext function = expression.getFunction(); if (function == null) { return; } if (function.getType() == FunctionContext.Type.AGGREGATION) { // Aggregation - aggregations.add(function); + aggregations.add(Pair.of(null, function)); } else { List<ExpressionContext> arguments = function.getArguments(); if (function.getFunctionName().equalsIgnoreCase("filter")) { @@ -524,12 +566,12 @@ public class QueryContext { && filterExpression.getFunction().getType() == FunctionContext.Type.TRANSFORM, "Second argument of FILTER must be a filter expression"); FilterContext filter = RequestContextUtils.getFilter(filterExpression); - aggregations.add(aggregation); - filteredAggregations.add(Pair.of(aggregation, filter)); + + aggregations.add(Pair.of(filter, aggregation)); } else { // Transform for (ExpressionContext argument : arguments) { - getAggregations(argument, aggregations, filteredAggregations); + getAggregations(argument, aggregations); } } } @@ -538,14 +580,15 @@ public class QueryContext { /** * Helper method to extract AGGREGATION FunctionContexts from the given filter. */ - private static void getAggregations(FilterContext filter, List<FunctionContext> aggregations) { + private static void getAggregations(FilterContext filter, + List<Pair<FilterContext, FunctionContext>> aggregations) { List<FilterContext> children = filter.getChildren(); if (children != null) { for (FilterContext child : children) { getAggregations(child, aggregations); } } else { - getAggregations(filter.getPredicate().getLhs(), aggregations, null); + getAggregations(filter.getPredicate().getLhs(), aggregations); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverter.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverter.java index 35ac4e2..70df3db 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverter.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverter.java @@ -102,9 +102,11 @@ public class BrokerRequestToQueryContextConverter { // WHERE FilterContext filter = null; + ExpressionContext filterExpressionContext = null; Expression filterExpression = pinotQuery.getFilterExpression(); if (filterExpression != null) { filter = RequestContextUtils.getFilter(pinotQuery.getFilterExpression()); + filterExpressionContext = RequestContextUtils.getExpression(filterExpression); } // GROUP BY diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java index 1f690c1..b666a09 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java @@ -582,7 +582,7 @@ public class BrokerRequestToQueryContextConverterTest { String query = "SELECT COUNT(*) FILTER(WHERE foo > 5), COUNT(*) FILTER(WHERE foo < 6) FROM testTable WHERE bar > 0"; QueryContext queryContext = QueryContextConverterUtils.getQueryContextFromSQL(query); List<Pair<AggregationFunction, FilterContext>> filteredAggregationList = - queryContext.getFilteredAggregationFunctions(); + queryContext.getFilteredAggregations(); assertNotNull(filteredAggregationList); assertEquals(filteredAggregationList.size(), 2); assertTrue(filteredAggregationList.get(0).getLeft() instanceof CountAggregationFunction); diff --git a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java index 8cb2e7a..c28a885 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java @@ -103,6 +103,7 @@ abstract class BaseStarTreeV2Test<R, A> { " WHERE (d2 > 95 OR d2 < 25) AND (d1 > 10 OR d1 < 50)"; private static final String QUERY_FILTER_COMPLEX_OR_SINGLE_DIMENSION = " WHERE d1 = 95 AND (d1 > 90 OR d1 < 100)"; private static final String QUERY_GROUP_BY = " GROUP BY d2"; + private static final String FILTER_AGG_CLAUSE = " FILTER(WHERE d1 > 10)"; private ValueAggregator _valueAggregator; private DataType _aggregatedValueType; @@ -163,8 +164,15 @@ abstract class BaseStarTreeV2Test<R, A> { @Test public void testQueries() throws IOException { + testQueriesHelper(false); + testQueriesHelper(true); + } + + private void testQueriesHelper(boolean runFilteredAggregate) + throws IOException { AggregationFunctionType aggregationType = _valueAggregator.getAggregationType(); String aggregation; + String filteredAggregation = null; if (aggregationType == AggregationFunctionType.COUNT) { aggregation = "COUNT(*)"; } else if (aggregationType == AggregationFunctionType.PERCENTILEEST @@ -176,13 +184,22 @@ abstract class BaseStarTreeV2Test<R, A> { } String baseQuery = String.format("SELECT %s FROM %s", aggregation, TABLE_NAME); - testQuery(baseQuery); - testQuery(baseQuery + QUERY_FILTER_AND); - testQuery(baseQuery + QUERY_FILTER_OR); - testQuery(baseQuery + QUERY_FILTER_COMPLEX_OR_MULTIPLE_DIMENSIONS); - testQuery(baseQuery + QUERY_FILTER_COMPLEX_AND_MULTIPLE_DIMENSIONS_THREE_PREDICATES); - testQuery(baseQuery + QUERY_FILTER_COMPLEX_OR_MULTIPLE_DIMENSIONS_THREE_PREDICATES); - testQuery(baseQuery + QUERY_FILTER_COMPLEX_OR_SINGLE_DIMENSION); + String filteredQuery; + + if (runFilteredAggregate) { + filteredAggregation = aggregation + FILTER_AGG_CLAUSE; + filteredQuery = String.format("SELECT %s FROM %s", filteredAggregation, TABLE_NAME); + } else { + filteredQuery = baseQuery; + } + + testQuery(filteredQuery); + testQuery(filteredQuery + QUERY_FILTER_AND); + testQuery(filteredQuery + QUERY_FILTER_OR); + testQuery(filteredQuery + QUERY_FILTER_COMPLEX_OR_MULTIPLE_DIMENSIONS); + testQuery(filteredQuery + QUERY_FILTER_COMPLEX_AND_MULTIPLE_DIMENSIONS_THREE_PREDICATES); + testQuery(filteredQuery + QUERY_FILTER_COMPLEX_OR_MULTIPLE_DIMENSIONS_THREE_PREDICATES); + testQuery(filteredQuery + QUERY_FILTER_COMPLEX_OR_SINGLE_DIMENSION); testQuery(baseQuery + QUERY_GROUP_BY); testQuery(baseQuery + QUERY_FILTER_AND + QUERY_GROUP_BY); testQuery(baseQuery + QUERY_FILTER_OR + QUERY_GROUP_BY); diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java new file mode 100644 index 0000000..54c11cb --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java @@ -0,0 +1,429 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.queries; + +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.apache.commons.io.FileUtils; +import org.apache.pinot.common.response.broker.BrokerResponseNative; +import org.apache.pinot.common.response.broker.ResultTable; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader; +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl; +import org.apache.pinot.segment.local.segment.index.loader.IndexLoadingConfig; +import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader; +import org.apache.pinot.segment.spi.ImmutableSegment; +import org.apache.pinot.segment.spi.IndexSegment; +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig; +import org.apache.pinot.spi.config.table.FieldConfig; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.config.table.TableType; +import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.data.readers.GenericRow; +import org.apache.pinot.spi.data.readers.RecordReader; +import org.apache.pinot.spi.utils.builder.TableConfigBuilder; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + + +public class FilteredAggregationsTest extends BaseQueriesTest { + private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "FilteredAggregationsTest"); + private static final String TABLE_NAME = "MyTable"; + private static final String FIRST_SEGMENT_NAME = "firstTestSegment"; + private static final String SECOND_SEGMENT_NAME = "secondTestSegment"; + private static final String INT_COL_NAME = "INT_COL"; + private static final String NO_INDEX_INT_COL_NAME = "NO_INDEX_COL"; + private static final String STATIC_INT_COL_NAME = "STATIC_INT_COL"; + private static final Integer INT_BASE_VALUE = 0; + private static final Integer NUM_ROWS = 30000; + + + private IndexSegment _indexSegment; + private List<IndexSegment> _indexSegments; + + @Override + protected String getFilter() { + return ""; + } + + @Override + protected IndexSegment getIndexSegment() { + return _indexSegment; + } + + @Override + protected List<IndexSegment> getIndexSegments() { + return _indexSegments; + } + + @BeforeClass + public void setUp() + throws Exception { + FileUtils.deleteQuietly(INDEX_DIR); + + buildSegment(FIRST_SEGMENT_NAME); + buildSegment(SECOND_SEGMENT_NAME); + IndexLoadingConfig indexLoadingConfig = new IndexLoadingConfig(); + + Set<String> invertedIndexCols = new HashSet<>(); + invertedIndexCols.add(INT_COL_NAME); + + indexLoadingConfig.setInvertedIndexColumns(invertedIndexCols); + ImmutableSegment firstImmutableSegment = + ImmutableSegmentLoader.load(new File(INDEX_DIR, FIRST_SEGMENT_NAME), indexLoadingConfig); + ImmutableSegment secondImmutableSegment = + ImmutableSegmentLoader.load(new File(INDEX_DIR, SECOND_SEGMENT_NAME), indexLoadingConfig); + _indexSegment = firstImmutableSegment; + _indexSegments = Arrays.asList(firstImmutableSegment, secondImmutableSegment); + } + + @AfterClass + public void tearDown() { + _indexSegment.destroy(); + FileUtils.deleteQuietly(INDEX_DIR); + } + + private List<GenericRow> createTestData(int numRows) { + List<GenericRow> rows = new ArrayList<>(); + + for (int i = 0; i < numRows; i++) { + GenericRow row = new GenericRow(); + row.putField(INT_COL_NAME, INT_BASE_VALUE + i); + row.putField(NO_INDEX_INT_COL_NAME, i); + row.putField(STATIC_INT_COL_NAME, 10); + + rows.add(row); + } + return rows; + } + + private void buildSegment(String segmentName) + throws Exception { + List<GenericRow> rows = createTestData(NUM_ROWS); + List<FieldConfig> fieldConfigs = new ArrayList<>(); + + TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(TABLE_NAME) + .setInvertedIndexColumns(Arrays.asList(INT_COL_NAME)).setFieldConfigList(fieldConfigs).build(); + Schema schema = new Schema.SchemaBuilder().setSchemaName(TABLE_NAME) + .addSingleValueDimension(NO_INDEX_INT_COL_NAME, FieldSpec.DataType.INT) + .addSingleValueDimension(STATIC_INT_COL_NAME, FieldSpec.DataType.INT) + .addMetric(INT_COL_NAME, FieldSpec.DataType.INT).build(); + SegmentGeneratorConfig config = new SegmentGeneratorConfig(tableConfig, schema); + config.setOutDir(INDEX_DIR.getPath()); + config.setTableName(TABLE_NAME); + config.setSegmentName(segmentName); + + SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl(); + try (RecordReader recordReader = new GenericRowRecordReader(rows)) { + driver.init(config, recordReader); + driver.build(); + } + } + + private void testInterSegmentAggregationQueryHelper(String firstQuery, String secondQuery) { + // SQL + BrokerResponseNative firstBrokerResponseNative = getBrokerResponseForSqlQuery(firstQuery); + BrokerResponseNative secondBrokerResponseNative = getBrokerResponseForSqlQuery(secondQuery); + ResultTable firstResultTable = firstBrokerResponseNative.getResultTable(); + ResultTable secondResultTable = secondBrokerResponseNative.getResultTable(); + DataSchema firstDataSchema = firstResultTable.getDataSchema(); + DataSchema secondDataSchema = secondResultTable.getDataSchema(); + + Assert.assertEquals(firstDataSchema.size(), secondDataSchema.size()); + + List<Object[]> firstSetOfRows = firstResultTable.getRows(); + List<Object[]> secondSetOfRows = secondResultTable.getRows(); + + Assert.assertEquals(firstSetOfRows.size(), secondSetOfRows.size()); + + for (int i = 0; i < firstSetOfRows.size(); i++) { + Object[] firstSetRow = firstSetOfRows.get(i); + Object[] secondSetRow = secondSetOfRows.get(i); + + Assert.assertEquals(firstSetRow.length, secondSetRow.length); + + for (int j = 0; j < firstSetRow.length; j++) { + //System.out.println("FIRST " + firstSetRow[j] + " SECOND " + secondSetRow[j] + " j " + j); + Assert.assertEquals(firstSetRow[j], secondSetRow[j]); + } + } + } + + @Test + public void testInterSegment() { + + String query = + "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999)" + + "FROM MyTable WHERE INT_COL < 1000000"; + + String nonFilterQuery = + "SELECT SUM(INT_COL)" + + "FROM MyTable WHERE INT_COL > 9999 AND INT_COL < 1000000"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 1234 AND INT_COL < 22000)" + + "FROM MyTable"; + + nonFilterQuery = "SELECT SUM(CASE " + + "WHEN (INT_COL > 1234 AND INT_COL < 22000) THEN INT_COL ELSE 0 " + + "END) AS total_sum FROM MyTable"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT SUM(INT_COL) FILTER(WHERE INT_COL < 3)" + + "FROM MyTable WHERE INT_COL > 1"; + nonFilterQuery = + "SELECT SUM(INT_COL)" + + "FROM MyTable WHERE INT_COL > 1 AND INT_COL < 3"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT COUNT(*) FILTER(WHERE INT_COL = 4)" + + "FROM MyTable"; + nonFilterQuery = + "SELECT COUNT(*)" + + "FROM MyTable WHERE INT_COL = 4"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 8000)" + + "FROM MyTable "; + + nonFilterQuery = + "SELECT SUM(INT_COL)" + + "FROM MyTable WHERE INT_COL > 8000"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT SUM(INT_COL) FILTER(WHERE NO_INDEX_COL <= 1)" + + "FROM MyTable WHERE INT_COL > 1"; + + nonFilterQuery = + "SELECT SUM(INT_COL)" + + "FROM MyTable WHERE NO_INDEX_COL <= 1 AND INT_COL > 1"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT AVG(NO_INDEX_COL)" + + "FROM MyTable WHERE NO_INDEX_COL > -1"; + nonFilterQuery = + "SELECT AVG(NO_INDEX_COL)" + + "FROM MyTable"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT SUM(INT_COL) FILTER(WHERE INT_COL % 10 = 0),SUM(NO_INDEX_COL),MAX(INT_COL) " + + "FROM MyTable"; + + nonFilterQuery = + "SELECT SUM(CASE WHEN (INT_COL % 10 = 0) THEN INT_COL ELSE 0 END) AS total_sum,SUM(NO_INDEX_COL)," + + "MAX(INT_COL) FROM MyTable"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT AVG(INT_COL) FILTER(WHERE NO_INDEX_COL > -1) FROM MyTable"; + nonFilterQuery = + "SELECT AVG(NO_INDEX_COL) FROM MyTable"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT SUM(INT_COL) FILTER(WHERE INT_COL % 10 = 0),MAX(NO_INDEX_COL) FROM MyTable"; + + nonFilterQuery = + "SELECT SUM(CASE WHEN (INT_COL % 10 = 0) THEN INT_COL ELSE 0 END) AS total_sum,MAX(NO_INDEX_COL)" + + "FROM MyTable"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT SUM(INT_COL) FILTER(WHERE INT_COL % 10 = 0),MAX(NO_INDEX_COL)" + + "FROM MyTable WHERE NO_INDEX_COL > 5"; + nonFilterQuery = + "SELECT SUM(CASE WHEN (INT_COL % 10 = 0) THEN INT_COL ELSE 0 END) AS total_sum,MAX(NO_INDEX_COL)" + + "FROM MyTable WHERE NO_INDEX_COL > 5"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT MAX(INT_COL) FILTER(WHERE INT_COL < 100) FROM MyTable"; + + nonFilterQuery = + "SELECT MAX(CASE WHEN (INT_COL < 100) THEN INT_COL ELSE 0 END) AS total_max FROM MyTable"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT MIN(NO_INDEX_COL) FILTER(WHERE INT_COL < 100) FROM MyTable"; + + nonFilterQuery = + "SELECT MIN(CASE WHEN (INT_COL < 100) THEN NO_INDEX_COL ELSE 0 END) AS total_min " + + "FROM MyTable"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT MIN(NO_INDEX_COL) FILTER(WHERE INT_COL > 29990),MAX(INT_COL) FILTER(WHERE INT_COL > 29990)" + + "FROM MyTable"; + + nonFilterQuery = + "SELECT MIN(NO_INDEX_COL), MAX(INT_COL) FROM MyTable WHERE INT_COL > 29990"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + } + + @Test + public void testCaseVsFilter() { + String query = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 3)," + + "SUM(INT_COL) FILTER(WHERE INT_COL < 4)" + + "FROM MyTable WHERE INT_COL > 2"; + + String nonFilterQuery = "SELECT SUM(CASE WHEN (INT_COL > 3) THEN INT_COL ELSE 0 " + + "END) AS total_sum,SUM(CASE WHEN (INT_COL < 4) THEN INT_COL ELSE 0 END) AS total_sum2 " + + "FROM MyTable WHERE INT_COL > 2"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 12345),SUM(INT_COL) FILTER(WHERE INT_COL < 59999)," + + "MIN(INT_COL) FILTER(WHERE INT_COL > 5000) FROM MyTable WHERE INT_COL > 1000"; + + nonFilterQuery = "SELECT SUM( CASE WHEN (INT_COL > 12345) THEN INT_COL ELSE 0 " + + "END) AS total_sum,SUM(CASE WHEN (INT_COL < 59999) THEN INT_COL ELSE 0 " + + "END) AS total_sum2,MIN(CASE WHEN (INT_COL > 5000) THEN INT_COL " + + "ELSE 9999999 END) AS total_min FROM MyTable WHERE INT_COL > 1000"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 12345)," + + "SUM(NO_INDEX_COL) FILTER(WHERE INT_COL < 59999)," + + "MIN(INT_COL) FILTER(WHERE INT_COL > 5000) " + + "FROM MyTable WHERE INT_COL > 1000"; + + nonFilterQuery = "SELECT SUM(CASE WHEN (INT_COL > 12345) THEN INT_COL " + + "ELSE 0 END) AS total_sum,SUM(CASE WHEN (INT_COL < 59999) THEN NO_INDEX_COL " + + "ELSE 0 END) AS total_sum2,MIN(CASE WHEN (INT_COL > 5000) THEN INT_COL " + + "ELSE 9999999 END) AS total_min FROM MyTable WHERE INT_COL > 1000"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 12345)," + + "SUM(NO_INDEX_COL) FILTER(WHERE INT_COL < 59999)," + + "MIN(INT_COL) FILTER(WHERE INT_COL > 5000) " + + "FROM MyTable WHERE INT_COL < 28000 AND NO_INDEX_COL > 3000 "; + + nonFilterQuery = "SELECT SUM(CASE WHEN (INT_COL > 12345) THEN INT_COL ELSE 0 " + + "END) AS total_sum,SUM(CASE WHEN (INT_COL < 59999) THEN NO_INDEX_COL " + + "ELSE 0 END) AS total_sum2,MIN(CASE WHEN (INT_COL > 5000) THEN INT_COL " + + "ELSE 9999999 END) AS total_min FROM MyTable WHERE INT_COL < 28000 AND NO_INDEX_COL > 3000"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = "SELECT SUM(INT_COL) FILTER(WHERE ABS(INT_COL) > 12345)," + + "SUM(NO_INDEX_COL) FILTER(WHERE LN(INT_COL) < 59999)," + + "MIN(INT_COL) FILTER(WHERE INT_COL > 5000) " + + "FROM MyTable WHERE INT_COL < 28000 AND NO_INDEX_COL > 3000 "; + + nonFilterQuery = "SELECT SUM(" + + "CASE WHEN (ABS(INT_COL) > 12345) THEN INT_COL ELSE 0 " + + "END) AS total_sum,SUM(CASE WHEN (LN(INT_COL) < 59999) THEN NO_INDEX_COL " + + "ELSE 0 END) AS total_sum2,MIN(CASE WHEN (INT_COL > 5000) THEN INT_COL " + + "ELSE 9999999 END) AS total_min FROM MyTable WHERE INT_COL < 28000 AND NO_INDEX_COL > 3000"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = "SELECT SUM(INT_COL) FILTER(WHERE MOD(INT_COL, STATIC_INT_COL) = 0)," + + "MIN(INT_COL) FILTER(WHERE INT_COL > 5000) " + + "FROM MyTable WHERE INT_COL < 28000 AND NO_INDEX_COL > 3000 "; + + nonFilterQuery = "SELECT SUM(CASE WHEN (MOD(INT_COL, STATIC_INT_COL) = 0) THEN INT_COL " + + "ELSE 0 END) AS total_sum,MIN(CASE WHEN (INT_COL > 5000) THEN INT_COL " + + "ELSE 9999999 END) AS total_min FROM MyTable WHERE INT_COL < 28000 AND NO_INDEX_COL > 3000"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 123 AND INT_COL < 25000)," + + "MAX(INT_COL) FILTER(WHERE INT_COL > 123 AND INT_COL < 25000) " + + "FROM MyTable WHERE NO_INDEX_COL > 5 AND NO_INDEX_COL < 29999"; + + nonFilterQuery = "SELECT SUM(CASE WHEN (INT_COL > 123 AND INT_COL < 25000) THEN INT_COL " + + "ELSE 0 END) AS total_sum,MAX(CASE WHEN (INT_COL > 123 AND INT_COL < 25000) THEN INT_COL " + + "ELSE 0 END) AS total_avg FROM MyTable WHERE NO_INDEX_COL > 5 AND NO_INDEX_COL < 29999"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + } + + @Test + public void testMultipleAggregationsOnSameFilter() { + String query = + "SELECT MIN(NO_INDEX_COL) FILTER(WHERE INT_COL > 29990)," + + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990)" + + "FROM MyTable"; + + String nonFilterQuery = + "SELECT MIN(NO_INDEX_COL), MAX(INT_COL) FROM MyTable " + + "WHERE INT_COL > 29990"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + + query = + "SELECT MIN(NO_INDEX_COL) FILTER(WHERE INT_COL > 29990)," + + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990)," + + "SUM(INT_COL) FILTER(WHERE NO_INDEX_COL < 5000)," + + "MAX(NO_INDEX_COL) FILTER(WHERE NO_INDEX_COL < 5000) " + + "FROM MyTable"; + + nonFilterQuery = "SELECT MIN(CASE WHEN (INT_COL > 29990) THEN NO_INDEX_COL ELSE 99999 " + + "END) AS total_min,MAX(CASE WHEN (INT_COL > 29990) THEN INT_COL ELSE 0 " + + "END) AS total_max,SUM(CASE WHEN (NO_INDEX_COL < 5000) THEN INT_COL " + + "ELSE 0 END) AS total_sum,MAX(CASE WHEN (NO_INDEX_COL < 5000) THEN NO_INDEX_COL " + + "ELSE 0 END) AS total_count FROM MyTable"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + } + + @Test(expectedExceptions = IllegalStateException.class) + public void testGroupBySupport() { + String query = + "SELECT MIN(NO_INDEX_COL) FILTER(WHERE INT_COL > 2)," + + "MAX(INT_COL) FILTER(WHERE INT_COL > 2)" + + "FROM MyTable WHERE INT_COL < 1000" + + " GROUP BY INT_COL"; + + String nonFilterQuery = + "SELECT MIN(NO_INDEX_COL), MAX(INT_COL) FROM MyTable " + + "GROUP BY INT_COL"; + + testInterSegmentAggregationQueryHelper(query, nonFilterQuery); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InnerSegmentAggregationSingleValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InnerSegmentAggregationSingleValueQueriesTest.java index cf2d3be..684672a 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/InnerSegmentAggregationSingleValueQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/InnerSegmentAggregationSingleValueQueriesTest.java @@ -26,6 +26,7 @@ import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock; import org.apache.pinot.core.operator.query.AggregationGroupByOperator; import org.apache.pinot.core.operator.query.AggregationOperator; import org.apache.pinot.core.operator.query.DistinctOperator; +import org.apache.pinot.core.operator.query.FilteredAggregationOperator; import org.apache.pinot.core.query.distinct.DistinctTable; import org.testng.Assert; import org.testng.annotations.Test; @@ -73,6 +74,52 @@ public class InnerSegmentAggregationSingleValueQueriesTest extends BaseSingleVal } @Test + public void testFilteredAggregations() { + String query = "SELECT SUM(column6) FILTER(WHERE column6 > 5), COUNT(*) FILTER(WHERE column1 IS NOT NULL)," + + "MAX(column3) FILTER(WHERE column3 IS NOT NULL), " + + "SUM(column3), AVG(column7) FILTER(WHERE column7 > 0) FROM testTable WHERE column3 > 0"; + + FilteredAggregationOperator aggregationOperator = getOperatorForSqlQuery(query); + IntermediateResultsBlock resultsBlock = aggregationOperator.nextBlock(); + QueriesTestUtils + .testInnerSegmentExecutionStatistics(aggregationOperator.getExecutionStatistics(), + 180000L, 0L, 540000L, 30000L); + QueriesTestUtils + .testInnerSegmentAggregationResultForFilteredAggs(resultsBlock.getAggregationResult(), 22266008882250L, + 30000, 2147419555, + 2147483647, 28175373944314L, 30000L); + + query = "SELECT SUM(column6) FILTER(WHERE column6 > 5), COUNT(*) FILTER(WHERE column1 IS NOT NULL)," + + "MAX(column3) FILTER(WHERE column3 IS NOT NULL), " + + "SUM(column3), AVG(column7) FILTER(WHERE column7 > 0) FROM testTable"; + + aggregationOperator = getOperatorForSqlQuery(query); + resultsBlock = aggregationOperator.nextBlock(); + QueriesTestUtils + .testInnerSegmentExecutionStatistics(aggregationOperator.getExecutionStatistics(), + 180000L, 0L, 540000L, 30000L); + QueriesTestUtils + .testInnerSegmentAggregationResultForFilteredAggs(resultsBlock.getAggregationResult(), 22266008882250L, + 30000, 2147419555, + 2147483647, 28175373944314L, 30000L); + + query = "SELECT SUM(column6) FILTER(WHERE column6 > 5 OR column6 < 15)," + + "COUNT(*) FILTER(WHERE column1 IS NOT NULL)," + + "MAX(column3) FILTER(WHERE column3 IS NOT NULL AND column3 > 0), " + + "SUM(column3), AVG(column7) FILTER(WHERE column7 > 0 AND column7 < 100) FROM testTable"; + + aggregationOperator = getOperatorForSqlQuery(query); + resultsBlock = aggregationOperator.nextBlock(); + QueriesTestUtils + .testInnerSegmentExecutionStatistics(aggregationOperator.getExecutionStatistics(), + 150000L, 0L, 450000L, 30000L); + QueriesTestUtils + .testInnerSegmentAggregationResultForFilteredAggs(resultsBlock.getAggregationResult(), 22266008882250L, + 30000, 2147419555, + 2147483647, 0L, 0L); + } + + @Test public void testSmallAggregationGroupBy() { String query = "SELECT" + AGGREGATION + " FROM testTable" + SMALL_GROUP_BY; diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java index 6c2b730..c354b66 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java @@ -652,6 +652,17 @@ public class InterSegmentAggregationMultiValueQueriesTest extends BaseMultiValue } @Test + public void testFilteredAggregations() { + String query = "SELECT COUNT(*) FILTER(WHERE column1 > 5) FROM testTable WHERE column3 > 0"; + BrokerResponseNative brokerResponse = getBrokerResponseForSqlQuery(query); + assertEquals(brokerResponse.getResultTable().getRows().size(), 1); + + long resultValue = (long) brokerResponse.getResultTable().getRows().get(0)[0]; + + assertEquals(resultValue, 370236); + } + + @Test public void testGroupByMVColumns() { String query = "SELECT COUNT(*), column7 FROM testTable GROUP BY column7 LIMIT 1000"; BrokerResponseNative brokerResponse = getBrokerResponseForSqlQuery(query); diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java b/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java index 2d1a5d1..ac0303d 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/QueriesTestUtils.java @@ -62,6 +62,18 @@ public class QueriesTestUtils { Assert.assertEquals(avgResult.getCount(), expectedAvgResultCount); } + public static void testInnerSegmentAggregationResultForFilteredAggs(List<Object> aggregationResult, + long expectedFilteredSumResult, long expectedCountResult, int expectedMaxResult, + int expectedNonFilteredSum, long expectedAvgResultSum, long expectedAvgResultCount) { + Assert.assertEquals(((Number) aggregationResult.get(0)).longValue(), expectedFilteredSumResult); + Assert.assertEquals(((Number) aggregationResult.get(1)).longValue(), expectedCountResult); + Assert.assertEquals(((Number) aggregationResult.get(2)).intValue(), expectedMaxResult); + Assert.assertEquals(((Number) aggregationResult.get(3)).intValue(), expectedNonFilteredSum); + AvgPair avgResult = (AvgPair) aggregationResult.get(4); + Assert.assertEquals((long) avgResult.getSum(), expectedAvgResultSum); + Assert.assertEquals(avgResult.getCount(), expectedAvgResultCount); + } + public static void testInnerSegmentAggregationGroupByResult(AggregationGroupByResult aggregationGroupByResult, String expectedGroupKey, long expectedCountResult, long expectedSumResult, int expectedMaxResult, int expectedMinResult, long expectedAvgResultSum, long expectedAvgResultCount) { diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkFilteredAggregations.java b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkFilteredAggregations.java new file mode 100644 index 0000000..8019aa2 --- /dev/null +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkFilteredAggregations.java @@ -0,0 +1,197 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.perf; + +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import org.apache.commons.io.FileUtils; +import org.apache.pinot.common.response.broker.BrokerResponseNative; +import org.apache.pinot.queries.BaseQueriesTest; +import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader; +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl; +import org.apache.pinot.segment.local.segment.index.loader.IndexLoadingConfig; +import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader; +import org.apache.pinot.segment.spi.ImmutableSegment; +import org.apache.pinot.segment.spi.IndexSegment; +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig; +import org.apache.pinot.spi.config.table.FieldConfig; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.config.table.TableType; +import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.data.readers.GenericRow; +import org.apache.pinot.spi.data.readers.RecordReader; +import org.apache.pinot.spi.utils.builder.TableConfigBuilder; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(1) +@Warmup(iterations = 3, time = 10) +@Measurement(iterations = 5, time = 10) +@State(Scope.Benchmark) +public class BenchmarkFilteredAggregations extends BaseQueriesTest { + + private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "FilteredAggregationsTest"); + private static final String TABLE_NAME = "MyTable"; + private static final String FIRST_SEGMENT_NAME = "firstTestSegment"; + private static final String SECOND_SEGMENT_NAME = "secondTestSegment"; + private static final String INT_COL_NAME = "INT_COL"; + private static final String NO_INDEX_INT_COL_NAME = "NO_INDEX_INT_COL"; + + @Param("1500000") + private int _numRows; + @Param("0") + int _intBaseValue; + + private IndexSegment _indexSegment; + private List<IndexSegment> _indexSegments; + + public String _filteredQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 123 AND INT_COL < 599999)," + + "MAX(INT_COL) FILTER(WHERE INT_COL > 123 AND INT_COL < 599999) " + + "FROM MyTable WHERE NO_INDEX_INT_COL > 5 AND NO_INDEX_INT_COL < 1499999"; + + public String _nonFilteredQuery = "SELECT SUM(" + + "CASE " + + "WHEN (INT_COL > 123 AND INT_COL < 599999) THEN INT_COL " + + "ELSE 0 " + + "END) AS total_sum," + + "MAX(" + + "CASE " + + "WHEN (INT_COL > 123 AND INT_COL < 599999) THEN INT_COL " + + "ELSE 0 " + + "END) AS total_avg " + + "FROM MyTable WHERE NO_INDEX_INT_COL > 5 AND NO_INDEX_INT_COL < 1499999"; + + @Setup + public void setUp() + throws Exception { + FileUtils.deleteQuietly(INDEX_DIR); + + buildSegment(FIRST_SEGMENT_NAME); + buildSegment(SECOND_SEGMENT_NAME); + IndexLoadingConfig indexLoadingConfig = new IndexLoadingConfig(); + + Set<String> invertedIndexCols = new HashSet<>(); + invertedIndexCols.add(INT_COL_NAME); + + indexLoadingConfig.setRangeIndexColumns(invertedIndexCols); + indexLoadingConfig.setInvertedIndexColumns(invertedIndexCols); + + ImmutableSegment firstImmutableSegment = + ImmutableSegmentLoader.load(new File(INDEX_DIR, FIRST_SEGMENT_NAME), indexLoadingConfig); + ImmutableSegment secondImmutableSegment = + ImmutableSegmentLoader.load(new File(INDEX_DIR, SECOND_SEGMENT_NAME), indexLoadingConfig); + _indexSegment = firstImmutableSegment; + _indexSegments = Arrays.asList(firstImmutableSegment, secondImmutableSegment); + } + + @TearDown + public void tearDown() { + for (IndexSegment indexSegment : _indexSegments) { + indexSegment.destroy(); + } + + FileUtils.deleteQuietly(INDEX_DIR); + } + + private List<GenericRow> createTestData(int numRows) { + List<GenericRow> rows = new ArrayList<>(); + + for (int i = 0; i < numRows; i++) { + GenericRow row = new GenericRow(); + row.putField(INT_COL_NAME, _intBaseValue + i); + row.putField(NO_INDEX_INT_COL_NAME, _intBaseValue + i); + + rows.add(row); + } + return rows; + } + + private void buildSegment(String segmentName) + throws Exception { + List<GenericRow> rows = createTestData(_numRows); + List<FieldConfig> fieldConfigs = new ArrayList<>(); + + TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(TABLE_NAME) + .setInvertedIndexColumns(Arrays.asList(INT_COL_NAME)).setFieldConfigList(fieldConfigs).build(); + Schema schema = new Schema.SchemaBuilder().setSchemaName(TABLE_NAME) + .addSingleValueDimension(NO_INDEX_INT_COL_NAME, FieldSpec.DataType.INT) + .addSingleValueDimension(INT_COL_NAME, FieldSpec.DataType.INT).build(); + SegmentGeneratorConfig config = new SegmentGeneratorConfig(tableConfig, schema); + config.setOutDir(INDEX_DIR.getPath()); + config.setTableName(TABLE_NAME); + config.setSegmentName(segmentName); + + SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl(); + try (RecordReader recordReader = new GenericRowRecordReader(rows)) { + driver.init(config, recordReader); + driver.build(); + } + } + + @Benchmark + public BrokerResponseNative testFilteredAggregations() { + return getBrokerResponseForSqlQuery(_filteredQuery); + } + + @Benchmark + public BrokerResponseNative testNonFilteredAggregations(Blackhole blackhole) { + return getBrokerResponseForSqlQuery(_nonFilteredQuery); + } + + public static void main(String[] args) + throws Exception { + new Runner(new OptionsBuilder().include(BenchmarkFilteredAggregations.class.getSimpleName()).build()).run(); + } + + @Override + protected String getFilter() { + return null; + } + + @Override + protected IndexSegment getIndexSegment() { + return _indexSegment; + } + + @Override + protected List<IndexSegment> getIndexSegments() { + return _indexSegments; + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org