This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 89cd958  Support post-aggregation in SELECT (#5867)
89cd958 is described below

commit 89cd958dbabce88befa7561d3d4a25a9f97afaaa
Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com>
AuthorDate: Tue Aug 18 14:41:59 2020 -0700

    Support post-aggregation in SELECT (#5867)
    
    Add `PostAggregationHandler` to handle the post-aggregation calculation and 
column re-ordering for the aggregation result
    Enhance `AggregationDataTableReducer` and `GroupByDataTableReducer` to 
support post-aggregation in SELECT
---
 .../query/reduce/AggregationDataTableReducer.java  |  79 +++----
 .../core/query/reduce/GroupByDataTableReducer.java | 122 +++--------
 .../core/query/reduce/PostAggregationHandler.java  | 243 +++++++++++++++++++++
 .../core/query/reduce/ResultReducerFactory.java    |   4 +-
 .../query/reduce/PostAggregationHandlerTest.java   | 120 ++++++++++
 .../tests/BaseClusterIntegrationTestSet.java       |  10 +
 6 files changed, 450 insertions(+), 128 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
index 855fbe0..3234a30 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
@@ -28,6 +28,7 @@ import 
org.apache.pinot.common.response.broker.AggregationResult;
 import org.apache.pinot.common.response.broker.BrokerResponseNative;
 import org.apache.pinot.common.response.broker.ResultTable;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.common.utils.DataTable;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import 
org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
@@ -41,12 +42,14 @@ import org.apache.pinot.core.util.QueryOptions;
  */
 @SuppressWarnings({"rawtypes", "unchecked"})
 public class AggregationDataTableReducer implements DataTableReducer {
+  private final QueryContext _queryContext;
   private final AggregationFunction[] _aggregationFunctions;
   private final boolean _preserveType;
   private final boolean _responseFormatSql;
 
-  AggregationDataTableReducer(QueryContext queryContext, AggregationFunction[] 
aggregationFunctions) {
-    _aggregationFunctions = aggregationFunctions;
+  AggregationDataTableReducer(QueryContext queryContext) {
+    _queryContext = queryContext;
+    _aggregationFunctions = queryContext.getAggregationFunctions();
     QueryOptions queryOptions = new 
QueryOptions(queryContext.getQueryOptions());
     _preserveType = queryOptions.isPreserveType();
     _responseFormatSql = queryOptions.isResponseFormatSQL();
@@ -63,8 +66,9 @@ public class AggregationDataTableReducer implements 
DataTableReducer {
       BrokerMetrics brokerMetrics) {
     if (dataTableMap.isEmpty()) {
       if (_responseFormatSql) {
-        DataSchema finalDataSchema = getResultTableDataSchema();
-        brokerResponseNative.setResultTable(new ResultTable(finalDataSchema, 
Collections.emptyList()));
+        DataSchema resultTableSchema =
+            new PostAggregationHandler(_queryContext, 
getPrePostAggregationDataSchema()).getResultDataSchema();
+        brokerResponseNative.setResultTable(new ResultTable(resultTableSchema, 
Collections.emptyList()));
       }
       return;
     }
@@ -75,7 +79,7 @@ public class AggregationDataTableReducer implements 
DataTableReducer {
     for (DataTable dataTable : dataTableMap.values()) {
       for (int i = 0; i < numAggregationFunctions; i++) {
         Object intermediateResultToMerge;
-        DataSchema.ColumnDataType columnDataType = 
dataSchema.getColumnDataType(i);
+        ColumnDataType columnDataType = dataSchema.getColumnDataType(i);
         switch (columnDataType) {
           case LONG:
             intermediateResultToMerge = dataTable.getLong(0, i);
@@ -97,63 +101,62 @@ public class AggregationDataTableReducer implements 
DataTableReducer {
         }
       }
     }
+    Serializable[] finalResults = new Serializable[numAggregationFunctions];
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      finalResults[i] = AggregationFunctionUtils
+          
.getSerializableValue(_aggregationFunctions[i].extractFinalResult(intermediateResults[i]));
+    }
 
     if (_responseFormatSql) {
-      
brokerResponseNative.setResultTable(reduceToResultTable(intermediateResults));
+      brokerResponseNative.setResultTable(reduceToResultTable(finalResults));
     } else {
-      
brokerResponseNative.setAggregationResults(reduceToAggregationResult(intermediateResults,
 dataSchema));
+      
brokerResponseNative.setAggregationResults(reduceToAggregationResults(finalResults,
 dataSchema));
     }
   }
 
   /**
    * Sets aggregation results into ResultsTable
    */
-  private ResultTable reduceToResultTable(Object[] intermediateResults) {
-    List<Object[]> rows = new ArrayList<>(1);
-    int numAggregationFunctions = _aggregationFunctions.length;
-    Object[] row = new Object[numAggregationFunctions];
-    for (int i = 0; i < numAggregationFunctions; i++) {
-      row[i] = AggregationFunctionUtils
-          
.getSerializableValue(_aggregationFunctions[i].extractFinalResult(intermediateResults[i]));
-    }
-    rows.add(row);
-
-    DataSchema finalDataSchema = getResultTableDataSchema();
-    return new ResultTable(finalDataSchema, rows);
+  private ResultTable reduceToResultTable(Object[] finalResults) {
+    PostAggregationHandler postAggregationHandler =
+        new PostAggregationHandler(_queryContext, 
getPrePostAggregationDataSchema());
+    DataSchema resultTableSchema = 
postAggregationHandler.getResultDataSchema();
+    Object[] resultRow = postAggregationHandler.getResult(finalResults);
+    return new ResultTable(resultTableSchema, 
Collections.singletonList(resultRow));
   }
 
   /**
    * Sets aggregation results into AggregationResults
    */
-  private List<AggregationResult> reduceToAggregationResult(Object[] 
intermediateResults, DataSchema dataSchema) {
-    // Extract final results and set them into the broker response
+  private List<AggregationResult> reduceToAggregationResults(Serializable[] 
finalResults, DataSchema dataSchema) {
     int numAggregationFunctions = _aggregationFunctions.length;
-    List<AggregationResult> reducedAggregationResults = new 
ArrayList<>(numAggregationFunctions);
-    for (int i = 0; i < numAggregationFunctions; i++) {
-      Serializable resultValue = AggregationFunctionUtils
-          
.getSerializableValue(_aggregationFunctions[i].extractFinalResult(intermediateResults[i]));
-
-      // Format the value into string if required
-      if (!_preserveType) {
-        resultValue = AggregationFunctionUtils.formatValue(resultValue);
+    List<AggregationResult> aggregationResults = new 
ArrayList<>(numAggregationFunctions);
+    if (_preserveType) {
+      for (int i = 0; i < numAggregationFunctions; i++) {
+        aggregationResults.add(new 
AggregationResult(dataSchema.getColumnName(i), finalResults[i]));
+      }
+    } else {
+      // Format the values into strings
+      for (int i = 0; i < numAggregationFunctions; i++) {
+        aggregationResults.add(
+            new AggregationResult(dataSchema.getColumnName(i), 
AggregationFunctionUtils.formatValue(finalResults[i])));
       }
-      reducedAggregationResults.add(new 
AggregationResult(dataSchema.getColumnName(i), resultValue));
     }
-    return reducedAggregationResults;
+    return aggregationResults;
   }
 
   /**
-   * Constructs the data schema for the final results table
+   * Constructs the DataSchema for the rows before the post-aggregation (SQL 
mode).
    */
-  private DataSchema getResultTableDataSchema() {
+  private DataSchema getPrePostAggregationDataSchema() {
     int numAggregationFunctions = _aggregationFunctions.length;
-    String[] finalColumnNames = new String[numAggregationFunctions];
-    DataSchema.ColumnDataType[] finalColumnDataTypes = new 
DataSchema.ColumnDataType[numAggregationFunctions];
+    String[] columnNames = new String[numAggregationFunctions];
+    ColumnDataType[] columnDataTypes = new 
ColumnDataType[numAggregationFunctions];
     for (int i = 0; i < numAggregationFunctions; i++) {
       AggregationFunction aggregationFunction = _aggregationFunctions[i];
-      finalColumnNames[i] = aggregationFunction.getResultColumnName();
-      finalColumnDataTypes[i] = aggregationFunction.getFinalResultColumnType();
+      columnNames[i] = aggregationFunction.getResultColumnName();
+      columnDataTypes[i] = aggregationFunction.getFinalResultColumnType();
     }
-    return new DataSchema(finalColumnNames, finalColumnDataTypes);
+    return new DataSchema(columnNames, columnDataTypes);
   }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
index 066b755..ea8e9c7 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.core.query.reduce;
 
-import com.google.common.base.Preconditions;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -34,6 +33,7 @@ import 
org.apache.pinot.common.response.broker.BrokerResponseNative;
 import org.apache.pinot.common.response.broker.GroupByResult;
 import org.apache.pinot.common.response.broker.ResultTable;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.common.utils.DataTable;
 import org.apache.pinot.common.utils.HashUtil;
 import org.apache.pinot.core.data.table.ConcurrentIndexedTable;
@@ -44,7 +44,6 @@ import 
org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils
 import 
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByTrimmingService;
 import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
 import org.apache.pinot.core.query.request.context.ExpressionContext;
-import org.apache.pinot.core.query.request.context.FunctionContext;
 import org.apache.pinot.core.query.request.context.QueryContext;
 import org.apache.pinot.core.transport.ServerRoutingInstance;
 import org.apache.pinot.core.util.GroupByUtils;
@@ -67,10 +66,11 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
   private final boolean _responseFormatSql;
   private final boolean _sqlQuery;
 
-  GroupByDataTableReducer(QueryContext queryContext, AggregationFunction[] 
aggregationFunctions) {
+  GroupByDataTableReducer(QueryContext queryContext) {
     _queryContext = queryContext;
-    _aggregationFunctions = aggregationFunctions;
-    _numAggregationFunctions = aggregationFunctions.length;
+    _aggregationFunctions = queryContext.getAggregationFunctions();
+    assert _aggregationFunctions != null;
+    _numAggregationFunctions = _aggregationFunctions.length;
     _groupByExpressions = queryContext.getGroupByExpressions();
     assert _groupByExpressions != null;
     _numGroupByExpressions = _groupByExpressions.size();
@@ -162,103 +162,49 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
    */
   private void setSQLGroupByInResultTable(BrokerResponseNative 
brokerResponseNative, DataSchema dataSchema,
       Collection<DataTable> dataTables) {
-    DataSchema resultTableSchema = getSQLResultTableSchema(dataSchema);
     IndexedTable indexedTable = getIndexedTable(dataSchema, dataTables);
     Iterator<Record> sortedIterator = indexedTable.iterator();
+    DataSchema prePostAggregationDataSchema = 
getPrePostAggregationDataSchema(dataSchema);
     int limit = _queryContext.getLimit();
     List<Object[]> rows = new ArrayList<>(limit);
+    for (int i = 0; i < limit && sortedIterator.hasNext(); i++) {
+      Object[] row = sortedIterator.next().getValues();
+      for (int j = 0; j < _numAggregationFunctions; j++) {
+        int valueIndex = j + _numGroupByExpressions;
+        row[valueIndex] =
+            
AggregationFunctionUtils.getSerializableValue(_aggregationFunctions[j].extractFinalResult(row[valueIndex]));
+      }
+      rows.add(row);
+    }
 
     if (_sqlQuery) {
       // SQL query with SQL group-by mode and response format
-      // NOTE: For SQL query, need to reorder the columns in the data table 
based on the select expressions.
-
-      int[] selectExpressionIndexMap = getSelectExpressionIndexMap();
-      int numSelectExpressions = selectExpressionIndexMap.length;
-      String[] columnNames = resultTableSchema.getColumnNames();
-      DataSchema.ColumnDataType[] columnDataTypes = 
resultTableSchema.getColumnDataTypes();
-      String[] reorderedColumnNames = new String[numSelectExpressions];
-      DataSchema.ColumnDataType[] reorderedColumnDataTypes = new 
DataSchema.ColumnDataType[numSelectExpressions];
-      resultTableSchema = new DataSchema(reorderedColumnNames, 
reorderedColumnDataTypes);
-      for (int i = 0; i < numSelectExpressions; i++) {
-        reorderedColumnNames[i] = columnNames[selectExpressionIndexMap[i]];
-        reorderedColumnDataTypes[i] = 
columnDataTypes[selectExpressionIndexMap[i]];
-      }
-      while (rows.size() < limit && sortedIterator.hasNext()) {
-        Record nextRecord = sortedIterator.next();
-        Object[] values = nextRecord.getValues();
-        for (int i = 0; i < _numAggregationFunctions; i++) {
-          int valueIndex = i + _numGroupByExpressions;
-          values[valueIndex] = AggregationFunctionUtils
-              
.getSerializableValue(_aggregationFunctions[i].extractFinalResult(values[valueIndex]));
-        }
 
-        Object[] reorderedValues = new Object[numSelectExpressions];
-        for (int i = 0; i < numSelectExpressions; i++) {
-          reorderedValues[i] = values[selectExpressionIndexMap[i]];
-        }
-        rows.add(reorderedValues);
-      }
+      PostAggregationHandler postAggregationHandler =
+          new PostAggregationHandler(_queryContext, 
prePostAggregationDataSchema);
+      DataSchema resultTableSchema = 
postAggregationHandler.getResultDataSchema();
+      rows.replaceAll(postAggregationHandler::getResult);
+      brokerResponseNative.setResultTable(new ResultTable(resultTableSchema, 
rows));
     } else {
       // PQL query with SQL group-by mode and response format
+      // NOTE: For PQL query, keep the order of columns as is (group-by 
expressions followed by aggregations), no need
+      //       to perform post-aggregation.
 
-      while (rows.size() < limit && sortedIterator.hasNext()) {
-        Record nextRecord = sortedIterator.next();
-        Object[] values = nextRecord.getValues();
-        for (int i = 0; i < _numAggregationFunctions; i++) {
-          int valueIndex = i + _numGroupByExpressions;
-          values[valueIndex] = AggregationFunctionUtils
-              
.getSerializableValue(_aggregationFunctions[i].extractFinalResult(values[valueIndex]));
-        }
-        rows.add(values);
-      }
-    }
-
-    brokerResponseNative.setResultTable(new ResultTable(resultTableSchema, 
rows));
-  }
-
-  /**
-   * Helper method to generate a map from the expression index in the select 
expressions to the column index in the data
-   * schema. This map is used to reorder the expressions according to the 
select expressions.
-   */
-  private int[] getSelectExpressionIndexMap() {
-    List<ExpressionContext> selectExpressions = 
_queryContext.getSelectExpressions();
-    List<ExpressionContext> groupByExpressions = 
_queryContext.getGroupByExpressions();
-    assert groupByExpressions != null;
-    int numSelectExpressions = selectExpressions.size();
-    int[] selectExpressionIndexMap = new int[numSelectExpressions];
-    int aggregationExpressionIndex = _numGroupByExpressions;
-    for (int i = 0; i < numSelectExpressions; i++) {
-      ExpressionContext selectExpression = selectExpressions.get(i);
-      if (selectExpression.getType() == ExpressionContext.Type.FUNCTION
-          && selectExpression.getFunction().getType() == 
FunctionContext.Type.AGGREGATION) {
-        selectExpressionIndexMap[i] = aggregationExpressionIndex++;
-      } else {
-        int indexInGroupByExpressions = 
groupByExpressions.indexOf(selectExpression);
-        Preconditions.checkState(indexInGroupByExpressions >= 0,
-            "Select expression: %s is not an aggregation expression and not 
contained in the group-by expressions");
-        selectExpressionIndexMap[i] = indexInGroupByExpressions;
-      }
+      brokerResponseNative.setResultTable(new 
ResultTable(prePostAggregationDataSchema, rows));
     }
-    return selectExpressionIndexMap;
   }
 
   /**
-   * Constructs the final result table schema for sql mode execution
-   * The data type for the aggregations needs to be taken from the final 
result's data type
+   * Constructs the DataSchema for the rows before the post-aggregation (SQL 
mode).
    */
-  private DataSchema getSQLResultTableSchema(DataSchema dataSchema) {
-    String[] columns = dataSchema.getColumnNames();
-    DataSchema.ColumnDataType[] finalColumnDataTypes = new 
DataSchema.ColumnDataType[_numColumns];
-    int aggIdx = 0;
-    for (int i = 0; i < _numColumns; i++) {
-      if (i < _numGroupByExpressions) {
-        finalColumnDataTypes[i] = dataSchema.getColumnDataType(i);
-      } else {
-        finalColumnDataTypes[i] = 
_aggregationFunctions[aggIdx].getFinalResultColumnType();
-        aggIdx++;
-      }
+  private DataSchema getPrePostAggregationDataSchema(DataSchema dataSchema) {
+    String[] columnNames = dataSchema.getColumnNames();
+    ColumnDataType[] columnDataTypes = new ColumnDataType[_numColumns];
+    System.arraycopy(dataSchema.getColumnDataTypes(), 0, columnDataTypes, 0, 
_numGroupByExpressions);
+    for (int i = 0; i < _numAggregationFunctions; i++) {
+      columnDataTypes[i + _numGroupByExpressions] = 
_aggregationFunctions[i].getFinalResultColumnType();
     }
-    return new DataSchema(columns, finalColumnDataTypes);
+    return new DataSchema(columnNames, columnDataTypes);
   }
 
   private IndexedTable getIndexedTable(DataSchema dataSchema, 
Collection<DataTable> dataTables) {
@@ -268,7 +214,7 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
     for (DataTable dataTable : dataTables) {
       BiFunction[] functions = new BiFunction[_numColumns];
       for (int i = 0; i < _numColumns; i++) {
-        DataSchema.ColumnDataType columnDataType = 
dataSchema.getColumnDataType(i);
+        ColumnDataType columnDataType = dataSchema.getColumnDataType(i);
         BiFunction<Integer, Integer, Object> function;
         switch (columnDataType) {
           case INT:
@@ -394,10 +340,10 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
    */
   private DataSchema getPQLResultTableSchema(AggregationFunction 
aggregationFunction) {
     String[] columnNames = new String[_numColumns];
-    DataSchema.ColumnDataType[] columnDataTypes = new 
DataSchema.ColumnDataType[_numColumns];
+    ColumnDataType[] columnDataTypes = new ColumnDataType[_numColumns];
     for (int i = 0; i < _numGroupByExpressions; i++) {
       columnNames[i] = _groupByExpressions.get(i).toString();
-      columnDataTypes[i] = DataSchema.ColumnDataType.STRING;
+      columnDataTypes[i] = ColumnDataType.STRING;
     }
     columnNames[_numGroupByExpressions] = 
aggregationFunction.getResultColumnName();
     columnDataTypes[_numGroupByExpressions] = 
aggregationFunction.getFinalResultColumnType();
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java
new file mode 100644
index 0000000..fdf3cd9
--- /dev/null
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java
@@ -0,0 +1,243 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.reduce;
+
+import com.google.common.base.Preconditions;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.query.postaggregation.PostAggregationFunction;
+import org.apache.pinot.core.query.request.context.ExpressionContext;
+import org.apache.pinot.core.query.request.context.FunctionContext;
+import org.apache.pinot.core.query.request.context.QueryContext;
+
+
+/**
+ * The {@code PostAggregationHandler} handles the post-aggregation calculation 
as well as the column re-ordering for the
+ * aggregation result.
+ */
+public class PostAggregationHandler {
+  private final Map<FunctionContext, Integer> _aggregationFunctionIndexMap;
+  private final int _numGroupByExpressions;
+  private final Map<ExpressionContext, Integer> _groupByExpressionIndexMap;
+  private final DataSchema _dataSchema;
+  private final ValueExtractor[] _valueExtractors;
+  private final DataSchema _resultDataSchema;
+
+  public PostAggregationHandler(QueryContext queryContext, DataSchema 
dataSchema) {
+    _aggregationFunctionIndexMap = 
queryContext.getAggregationFunctionIndexMap();
+    assert _aggregationFunctionIndexMap != null;
+    List<ExpressionContext> groupByExpressions = 
queryContext.getGroupByExpressions();
+    if (groupByExpressions != null) {
+      _numGroupByExpressions = groupByExpressions.size();
+      _groupByExpressionIndexMap = new HashMap<>();
+      for (int i = 0; i < _numGroupByExpressions; i++) {
+        _groupByExpressionIndexMap.put(groupByExpressions.get(i), i);
+      }
+    } else {
+      _numGroupByExpressions = 0;
+      _groupByExpressionIndexMap = null;
+    }
+
+    // NOTE: The data schema will always have group-by expressions in the 
front, followed by aggregation functions of
+    //       the same order as in the query context. This is handled in 
AggregationGroupByOrderByOperator.
+    _dataSchema = dataSchema;
+
+    List<ExpressionContext> selectExpressions = 
queryContext.getSelectExpressions();
+    int numSelectExpressions = selectExpressions.size();
+    _valueExtractors = new ValueExtractor[numSelectExpressions];
+    String[] columnNames = new String[numSelectExpressions];
+    ColumnDataType[] columnDataTypes = new 
ColumnDataType[numSelectExpressions];
+    for (int i = 0; i < numSelectExpressions; i++) {
+      ValueExtractor valueExtractor = 
getValueExtractor(selectExpressions.get(i));
+      _valueExtractors[i] = valueExtractor;
+      columnNames[i] = valueExtractor.getColumnName();
+      columnDataTypes[i] = valueExtractor.getColumnDataType();
+    }
+    _resultDataSchema = new DataSchema(columnNames, columnDataTypes);
+  }
+
+  /**
+   * Returns the DataSchema of the post-aggregation result.
+   */
+  public DataSchema getResultDataSchema() {
+    return _resultDataSchema;
+  }
+
+  /**
+   * Returns the post-aggregation result for the given row.
+   */
+  public Object[] getResult(Object[] row) {
+    int numValues = _valueExtractors.length;
+    Object[] result = new Object[numValues];
+    for (int i = 0; i < numValues; i++) {
+      result[i] = _valueExtractors[i].extract(row);
+    }
+    return result;
+  }
+
+  /**
+   * Returns a ValueExtractor based on the given expression.
+   */
+  public ValueExtractor getValueExtractor(ExpressionContext expression) {
+    if (expression.getType() == ExpressionContext.Type.LITERAL) {
+      // Literal
+      return new LiteralValueExtractor(expression.getLiteral());
+    }
+    if (_numGroupByExpressions > 0) {
+      Integer groupByExpressionIndex = 
_groupByExpressionIndexMap.get(expression);
+      if (groupByExpressionIndex != null) {
+        // Group-by expression
+        return new ColumnValueExtractor(groupByExpressionIndex);
+      }
+    }
+    FunctionContext function = expression.getFunction();
+    Preconditions
+        .checkState(function != null, "Failed to find SELECT expression: %s in 
the GROUP-BY clause", expression);
+    if (function.getType() == FunctionContext.Type.AGGREGATION) {
+      // Aggregation function
+      return new 
ColumnValueExtractor(_aggregationFunctionIndexMap.get(function) + 
_numGroupByExpressions);
+    } else {
+      // Post-aggregation function
+      return new PostAggregationValueExtractor(function);
+    }
+  }
+
+  /**
+   * Value extractor for the post-aggregation function.
+   */
+  public interface ValueExtractor {
+
+    /**
+     * Returns the column name for the value extracted.
+     */
+    String getColumnName();
+
+    /**
+     * Returns the ColumnDataType of the value extracted.
+     */
+    ColumnDataType getColumnDataType();
+
+    /**
+     * Extracts the value from the given row.
+     */
+    Object extract(Object[] row);
+  }
+
+  /**
+   * Value extractor for a literal.
+   */
+  private static class LiteralValueExtractor implements ValueExtractor {
+    final String _literal;
+
+    LiteralValueExtractor(String literal) {
+      _literal = literal;
+    }
+
+    @Override
+    public String getColumnName() {
+      return '\'' + _literal + '\'';
+    }
+
+    @Override
+    public ColumnDataType getColumnDataType() {
+      return ColumnDataType.STRING;
+    }
+
+    @Override
+    public Object extract(Object[] row) {
+      return _literal;
+    }
+  }
+
+  /**
+   * Value extractor for a non-post-aggregation column (group-by expression or 
aggregation).
+   */
+  private class ColumnValueExtractor implements ValueExtractor {
+    final int _index;
+
+    ColumnValueExtractor(int index) {
+      _index = index;
+    }
+
+    @Override
+    public String getColumnName() {
+      return _dataSchema.getColumnName(_index);
+    }
+
+    @Override
+    public ColumnDataType getColumnDataType() {
+      return _dataSchema.getColumnDataType(_index);
+    }
+
+    @Override
+    public Object extract(Object[] row) {
+      return row[_index];
+    }
+  }
+
+  /**
+   * Value extractor for a post-aggregation column.
+   */
+  private class PostAggregationValueExtractor implements ValueExtractor {
+    final FunctionContext _function;
+    final Object[] _arguments;
+    final ValueExtractor[] _argumentExtractors;
+    final PostAggregationFunction _postAggregationFunction;
+
+    PostAggregationValueExtractor(FunctionContext function) {
+      assert function.getType() == FunctionContext.Type.TRANSFORM;
+
+      _function = function;
+      List<ExpressionContext> arguments = function.getArguments();
+      int numArguments = arguments.size();
+      _arguments = new Object[numArguments];
+      _argumentExtractors = new ValueExtractor[numArguments];
+      ColumnDataType[] argumentTypes = new ColumnDataType[numArguments];
+      for (int i = 0; i < numArguments; i++) {
+        ExpressionContext argument = arguments.get(i);
+        ValueExtractor argumentExtractor = getValueExtractor(argument);
+        _argumentExtractors[i] = argumentExtractor;
+        argumentTypes[i] = argumentExtractor.getColumnDataType();
+      }
+      _postAggregationFunction = new 
PostAggregationFunction(function.getFunctionName(), argumentTypes);
+    }
+
+    @Override
+    public String getColumnName() {
+      return _function.toString();
+    }
+
+    @Override
+    public ColumnDataType getColumnDataType() {
+      return _postAggregationFunction.getResultType();
+    }
+
+    @Override
+    public Object extract(Object[] row) {
+      int numArguments = _arguments.length;
+      for (int i = 0; i < numArguments; i++) {
+        _arguments[i] = _argumentExtractors[i].extract(row);
+      }
+      return _postAggregationFunction.invoke(_arguments);
+    }
+  }
+}
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java
index 66c6313..04f97d6 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java
@@ -46,11 +46,11 @@ public final class ResultReducerFactory {
           // Distinct query
           return new DistinctDataTableReducer(queryContext, 
(DistinctAggregationFunction) aggregationFunctions[0]);
         } else {
-          return new AggregationDataTableReducer(queryContext, 
aggregationFunctions);
+          return new AggregationDataTableReducer(queryContext);
         }
       } else {
         // Aggregation group-by query
-        return new GroupByDataTableReducer(queryContext, aggregationFunctions);
+        return new GroupByDataTableReducer(queryContext);
       }
     }
   }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/query/reduce/PostAggregationHandlerTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/query/reduce/PostAggregationHandlerTest.java
new file mode 100644
index 0000000..38da116
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/query/reduce/PostAggregationHandlerTest.java
@@ -0,0 +1,120 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.reduce;
+
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.query.request.context.QueryContext;
+import 
org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+
+
+public class PostAggregationHandlerTest {
+
+  @Test
+  public void testPostAggregation() {
+    // Regular aggregation only
+    {
+      QueryContext queryContext = 
QueryContextConverterUtils.getQueryContextFromSQL("SELECT SUM(m1) FROM 
testTable");
+      DataSchema dataSchema = new DataSchema(new String[]{"sum(m1)"}, new 
ColumnDataType[]{ColumnDataType.DOUBLE});
+      PostAggregationHandler handler = new 
PostAggregationHandler(queryContext, dataSchema);
+      assertEquals(handler.getResultDataSchema(), dataSchema);
+      assertEquals(handler.getResult(new Object[]{1.0}), new Object[]{1.0});
+      assertEquals(handler.getResult(new Object[]{2.0}), new Object[]{2.0});
+    }
+
+    // Regular aggregation group-by
+    {
+      QueryContext queryContext =
+          QueryContextConverterUtils.getQueryContextFromSQL("SELECT d1, 
SUM(m1) FROM testTable GROUP BY d1");
+      DataSchema dataSchema = new DataSchema(new String[]{"d1", "sum(m1)"},
+          new ColumnDataType[]{ColumnDataType.INT, ColumnDataType.DOUBLE});
+      PostAggregationHandler handler = new 
PostAggregationHandler(queryContext, dataSchema);
+      assertEquals(handler.getResultDataSchema(), dataSchema);
+      assertEquals(handler.getResult(new Object[]{1, 2.0}), new Object[]{1, 
2.0});
+      assertEquals(handler.getResult(new Object[]{3, 4.0}), new Object[]{3, 
4.0});
+    }
+
+    // Aggregation group-by with partial columns selected
+    {
+      QueryContext queryContext =
+          QueryContextConverterUtils.getQueryContextFromSQL("SELECT SUM(m1), 
d2 FROM testTable GROUP BY d1, d2");
+      DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", 
"sum(m1)"},
+          new ColumnDataType[]{ColumnDataType.INT, ColumnDataType.LONG, 
ColumnDataType.DOUBLE});
+      PostAggregationHandler handler = new 
PostAggregationHandler(queryContext, dataSchema);
+      DataSchema resultDataSchema = handler.getResultDataSchema();
+      assertEquals(resultDataSchema.size(), 2);
+      assertEquals(resultDataSchema.getColumnNames(), new String[]{"sum(m1)", 
"d2"});
+      assertEquals(resultDataSchema.getColumnDataTypes(),
+          new ColumnDataType[]{ColumnDataType.DOUBLE, ColumnDataType.LONG});
+      assertEquals(handler.getResult(new Object[]{1, 2L, 3.0}), new 
Object[]{3.0, 2L});
+      assertEquals(handler.getResult(new Object[]{4, 5L, 6.0}), new 
Object[]{6.0, 5L});
+    }
+
+    // Aggregation group-by with order-by
+    {
+      QueryContext queryContext = QueryContextConverterUtils
+          .getQueryContextFromSQL("SELECT SUM(m1), d2 FROM testTable GROUP BY 
d1, d2 ORDER BY MAX(m1)");
+      DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", 
"sum(m1)", "max(m1)"},
+          new ColumnDataType[]{ColumnDataType.INT, ColumnDataType.LONG, 
ColumnDataType.DOUBLE, ColumnDataType.DOUBLE});
+      PostAggregationHandler handler = new 
PostAggregationHandler(queryContext, dataSchema);
+      DataSchema resultDataSchema = handler.getResultDataSchema();
+      assertEquals(resultDataSchema.size(), 2);
+      assertEquals(resultDataSchema.getColumnNames(), new String[]{"sum(m1)", 
"d2"});
+      assertEquals(resultDataSchema.getColumnDataTypes(),
+          new ColumnDataType[]{ColumnDataType.DOUBLE, ColumnDataType.LONG});
+      assertEquals(handler.getResult(new Object[]{1, 2L, 3.0, 4.0}), new 
Object[]{3.0, 2L});
+      assertEquals(handler.getResult(new Object[]{5, 6L, 7.0, 8.0}), new 
Object[]{7.0, 6L});
+    }
+
+    // Post aggregation
+    {
+      QueryContext queryContext =
+          QueryContextConverterUtils.getQueryContextFromSQL("SELECT SUM(m1) + 
MAX(m2) FROM testTable");
+      DataSchema dataSchema = new DataSchema(new String[]{"sum(m1)", 
"max(m2)"},
+          new ColumnDataType[]{ColumnDataType.DOUBLE, ColumnDataType.DOUBLE});
+      PostAggregationHandler handler = new 
PostAggregationHandler(queryContext, dataSchema);
+      DataSchema resultDataSchema = handler.getResultDataSchema();
+      assertEquals(resultDataSchema.size(), 1);
+      assertEquals(resultDataSchema.getColumnName(0), "plus(sum(m1),max(m2))");
+      assertEquals(resultDataSchema.getColumnDataType(0), 
ColumnDataType.DOUBLE);
+      assertEquals(handler.getResult(new Object[]{1.0, 2.0}), new 
Object[]{3.0});
+      assertEquals(handler.getResult(new Object[]{3.0, 4.0}), new 
Object[]{7.0});
+    }
+
+    // Post aggregation with group-by and order-by
+    {
+      QueryContext queryContext = 
QueryContextConverterUtils.getQueryContextFromSQL(
+          "SELECT (SUM(m1) + MAX(m2) - d1) / 2, d2 FROM testTable GROUP BY d1, 
d2 ORDER BY MAX(m1)");
+      DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", 
"sum(m1)", "max(m2)", "max(m1)"},
+          new ColumnDataType[]{ColumnDataType.INT, ColumnDataType.LONG, 
ColumnDataType.DOUBLE, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE});
+      PostAggregationHandler handler = new 
PostAggregationHandler(queryContext, dataSchema);
+      DataSchema resultDataSchema = handler.getResultDataSchema();
+      assertEquals(resultDataSchema.size(), 2);
+      assertEquals(resultDataSchema.getColumnNames(),
+          new String[]{"divide(minus(plus(sum(m1),max(m2)),d1),'2')", "d2"});
+      assertEquals(resultDataSchema.getColumnDataTypes(),
+          new ColumnDataType[]{ColumnDataType.DOUBLE, ColumnDataType.LONG});
+      assertEquals(handler.getResult(new Object[]{1, 2L, 3.0, 4.0, 5.0}), new 
Object[]{3.0, 2L});
+      assertEquals(handler.getResult(new Object[]{6, 7L, 8.0, 9.0, 10.0}), new 
Object[]{5.5, 7L});
+    }
+  }
+}
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
index d02e19d..53e85ce 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
@@ -231,6 +231,16 @@ public abstract class BaseClusterIntegrationTestSet 
extends BaseClusterIntegrati
     testSqlQuery(query, Collections.singletonList(query));
     query = "SELECT MAX(ArrDelay), Month FROM mytable GROUP BY Month ORDER BY 
ABS(Month - 6) + MAX(ArrDelay)";
     testSqlQuery(query, Collections.singletonList(query));
+
+    // Post-aggregation in SELECT
+    query = "SELECT MAX(ArrDelay) + MAX(AirTime) FROM mytable";
+    testSqlQuery(query, Collections.singletonList(query));
+    query =
+        "SELECT MAX(ArrDelay) - MAX(AirTime), DaysSinceEpoch FROM mytable 
GROUP BY DaysSinceEpoch ORDER BY MAX(ArrDelay) - MIN(AirTime) DESC";
+    testSqlQuery(query, Collections.singletonList(query));
+    query =
+        "SELECT DaysSinceEpoch, MAX(ArrDelay) * 2 - MAX(AirTime) - 3 FROM 
mytable GROUP BY DaysSinceEpoch ORDER BY MAX(ArrDelay) - MIN(AirTime) DESC";
+    testSqlQuery(query, Collections.singletonList(query));
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to