This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 303b1a7cbe [Issue 7519] Adds support for multiple filtered/unfiltered 
aggregations with GROUP BY (#10000)
303b1a7cbe is described below

commit 303b1a7cbe78244491f0580eb88e966a41b56b25
Author: Evan Galpin <egal...@users.noreply.github.com>
AuthorDate: Wed Jan 4 19:01:40 2023 -0800

    [Issue 7519] Adds support for multiple filtered/unfiltered aggregations 
with GROUP BY (#10000)
---
 .../operator/query/FilteredGroupByOperator.java    | 215 +++++++++++++++++++++
 .../pinot/core/plan/AggregationPlanNode.java       |  87 +--------
 .../apache/pinot/core/plan/GroupByPlanNode.java    |  30 ++-
 .../function/AggregationFunctionUtils.java         |  94 +++++++++
 .../groupby/DefaultGroupByExecutor.java            |  56 ++++--
 .../query/aggregation/groupby/GroupByExecutor.java |   4 +
 .../core/query/request/context/QueryContext.java   |   6 +-
 .../query/aggregation/groupby/GroupByTrimTest.java |   9 +-
 .../pinot/queries/FilteredAggregationsTest.java    |  57 +++++-
 9 files changed, 445 insertions(+), 113 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
new file mode 100644
index 0000000000..e895d817dd
--- /dev/null
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
@@ -0,0 +1,215 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.query;
+
+import java.util.Collection;
+import java.util.IdentityHashMap;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.data.table.IntermediateRecord;
+import org.apache.pinot.core.data.table.TableResizer;
+import org.apache.pinot.core.operator.BaseOperator;
+import org.apache.pinot.core.operator.ExecutionStatistics;
+import org.apache.pinot.core.operator.blocks.TransformBlock;
+import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
+import org.apache.pinot.core.operator.transform.TransformOperator;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import 
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
+import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
+import org.apache.pinot.core.query.request.context.QueryContext;
+import org.apache.pinot.core.util.GroupByUtils;
+import org.apache.pinot.spi.trace.Tracing;
+
+
+/**
+ * The <code>FilteredGroupByOperator</code> class provides the operator for 
group-by query on a single segment when
+ * there are 1 or more filter expressions on aggregations.
+ */
+@SuppressWarnings("rawtypes")
+public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> 
{
+  private static final String EXPLAIN_NAME = "GROUP_BY_FILTERED";
+
+  private final AggregationFunction[] _aggregationFunctions;
+  private final List<Pair<AggregationFunction[], TransformOperator>> 
_aggFunctionsWithTransformOperator;
+  private final ExpressionContext[] _groupByExpressions;
+  private final long _numTotalDocs;
+  private long _numDocsScanned;
+  private long _numEntriesScannedInFilter;
+  private long _numEntriesScannedPostFilter;
+  private final DataSchema _dataSchema;
+  private final QueryContext _queryContext;
+
+  public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions,
+      List<Pair<AggregationFunction[], TransformOperator>> 
aggFunctionsWithTransformOperator,
+      ExpressionContext[] groupByExpressions, long numTotalDocs, QueryContext 
queryContext) {
+    _aggregationFunctions = aggregationFunctions;
+    _aggFunctionsWithTransformOperator = aggFunctionsWithTransformOperator;
+    _groupByExpressions = groupByExpressions;
+    _numTotalDocs = numTotalDocs;
+    _queryContext = queryContext;
+
+    // NOTE: The indexedTable expects that the data schema will have group by 
columns before aggregation columns
+    int numGroupByExpressions = groupByExpressions.length;
+    int numAggregationFunctions = aggregationFunctions.length;
+    int numColumns = numGroupByExpressions + numAggregationFunctions;
+    String[] columnNames = new String[numColumns];
+    DataSchema.ColumnDataType[] columnDataTypes = new 
DataSchema.ColumnDataType[numColumns];
+
+    // Extract column names and data types for group-by columns
+    for (int i = 0; i < numGroupByExpressions; i++) {
+      ExpressionContext groupByExpression = groupByExpressions[i];
+      columnNames[i] = groupByExpression.toString();
+      columnDataTypes[i] = DataSchema.ColumnDataType.fromDataTypeSV(
+          
aggFunctionsWithTransformOperator.get(i).getRight().getResultMetadata(groupByExpression).getDataType());
+    }
+
+    // Extract column names and data types for aggregation functions
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      AggregationFunction aggregationFunction = aggregationFunctions[i];
+      int index = numGroupByExpressions + i;
+      columnNames[index] = aggregationFunction.getResultColumnName();
+      columnDataTypes[index] = 
aggregationFunction.getIntermediateResultColumnType();
+    }
+
+    _dataSchema = new DataSchema(columnNames, columnDataTypes);
+  }
+
+  @Override
+  protected GroupByResultsBlock getNextBlock() {
+    // TODO(egalpin): Support Startree query resolution when possible, even 
with FILTER expressions
+    int numAggregations = _aggregationFunctions.length;
+
+    GroupByResultHolder[] groupByResultHolders = new 
GroupByResultHolder[numAggregations];
+    IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap = new 
IdentityHashMap<>(numAggregations);
+    for (int i = 0; i < numAggregations; i++) {
+      resultHolderIndexMap.put(_aggregationFunctions[i], i);
+    }
+
+    GroupKeyGenerator groupKeyGenerator = null;
+    for (Pair<AggregationFunction[], TransformOperator> filteredAggregation : 
_aggFunctionsWithTransformOperator) {
+      TransformOperator transformOperator = filteredAggregation.getRight();
+      AggregationFunction[] filteredAggFunctions = 
filteredAggregation.getLeft();
+
+      // Perform aggregation group-by on all the blocks
+      DefaultGroupByExecutor groupByExecutor;
+      if (groupKeyGenerator == null) {
+        // The group key generator should be shared across all 
AggregationFunctions so that agg results can be
+        // aligned. Given that filtered aggregations are stored as an iterable 
of iterables so that all filtered aggs
+        // with the same filter can share transform blocks, rather than a 
singular flat iterable in the case where
+        // aggs are all non-filtered, sharing a GroupKeyGenerator across all 
aggs cannot be accomplished by allowing
+        // the GroupByExecutor to have sole ownership of the 
GroupKeyGenerator. Therefore, we allow constructing a
+        // GroupByExecutor with a pre-existing GroupKeyGenerator so that the 
GroupKeyGenerator can be shared across
+        // loop iterations i.e. across all aggs.
+        groupByExecutor =
+            new DefaultGroupByExecutor(_queryContext, filteredAggFunctions, 
_groupByExpressions, transformOperator);
+        groupKeyGenerator = groupByExecutor.getGroupKeyGenerator();
+      } else {
+        groupByExecutor =
+            new DefaultGroupByExecutor(_queryContext, filteredAggFunctions, 
_groupByExpressions, transformOperator,
+                groupKeyGenerator);
+      }
+
+      int numDocsScanned = 0;
+      TransformBlock transformBlock;
+      while ((transformBlock = transformOperator.nextBlock()) != null) {
+        numDocsScanned += transformBlock.getNumDocs();
+        groupByExecutor.process(transformBlock);
+      }
+
+      _numDocsScanned += numDocsScanned;
+      _numEntriesScannedInFilter += 
transformOperator.getExecutionStatistics().getNumEntriesScannedInFilter();
+      _numEntriesScannedPostFilter += (long) numDocsScanned * 
transformOperator.getNumColumnsProjected();
+      GroupByResultHolder[] filterGroupByResults = 
groupByExecutor.getGroupByResultHolders();
+      for (int i = 0; i < filteredAggFunctions.length; i++) {
+        
groupByResultHolders[resultHolderIndexMap.get(filteredAggFunctions[i])] = 
filterGroupByResults[i];
+      }
+    }
+    assert groupKeyGenerator != null;
+    for (GroupByResultHolder groupByResultHolder : groupByResultHolders) {
+      groupByResultHolder.ensureCapacity(groupKeyGenerator.getNumKeys());
+    }
+
+    // Check if the groups limit is reached
+    boolean numGroupsLimitReached = groupKeyGenerator.getNumKeys() >= 
_queryContext.getNumGroupsLimit();
+    Tracing.activeRecording().setNumGroups(_queryContext.getNumGroupsLimit(), 
groupKeyGenerator.getNumKeys());
+
+    // Trim the groups when iff:
+    // - Query has ORDER BY clause
+    // - Segment group trim is enabled
+    // - There are more groups than the trim size
+    // TODO: Currently the groups are not trimmed if there is no ordering 
specified. Consider ordering on group-by
+    //       columns if no ordering is specified.
+    int minGroupTrimSize = _queryContext.getMinSegmentGroupTrimSize();
+    if (_queryContext.getOrderByExpressions() != null && minGroupTrimSize > 0) 
{
+      int trimSize = GroupByUtils.getTableCapacity(_queryContext.getLimit(), 
minGroupTrimSize);
+      if (groupKeyGenerator.getNumKeys() > trimSize) {
+        TableResizer tableResizer = new TableResizer(_dataSchema, 
_queryContext);
+        Collection<IntermediateRecord> intermediateRecords =
+            tableResizer.trimInSegmentResults(groupKeyGenerator, 
groupByResultHolders, trimSize);
+        GroupByResultsBlock resultsBlock = new 
GroupByResultsBlock(_dataSchema, intermediateRecords);
+        resultsBlock.setNumGroupsLimitReached(numGroupsLimitReached);
+        return resultsBlock;
+      }
+    }
+
+    AggregationGroupByResult aggGroupByResult =
+        new AggregationGroupByResult(groupKeyGenerator, _aggregationFunctions, 
groupByResultHolders);
+    GroupByResultsBlock resultsBlock = new GroupByResultsBlock(_dataSchema, 
aggGroupByResult);
+    resultsBlock.setNumGroupsLimitReached(numGroupsLimitReached);
+    return resultsBlock;
+  }
+
+  @Override
+  public List<Operator> getChildOperators() {
+    return 
_aggFunctionsWithTransformOperator.stream().map(Pair::getRight).collect(Collectors.toList());
+  }
+
+  @Override
+  public ExecutionStatistics getExecutionStatistics() {
+    return new ExecutionStatistics(_numDocsScanned, 
_numEntriesScannedInFilter, _numEntriesScannedPostFilter,
+        _numTotalDocs);
+  }
+
+  @Override
+  public String toExplainString() {
+    StringBuilder stringBuilder = new 
StringBuilder(EXPLAIN_NAME).append("(groupKeys:");
+    if (_groupByExpressions.length > 0) {
+      stringBuilder.append(_groupByExpressions[0].toString());
+      for (int i = 1; i < _groupByExpressions.length; i++) {
+        stringBuilder.append(", ").append(_groupByExpressions[i].toString());
+      }
+    }
+
+    stringBuilder.append(", aggregations:");
+    if (_aggregationFunctions.length > 0) {
+      stringBuilder.append(_aggregationFunctions[0].toExplainString());
+      for (int i = 1; i < _aggregationFunctions.length; i++) {
+        stringBuilder.append(", 
").append(_aggregationFunctions[i].toExplainString());
+      }
+    }
+
+    return stringBuilder.append(')').toString();
+  }
+}
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java 
b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
index 58d74fb00f..148911897e 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
@@ -18,19 +18,15 @@
  */
 package org.apache.pinot.core.plan;
 
-import java.util.ArrayList;
 import java.util.EnumSet;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.request.context.ExpressionContext;
-import org.apache.pinot.common.request.context.FilterContext;
 import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
 import org.apache.pinot.core.operator.filter.BaseFilterOperator;
-import org.apache.pinot.core.operator.filter.CombinedFilterOperator;
 import org.apache.pinot.core.operator.query.AggregationOperator;
 import org.apache.pinot.core.operator.query.FastFilteredCountOperator;
 import org.apache.pinot.core.operator.query.FilteredAggregationOperator;
@@ -77,7 +73,7 @@ public class AggregationPlanNode implements PlanNode {
   @Override
   public Operator<AggregationResultsBlock> run() {
     assert _queryContext.getAggregationFunctions() != null;
-    return _queryContext.isHasFilteredAggregations() ? 
buildFilteredAggOperator() : buildNonFilteredAggOperator();
+    return _queryContext.hasFilteredAggregations() ? 
buildFilteredAggOperator() : buildNonFilteredAggOperator();
   }
 
   /**
@@ -86,83 +82,18 @@ public class AggregationPlanNode implements PlanNode {
   private FilteredAggregationOperator buildFilteredAggOperator() {
     int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
     // Build the operator chain for the main predicate
-    Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair = 
buildFilterOperator(_queryContext.getFilter());
-    TransformOperator transformOperator = 
buildTransformOperatorForFilteredAggregates(filterOperatorPair.getRight());
-
-    return buildFilterOperatorInternal(filterOperatorPair.getRight(), 
transformOperator, numTotalDocs);
-  }
-
-  /**
-   * Build a FilteredAggregationOperator given the parameters.
-   * @param mainPredicateFilterOperator Filter operator corresponding to the 
main predicate
-   * @param mainTransformOperator Transform operator corresponding to the main 
predicate
-   * @param numTotalDocs Number of total docs
-   */
-  private FilteredAggregationOperator 
buildFilterOperatorInternal(BaseFilterOperator mainPredicateFilterOperator,
-      TransformOperator mainTransformOperator, int numTotalDocs) {
-    Map<FilterContext, Pair<List<AggregationFunction>, TransformOperator>> 
filterContextToAggFuncsMap = new HashMap<>();
-    List<AggregationFunction> nonFilteredAggregationFunctions = new 
ArrayList<>();
-    List<Pair<AggregationFunction, FilterContext>> aggregationFunctions =
-        _queryContext.getFilteredAggregationFunctions();
-
-    // For each aggregation function, check if the aggregation function is a 
filtered agg.
-    // If it is, populate the corresponding filter operator and corresponding 
transform operator
-    for (Pair<AggregationFunction, FilterContext> inputPair : 
aggregationFunctions) {
-      if (inputPair.getLeft() != null) {
-        FilterContext currentFilterExpression = inputPair.getRight();
-        if (filterContextToAggFuncsMap.get(currentFilterExpression) != null) {
-          
filterContextToAggFuncsMap.get(currentFilterExpression).getLeft().add(inputPair.getLeft());
-          continue;
-        }
-        Pair<FilterPlanNode, BaseFilterOperator> pair = 
buildFilterOperator(currentFilterExpression);
-        BaseFilterOperator wrappedFilterOperator =
-            new CombinedFilterOperator(mainPredicateFilterOperator, 
pair.getRight(), _queryContext.getQueryOptions());
-        TransformOperator newTransformOperator = 
buildTransformOperatorForFilteredAggregates(wrappedFilterOperator);
-        // For each transform operator, associate it with the underlying 
expression. This allows
-        // fetching the relevant TransformOperator when resolving blocks 
during aggregation
-        // execution
-        List<AggregationFunction> aggFunctionList = new ArrayList<>();
-        aggFunctionList.add(inputPair.getLeft());
-        filterContextToAggFuncsMap.put(currentFilterExpression, 
Pair.of(aggFunctionList, newTransformOperator));
-      } else {
-        nonFilteredAggregationFunctions.add(inputPair.getLeft());
-      }
-    }
-    List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList 
= new ArrayList<>();
-    // Convert to array since FilteredAggregationOperator expects it
-    for (Pair<List<AggregationFunction>, TransformOperator> pair : 
filterContextToAggFuncsMap.values()) {
-      List<AggregationFunction> aggregationFunctionList = pair.getLeft();
-      if (aggregationFunctionList == null) {
-        throw new IllegalStateException("Null aggregation list seen");
-      }
-      aggToTransformOpList.add(Pair.of(aggregationFunctionList.toArray(new 
AggregationFunction[0]), pair.getRight()));
-    }
-    aggToTransformOpList.add(
-        Pair.of(nonFilteredAggregationFunctions.toArray(new 
AggregationFunction[0]), mainTransformOperator));
+    Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair =
+        AggregationFunctionUtils.buildFilterOperator(_indexSegment, 
_queryContext);
+    TransformOperator transformOperator =
+        
AggregationFunctionUtils.buildTransformOperatorForFilteredAggregates(_indexSegment,
 _queryContext,
+            filterOperatorPair.getRight(), null);
 
+    List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList =
+        AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment, 
_queryContext,
+            filterOperatorPair.getRight(), transformOperator, null);
     return new 
FilteredAggregationOperator(_queryContext.getAggregationFunctions(), 
aggToTransformOpList, numTotalDocs);
   }
 
-  /**
-   * Build a filter operator from the given FilterContext.
-   *
-   * It returns the FilterPlanNode to allow reusing plan level components such 
as predicate
-   * evaluator map
-   */
-  private Pair<FilterPlanNode, BaseFilterOperator> 
buildFilterOperator(FilterContext filterContext) {
-    FilterPlanNode filterPlanNode = new FilterPlanNode(_indexSegment, 
_queryContext, filterContext);
-    return Pair.of(filterPlanNode, filterPlanNode.run());
-  }
-
-  private TransformOperator 
buildTransformOperatorForFilteredAggregates(BaseFilterOperator filterOperator) {
-    AggregationFunction[] aggregationFunctions = 
_queryContext.getAggregationFunctions();
-    Set<ExpressionContext> expressionsToTransform =
-        
AggregationFunctionUtils.collectExpressionsToTransform(aggregationFunctions, 
null);
-
-    return new TransformPlanNode(_indexSegment, _queryContext, 
expressionsToTransform,
-        DocIdSetPlanNode.MAX_DOC_PER_CALL, filterOperator).run();
-  }
-
   /**
    * Processing workhorse for non filtered aggregates. Note that this code 
path is invoked only
    * if the query has no filtered aggregates at all. If a query has mixed 
aggregates, filtered
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java 
b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
index 2b5da7896b..99fdec9746 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
@@ -21,8 +21,12 @@ package org.apache.pinot.core.plan;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
 import org.apache.pinot.core.operator.filter.BaseFilterOperator;
+import org.apache.pinot.core.operator.query.FilteredGroupByOperator;
 import org.apache.pinot.core.operator.query.GroupByOperator;
 import org.apache.pinot.core.operator.transform.TransformOperator;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
@@ -50,10 +54,34 @@ public class GroupByPlanNode implements PlanNode {
   }
 
   @Override
-  public GroupByOperator run() {
+  public Operator<GroupByResultsBlock> run() {
     assert _queryContext.getAggregationFunctions() != null;
     assert _queryContext.getGroupByExpressions() != null;
 
+    if (_queryContext.hasFilteredAggregations()) {
+      return buildFilteredGroupByPlan();
+    }
+    return buildNonFilteredGroupByPlan();
+  }
+
+  private FilteredGroupByOperator buildFilteredGroupByPlan() {
+    int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
+    // Build the operator chain for the main predicate so the filter plan can 
be run only one time
+    Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair =
+        AggregationFunctionUtils.buildFilterOperator(_indexSegment, 
_queryContext);
+    ExpressionContext[] groupByExpressions = 
_queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]);
+    TransformOperator transformOperator =
+        
AggregationFunctionUtils.buildTransformOperatorForFilteredAggregates(_indexSegment,
 _queryContext,
+            filterOperatorPair.getRight(), groupByExpressions);
+
+    List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList =
+        AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment, 
_queryContext,
+            filterOperatorPair.getRight(), transformOperator, 
groupByExpressions);
+    return new 
FilteredGroupByOperator(_queryContext.getAggregationFunctions(), 
aggToTransformOpList,
+        _queryContext.getGroupByExpressions().toArray(new 
ExpressionContext[0]), numTotalDocs, _queryContext);
+  }
+
+  private GroupByOperator buildNonFilteredGroupByPlan() {
     int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
     AggregationFunction[] aggregationFunctions = 
_queryContext.getAggregationFunctions();
     ExpressionContext[] groupByExpressions = 
_queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]);
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
index 8ef21fa1b4..0dcecb046d 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.core.query.aggregation.function;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
@@ -26,13 +27,23 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import javax.annotation.Nullable;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.datatable.DataTable;
 import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.request.context.FilterContext;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.common.ObjectSerDeUtils;
 import org.apache.pinot.core.operator.blocks.TransformBlock;
+import org.apache.pinot.core.operator.filter.BaseFilterOperator;
+import org.apache.pinot.core.operator.filter.CombinedFilterOperator;
+import org.apache.pinot.core.operator.transform.TransformOperator;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.core.plan.FilterPlanNode;
+import org.apache.pinot.core.plan.TransformPlanNode;
+import org.apache.pinot.core.query.request.context.QueryContext;
 import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.apache.pinot.segment.spi.IndexSegment;
 import 
org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
 
 
@@ -165,4 +176,87 @@ public class AggregationFunctionUtils {
         throw new IllegalStateException("Illegal column data type in final 
result: " + columnDataType);
     }
   }
+
+  /**
+   * Build a filter operator from the given FilterContext.
+   *
+   * It returns the FilterPlanNode to allow reusing plan level components such 
as predicate
+   * evaluator map
+   */
+  public static Pair<FilterPlanNode, BaseFilterOperator> 
buildFilterOperator(IndexSegment indexSegment,
+      QueryContext queryContext, FilterContext filterContext) {
+    FilterPlanNode filterPlanNode = new FilterPlanNode(indexSegment, 
queryContext, filterContext);
+    return Pair.of(filterPlanNode, filterPlanNode.run());
+  }
+
+  public static Pair<FilterPlanNode, BaseFilterOperator> 
buildFilterOperator(IndexSegment indexSegment,
+      QueryContext queryContext) {
+    return buildFilterOperator(indexSegment, queryContext, 
queryContext.getFilter());
+  }
+
+  public static TransformOperator 
buildTransformOperatorForFilteredAggregates(IndexSegment indexSegment,
+      QueryContext queryContext, BaseFilterOperator filterOperator, @Nullable 
ExpressionContext[] groupByExpressions) {
+    AggregationFunction[] aggregationFunctions = 
queryContext.getAggregationFunctions();
+    assert aggregationFunctions != null;
+    Set<ExpressionContext> expressionsToTransform =
+        collectExpressionsToTransform(aggregationFunctions, 
groupByExpressions);
+    return new TransformPlanNode(indexSegment, queryContext, 
expressionsToTransform, DocIdSetPlanNode.MAX_DOC_PER_CALL,
+        filterOperator).run();
+  }
+
+  /**
+   * Build pairs of filtered aggregation functions and corresponding transform 
operator
+   * @param mainPredicateFilterOperator Filter operator corresponding to the 
main predicate
+   * @param mainTransformOperator Transform operator corresponding to the main 
predicate
+   */
+  public static List<Pair<AggregationFunction[], TransformOperator>> 
buildFilteredAggTransformPairs(
+      IndexSegment indexSegment, QueryContext queryContext, BaseFilterOperator 
mainPredicateFilterOperator,
+      TransformOperator mainTransformOperator, @Nullable ExpressionContext[] 
groupByExpressions) {
+    Map<FilterContext, Pair<List<AggregationFunction>, TransformOperator>> 
filterContextToAggFuncsMap = new HashMap<>();
+    List<AggregationFunction> nonFilteredAggregationFunctions = new 
ArrayList<>();
+    List<Pair<AggregationFunction, FilterContext>> aggregationFunctions =
+        queryContext.getFilteredAggregationFunctions();
+    List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList 
= new ArrayList<>();
+
+    // For each aggregation function, check if the aggregation function is a 
filtered agg.
+    // If it is, populate the corresponding filter operator and corresponding 
transform operator
+    assert aggregationFunctions != null;
+    for (Pair<AggregationFunction, FilterContext> inputPair : 
aggregationFunctions) {
+      if (inputPair.getLeft() != null) {
+        FilterContext currentFilterExpression = inputPair.getRight();
+        if (filterContextToAggFuncsMap.get(currentFilterExpression) != null) {
+          
filterContextToAggFuncsMap.get(currentFilterExpression).getLeft().add(inputPair.getLeft());
+          continue;
+        }
+        Pair<FilterPlanNode, BaseFilterOperator> filterPlanOpPair =
+            buildFilterOperator(indexSegment, queryContext, 
currentFilterExpression);
+        BaseFilterOperator wrappedFilterOperator =
+            new CombinedFilterOperator(mainPredicateFilterOperator, 
filterPlanOpPair.getRight(),
+                queryContext.getQueryOptions());
+        TransformOperator newTransformOperator =
+            buildTransformOperatorForFilteredAggregates(indexSegment, 
queryContext, wrappedFilterOperator,
+                groupByExpressions);
+        // For each transform operator, associate it with the underlying 
expression. This allows
+        // fetching the relevant TransformOperator when resolving blocks 
during aggregation
+        // execution
+        List<AggregationFunction> aggFunctionList = new ArrayList<>();
+        aggFunctionList.add(inputPair.getLeft());
+        filterContextToAggFuncsMap.put(currentFilterExpression, 
Pair.of(aggFunctionList, newTransformOperator));
+      } else {
+        nonFilteredAggregationFunctions.add(inputPair.getLeft());
+      }
+    }
+    // Convert to array since FilteredGroupByOperator expects it
+    for (Pair<List<AggregationFunction>, TransformOperator> pair : 
filterContextToAggFuncsMap.values()) {
+      List<AggregationFunction> aggregationFunctionList = pair.getLeft();
+      if (aggregationFunctionList == null) {
+        throw new IllegalStateException("Null aggregation list seen");
+      }
+      aggToTransformOpList.add(Pair.of(aggregationFunctionList.toArray(new 
AggregationFunction[0]), pair.getRight()));
+    }
+    aggToTransformOpList.add(
+        Pair.of(nonFilteredAggregationFunctions.toArray(new 
AggregationFunction[0]), mainTransformOperator));
+
+    return aggToTransformOpList;
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
index e0af94070c..38ebd3706c 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.groupby;
 
 import java.util.Collection;
 import java.util.Map;
+import javax.annotation.Nullable;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.data.table.IntermediateRecord;
@@ -58,16 +59,28 @@ public class DefaultGroupByExecutor implements 
GroupByExecutor {
   protected final int[] _svGroupKeys;
   protected final int[][] _mvGroupKeys;
 
+  public DefaultGroupByExecutor(QueryContext queryContext, ExpressionContext[] 
groupByExpressions,
+      TransformOperator transformOperator) {
+    this(queryContext, queryContext.getAggregationFunctions(), 
groupByExpressions, transformOperator, null);
+  }
+
+  public DefaultGroupByExecutor(QueryContext queryContext, 
AggregationFunction[] aggregationFunctions,
+      ExpressionContext[] groupByExpressions, TransformOperator 
transformOperator) {
+    this(queryContext, aggregationFunctions, groupByExpressions, 
transformOperator, null);
+  }
+
   /**
    * Constructor for the class.
    *
    * @param queryContext Query context
+   * @param aggregationFunctions Aggregation functions
    * @param groupByExpressions Array of group-by expressions
    * @param transformOperator Transform operator
    */
-  public DefaultGroupByExecutor(QueryContext queryContext, ExpressionContext[] 
groupByExpressions,
-      TransformOperator transformOperator) {
-    _aggregationFunctions = queryContext.getAggregationFunctions();
+  public DefaultGroupByExecutor(QueryContext queryContext, 
AggregationFunction[] aggregationFunctions,
+      ExpressionContext[] groupByExpressions, TransformOperator 
transformOperator,
+      @Nullable GroupKeyGenerator groupKeyGenerator) {
+    _aggregationFunctions = aggregationFunctions;
     assert _aggregationFunctions != null;
     _nullHandlingEnabled = queryContext.isNullHandlingEnabled();
 
@@ -83,19 +96,23 @@ public class DefaultGroupByExecutor implements 
GroupByExecutor {
     // Initialize group key generator
     int numGroupsLimit = queryContext.getNumGroupsLimit();
     int maxInitialResultHolderCapacity = 
queryContext.getMaxInitialResultHolderCapacity();
-    if (hasNoDictionaryGroupByExpression || _nullHandlingEnabled) {
-      if (groupByExpressions.length == 1) {
-        // TODO(nhejazi): support MV and dictionary based when null handling 
is enabled.
-        _groupKeyGenerator =
-            new NoDictionarySingleColumnGroupKeyGenerator(transformOperator, 
groupByExpressions[0], numGroupsLimit,
-                _nullHandlingEnabled);
+    if (groupKeyGenerator != null) {
+      _groupKeyGenerator = groupKeyGenerator;
+    } else {
+      if (hasNoDictionaryGroupByExpression || _nullHandlingEnabled) {
+        if (groupByExpressions.length == 1) {
+          // TODO(nhejazi): support MV and dictionary based when null handling 
is enabled.
+          _groupKeyGenerator =
+              new NoDictionarySingleColumnGroupKeyGenerator(transformOperator, 
groupByExpressions[0], numGroupsLimit,
+                  _nullHandlingEnabled);
+        } else {
+          _groupKeyGenerator =
+              new NoDictionaryMultiColumnGroupKeyGenerator(transformOperator, 
groupByExpressions, numGroupsLimit);
+        }
       } else {
-        _groupKeyGenerator =
-            new NoDictionaryMultiColumnGroupKeyGenerator(transformOperator, 
groupByExpressions, numGroupsLimit);
+        _groupKeyGenerator = new 
DictionaryBasedGroupKeyGenerator(transformOperator, groupByExpressions, 
numGroupsLimit,
+            maxInitialResultHolderCapacity);
       }
-    } else {
-      _groupKeyGenerator = new 
DictionaryBasedGroupKeyGenerator(transformOperator, groupByExpressions, 
numGroupsLimit,
-          maxInitialResultHolderCapacity);
     }
 
     // Initialize result holders
@@ -141,7 +158,6 @@ public class DefaultGroupByExecutor implements 
GroupByExecutor {
     AggregationFunction aggregationFunction = 
_aggregationFunctions[functionIndex];
     Map<ExpressionContext, BlockValSet> blockValSetMap =
         AggregationFunctionUtils.getBlockValSetMap(aggregationFunction, 
transformBlock);
-
     GroupByResultHolder groupByResultHolder = 
_groupByResultHolders[functionIndex];
     if (_hasMVGroupByExpression) {
       aggregationFunction.aggregateGroupByMV(length, _mvGroupKeys, 
groupByResultHolder, blockValSetMap);
@@ -164,4 +180,14 @@ public class DefaultGroupByExecutor implements 
GroupByExecutor {
   public Collection<IntermediateRecord> trimGroupByResult(int trimSize, 
TableResizer tableResizer) {
     return tableResizer.trimInSegmentResults(_groupKeyGenerator, 
_groupByResultHolders, trimSize);
   }
+
+  @Override
+  public GroupKeyGenerator getGroupKeyGenerator() {
+    return _groupKeyGenerator;
+  }
+
+  @Override
+  public GroupByResultHolder[] getGroupByResultHolders() {
+    return _groupByResultHolders;
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
index 869ef5dbe9..db5ff16b18 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
@@ -58,4 +58,8 @@ public interface GroupByExecutor {
    *
    */
   Collection<IntermediateRecord> trimGroupByResult(int trimSize, TableResizer 
tableResizer);
+
+  GroupKeyGenerator getGroupKeyGenerator();
+
+  GroupByResultHolder[] getGroupByResultHolders();
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
index 5c1bd2fe84..fcc97dd6fd 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
@@ -260,7 +260,7 @@ public class QueryContext {
   /**
    * Returns the filtered aggregation expressions for the query.
    */
-  public boolean isHasFilteredAggregations() {
+  public boolean hasFilteredAggregations() {
     return _hasFilteredAggregations;
   }
 
@@ -536,10 +536,6 @@ public class QueryContext {
         FunctionContext aggregation = pair.getLeft();
         FilterContext filter = pair.getRight();
         if (filter != null) {
-          // Filtered aggregation
-          if (_groupByExpressions != null) {
-            throw new IllegalStateException("GROUP BY with FILTER clauses is 
not supported");
-          }
           queryContext._hasFilteredAggregations = true;
         }
         int functionIndex = filteredAggregationFunctions.size();
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java
index 62236f3a4b..dba3faefe6 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.java
@@ -29,11 +29,11 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import org.apache.commons.io.FileUtils;
 import org.apache.commons.lang3.tuple.Pair;
+import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.core.data.table.Record;
 import org.apache.pinot.core.data.table.Table;
 import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
 import org.apache.pinot.core.operator.combine.GroupByCombineOperator;
-import org.apache.pinot.core.operator.query.GroupByOperator;
 import org.apache.pinot.core.plan.GroupByPlanNode;
 import org.apache.pinot.core.query.request.context.QueryContext;
 import 
org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
@@ -50,13 +50,12 @@ import org.apache.pinot.spi.data.readers.GenericRow;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.spi.utils.ReadMode;
 import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.Assert;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeClass;
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
-import static org.testng.Assert.assertEquals;
-
 
 /**
  * Unit test for GroupBy Trim functionalities.
@@ -120,7 +119,7 @@ public class GroupByTrimTest {
     queryContext.setMinServerGroupTrimSize(minServerGroupTrimSize);
 
     // Create a query operator
-    GroupByOperator groupByOperator = new GroupByPlanNode(_indexSegment, 
queryContext).run();
+    Operator<GroupByResultsBlock> groupByOperator = new 
GroupByPlanNode(_indexSegment, queryContext).run();
     GroupByCombineOperator combineOperator =
         new GroupByCombineOperator(Collections.singletonList(groupByOperator), 
queryContext, _executorService);
 
@@ -130,7 +129,7 @@ public class GroupByTrimTest {
     // Extract the execution result
     List<Pair<Double, Double>> extractedResult = 
extractTestResult(resultsBlock.getTable());
 
-    assertEquals(extractedResult, expectedResult);
+    Assert.assertEquals(extractedResult, expectedResult);
   }
 
   /**
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
index 2fc9ad1fa6..9d772abc3f 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java
@@ -202,10 +202,10 @@ public class FilteredAggregationsTest extends 
BaseQueriesTest {
     nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true 
AND STARTSWITH(STRING_COL, 'abc')";
     testQuery(filterQuery, nonFilterQuery);
 
-    filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND 
STARTSWITH(REVERSE(STRING_COL), 'abc')) FROM "
-        + "MyTable";
-    nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true 
AND STARTSWITH(REVERSE(STRING_COL), "
-        + "'abc')";
+    filterQuery =
+        "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND 
STARTSWITH(REVERSE(STRING_COL), 'abc')) FROM " + "MyTable";
+    nonFilterQuery =
+        "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true AND 
STARTSWITH(REVERSE(STRING_COL), " + "'abc')";
     testQuery(filterQuery, nonFilterQuery);
   }
 
@@ -335,10 +335,49 @@ public class FilteredAggregationsTest extends 
BaseQueriesTest {
     testQuery(filterQuery, nonFilterQuery);
   }
 
-  @Test(expectedExceptions = IllegalStateException.class)
-  public void testGroupBySupport() {
-    String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 2), 
MAX(INT_COL) FILTER(WHERE INT_COL > 2) "
-        + "FROM MyTable WHERE INT_COL < 1000 GROUP BY INT_COL";
-    getBrokerResponse(filterQuery);
+  @Test
+  public void testGroupBy() {
+    String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 25000) 
FROM MyTable GROUP BY BOOLEAN_COL";
+    String nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 
25000 GROUP BY BOOLEAN_COL";
+    testQuery(filterQuery, nonFilterQuery);
+  }
+
+  @Test
+  public void testGroupByCaseAlternative() {
+    String filterQuery =
+        "SELECT SUM(INT_COL), SUM(INT_COL) FILTER(WHERE INT_COL > 25000) AS 
total_sum FROM MyTable GROUP BY "
+            + "BOOLEAN_COL";
+    String nonFilterQuery =
+        "SELECT SUM(INT_COL), SUM(CASE WHEN INT_COL > 25000 THEN INT_COL ELSE 
0 END) AS total_sum FROM MyTable GROUP "
+            + "BY BOOLEAN_COL";
+    testQuery(filterQuery, nonFilterQuery);
+  }
+
+  @Test
+  public void testGroupBySameFilter() {
+    String filterQuery =
+        "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000), SUM(INT_COL) 
FILTER(WHERE INT_COL > 25000) FROM MyTable "
+            + "GROUP BY BOOLEAN_COL";
+    String nonFilterQuery = "SELECT AVG(INT_COL), SUM(INT_COL) FROM MyTable 
WHERE INT_COL > 25000 GROUP BY BOOLEAN_COL";
+    testQuery(filterQuery, nonFilterQuery);
+  }
+
+  @Test
+  public void testMultipleAggregationsOnSameFilterGroupBy() {
+    String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 
29990), "
+        + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) FROM MyTable GROUP BY 
BOOLEAN_COL";
+    String nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable 
WHERE INT_COL > 29990 GROUP BY BOOLEAN_COL";
+    testQuery(filterQuery, nonFilterQuery);
+
+    filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS 
total_min, "
+        + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) AS total_max, "
+        + "SUM(INT_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_sum, "
+        + "MAX(NO_INDEX_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_max2 
FROM MyTable GROUP BY BOOLEAN_COL";
+    nonFilterQuery = "SELECT MIN(CASE WHEN (NO_INDEX_COL > 29990) THEN INT_COL 
ELSE 99999 END) AS total_min, "
+        + "MAX(CASE WHEN (INT_COL > 29990) THEN INT_COL ELSE 0 END) AS 
total_max, "
+        + "SUM(CASE WHEN (NO_INDEX_COL < 5000) THEN INT_COL ELSE 0 END) AS 
total_sum, "
+        + "MAX(CASE WHEN (NO_INDEX_COL < 5000) THEN NO_INDEX_COL ELSE 0 END) 
AS total_max2 FROM MyTable GROUP BY "
+        + "BOOLEAN_COL";
+    testQuery(filterQuery, nonFilterQuery);
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org


Reply via email to