This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new b6e81355b5 Allow using 'serverReturnFinalResult' to optimize server 
partitioned table (#13208)
b6e81355b5 is described below

commit b6e81355b57c66dfe53e8cb1717045a4c32cfd71
Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com>
AuthorDate: Sat May 25 12:08:22 2024 -0700

    Allow using 'serverReturnFinalResult' to optimize server partitioned table 
(#13208)
---
 .../common/utils/config/QueryOptionsUtils.java     |   4 +
 .../pinot/common/utils/grpc/GrpcQueryClient.java   |   9 +-
 .../core/data/table/ConcurrentIndexedTable.java    |   7 +-
 .../apache/pinot/core/data/table/IndexedTable.java |  52 ++---
 .../pinot/core/data/table/SimpleIndexedTable.java  |   7 +-
 .../apache/pinot/core/data/table/TableResizer.java |  24 ++-
 .../table/UnboundedConcurrentIndexedTable.java     |   7 +-
 .../operator/combine/GroupByCombineOperator.java   |   8 +-
 .../streaming/StreamingGroupByCombineOperator.java |   8 +-
 .../aggregation/function/AggregationFunction.java  |   8 +
 .../function/AggregationFunctionUtils.java         |  43 +++-
 .../function/BaseBooleanAggregationFunction.java   |   5 +
 .../function/ChildAggregationFunction.java         |   5 +
 .../function/CountAggregationFunction.java         |   5 +
 .../function/DistinctCountAggregationFunction.java |   5 +
 .../DistinctCountBitmapAggregationFunction.java    |   5 +
 .../DistinctCountCPCSketchAggregationFunction.java |   5 +
 .../DistinctCountHLLAggregationFunction.java       |   5 +
 .../DistinctCountHLLPlusAggregationFunction.java   |   5 +
 ...CountIntegerTupleSketchAggregationFunction.java |   5 +
 .../DistinctCountMVAggregationFunction.java        |   5 +
 .../DistinctCountSmartHLLAggregationFunction.java  |   5 +
 ...istinctCountThetaSketchAggregationFunction.java |   5 +
 .../DistinctCountULLAggregationFunction.java       |   5 +
 .../function/DistinctSumAggregationFunction.java   |   5 +
 .../function/DistinctSumMVAggregationFunction.java |   5 +
 .../function/FastHLLAggregationFunction.java       |   5 +
 .../function/MaxAggregationFunction.java           |   5 +
 .../function/MinAggregationFunction.java           |   5 +
 .../function/MinMaxRangeAggregationFunction.java   |   8 +
 ...artitionedDistinctCountAggregationFunction.java |   5 +
 .../function/SumAggregationFunction.java           |   5 +
 .../function/SumPrecisionAggregationFunction.java  |   5 +
 ...aluesIntegerTupleSketchAggregationFunction.java |   5 +
 .../funnel/FunnelCountAggregationFunction.java     |  24 ++-
 .../funnel/FunnelMaxStepAggregationFunction.java   |   5 +
 .../query/reduce/AggregationDataTableReducer.java  |  56 ++++-
 .../core/query/reduce/GroupByDataTableReducer.java | 225 +++++++++++----------
 .../core/query/request/context/QueryContext.java   |  12 ++
 .../MultiNodesOfflineClusterIntegrationTest.java   |  95 +++++++++
 .../tests/OfflineGRPCServerIntegrationTest.java    |  37 ++--
 .../apache/pinot/spi/utils/CommonConstants.java    |  10 +
 42 files changed, 571 insertions(+), 188 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/config/QueryOptionsUtils.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/config/QueryOptionsUtils.java
index 75797dcd14..a7c45c45d3 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/config/QueryOptionsUtils.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/config/QueryOptionsUtils.java
@@ -236,6 +236,10 @@ public class QueryOptionsUtils {
     return 
Boolean.parseBoolean(queryOptions.get(QueryOptionKey.SERVER_RETURN_FINAL_RESULT));
   }
 
+  public static boolean isServerReturnFinalResultKeyUnpartitioned(Map<String, 
String> queryOptions) {
+    return 
Boolean.parseBoolean(queryOptions.get(QueryOptionKey.SERVER_RETURN_FINAL_RESULT_KEY_UNPARTITIONED));
+  }
+
   @Nullable
   public static String getOrderByAlgorithm(Map<String, String> queryOptions) {
     return queryOptions.get(QueryOptionKey.ORDER_BY_ALGORITHM);
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java
index ac05ec70d3..a41a30c5d4 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java
@@ -25,6 +25,7 @@ import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
 import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
 import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
 import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
+import java.io.Closeable;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.Map;
@@ -42,7 +43,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 
-public class GrpcQueryClient {
+public class GrpcQueryClient implements Closeable {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(GrpcQueryClient.class);
   private static final int DEFAULT_CHANNEL_SHUTDOWN_TIMEOUT_SECOND = 10;
   // the key is the hashCode of the TlsConfig, the value is the SslContext
@@ -74,9 +75,8 @@ public class GrpcQueryClient {
     LOGGER.info("Building gRPC SSL context");
     SslContext sslContext = 
CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), 
tlsConfigHashCode -> {
       try {
-        SSLFactory sslFactory =
-            
RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores(
-                tlsConfig, PinotInsecureMode::isPinotInInsecureMode);
+        SSLFactory sslFactory = 
RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores(tlsConfig,
+            PinotInsecureMode::isPinotInInsecureMode);
         SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
         
sslFactory.getKeyManagerFactory().ifPresent(sslContextBuilder::keyManager);
         
sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager);
@@ -98,6 +98,7 @@ public class GrpcQueryClient {
     return _blockingStub.submit(request);
   }
 
+  @Override
   public void close() {
     if (!_managedChannel.isShutdown()) {
       try {
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java
index b5f8d6e0d0..119d47c79e 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java
@@ -34,7 +34,12 @@ public class ConcurrentIndexedTable extends IndexedTable {
 
   public ConcurrentIndexedTable(DataSchema dataSchema, QueryContext 
queryContext, int resultSize, int trimSize,
       int trimThreshold) {
-    super(dataSchema, queryContext, resultSize, trimSize, trimThreshold, new 
ConcurrentHashMap<>());
+    this(dataSchema, false, queryContext, resultSize, trimSize, trimThreshold);
+  }
+
+  public ConcurrentIndexedTable(DataSchema dataSchema, boolean hasFinalInput, 
QueryContext queryContext, int resultSize,
+      int trimSize, int trimThreshold) {
+    super(dataSchema, hasFinalInput, queryContext, resultSize, trimSize, 
trimThreshold, new ConcurrentHashMap<>());
   }
 
   /**
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java
index 012fdc1170..04598f424c 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java
@@ -38,6 +38,7 @@ import 
org.apache.pinot.core.query.request.context.QueryContext;
 @SuppressWarnings({"rawtypes", "unchecked"})
 public abstract class IndexedTable extends BaseTable {
   protected final Map<Key, Record> _lookupMap;
+  protected final boolean _hasFinalInput;
   protected final int _resultSize;
   protected final int _numKeyColumns;
   protected final AggregationFunction[] _aggregationFunctions;
@@ -54,16 +55,18 @@ public abstract class IndexedTable extends BaseTable {
    * Constructor for the IndexedTable.
    *
    * @param dataSchema    Data schema of the table
+   * @param hasFinalInput Whether the input is the final aggregate result
    * @param queryContext  Query context
    * @param resultSize    Number of records to keep in the final result after 
calling {@link #finish(boolean, boolean)}
    * @param trimSize      Number of records to keep when trimming the table
    * @param trimThreshold Trim the table when the number of records exceeds 
the threshold
    * @param lookupMap     Map from keys to records
    */
-  protected IndexedTable(DataSchema dataSchema, QueryContext queryContext, int 
resultSize, int trimSize,
-      int trimThreshold, Map<Key, Record> lookupMap) {
+  protected IndexedTable(DataSchema dataSchema, boolean hasFinalInput, 
QueryContext queryContext, int resultSize,
+      int trimSize, int trimThreshold, Map<Key, Record> lookupMap) {
     super(dataSchema);
     _lookupMap = lookupMap;
+    _hasFinalInput = hasFinalInput;
     _resultSize = resultSize;
 
     List<ExpressionContext> groupByExpressions = 
queryContext.getGroupByExpressions();
@@ -74,7 +77,7 @@ public abstract class IndexedTable extends BaseTable {
     if (orderByExpressions != null) {
       // GROUP BY with ORDER BY
       _hasOrderBy = true;
-      _tableResizer = new TableResizer(dataSchema, queryContext);
+      _tableResizer = new TableResizer(dataSchema, hasFinalInput, 
queryContext);
       // NOTE: trimSize is bounded by trimThreshold/2 to protect the server 
from using too much memory.
       // TODO: Re-evaluate it as it can lead to in-accurate results
       _trimSize = Math.min(trimSize, trimThreshold / 2);
@@ -102,34 +105,32 @@ public abstract class IndexedTable extends BaseTable {
    * Adds a record with new key or updates a record with existing key.
    */
   protected void addOrUpdateRecord(Key key, Record newRecord) {
-    _lookupMap.compute(key, (k, v) -> {
-      if (v == null) {
-        return newRecord;
-      } else {
-        Object[] existingValues = v.getValues();
-        Object[] newValues = newRecord.getValues();
-        int aggNum = 0;
-        for (int i = _numKeyColumns; i < _numColumns; i++) {
-          existingValues[i] = 
_aggregationFunctions[aggNum++].merge(existingValues[i], newValues[i]);
-        }
-        return v;
-      }
-    });
+    _lookupMap.compute(key, (k, v) -> v == null ? newRecord : updateRecord(v, 
newRecord));
   }
 
   /**
    * Updates a record with existing key. Record with new key will be ignored.
    */
   protected void updateExistingRecord(Key key, Record newRecord) {
-    _lookupMap.computeIfPresent(key, (k, v) -> {
-      Object[] existingValues = v.getValues();
-      Object[] newValues = newRecord.getValues();
-      int aggNum = 0;
-      for (int i = _numKeyColumns; i < _numColumns; i++) {
-        existingValues[i] = 
_aggregationFunctions[aggNum++].merge(existingValues[i], newValues[i]);
+    _lookupMap.computeIfPresent(key, (k, v) -> updateRecord(v, newRecord));
+  }
+
+  private Record updateRecord(Record existingRecord, Record newRecord) {
+    Object[] existingValues = existingRecord.getValues();
+    Object[] newValues = newRecord.getValues();
+    int numAggregations = _aggregationFunctions.length;
+    int index = _numKeyColumns;
+    if (!_hasFinalInput) {
+      for (int i = 0; i < numAggregations; i++, index++) {
+        existingValues[index] = 
_aggregationFunctions[i].merge(existingValues[index], newValues[index]);
       }
-      return v;
-    });
+    } else {
+      for (int i = 0; i < numAggregations; i++, index++) {
+        existingValues[index] = 
_aggregationFunctions[i].mergeFinalResult((Comparable) existingValues[index],
+            (Comparable) newValues[index]);
+      }
+    }
+    return existingRecord;
   }
 
   /**
@@ -156,7 +157,8 @@ public abstract class IndexedTable extends BaseTable {
       _topRecords = _lookupMap.values();
     }
     // TODO: Directly return final result in _tableResizer.getTopRecords to 
avoid extracting final result multiple times
-    if (storeFinalResult) {
+    assert !(_hasFinalInput && !storeFinalResult);
+    if (storeFinalResult && !_hasFinalInput) {
       ColumnDataType[] columnDataTypes = _dataSchema.getColumnDataTypes();
       int numAggregationFunctions = _aggregationFunctions.length;
       for (int i = 0; i < numAggregationFunctions; i++) {
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java
index 800c649112..2163620225 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java
@@ -32,7 +32,12 @@ public class SimpleIndexedTable extends IndexedTable {
 
   public SimpleIndexedTable(DataSchema dataSchema, QueryContext queryContext, 
int resultSize, int trimSize,
       int trimThreshold) {
-    super(dataSchema, queryContext, resultSize, trimSize, trimThreshold, new 
HashMap<>());
+    this(dataSchema, false, queryContext, resultSize, trimSize, trimThreshold);
+  }
+
+  public SimpleIndexedTable(DataSchema dataSchema, boolean hasFinalInput, 
QueryContext queryContext, int resultSize,
+      int trimSize, int trimThreshold) {
+    super(dataSchema, hasFinalInput, queryContext, resultSize, trimSize, 
trimThreshold, new HashMap<>());
   }
 
   /**
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
index 4299e5665e..45ded8f1e5 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
@@ -50,6 +50,7 @@ import org.apache.pinot.spi.utils.ByteArray;
 @SuppressWarnings({"rawtypes", "unchecked"})
 public class TableResizer {
   private final DataSchema _dataSchema;
+  private final boolean _hasFinalInput;
   private final int _numGroupByExpressions;
   private final Map<ExpressionContext, Integer> _groupByExpressionIndexMap;
   private final AggregationFunction[] _aggregationFunctions;
@@ -61,7 +62,12 @@ public class TableResizer {
   private final Comparator<IntermediateRecord> _intermediateRecordComparator;
 
   public TableResizer(DataSchema dataSchema, QueryContext queryContext) {
+    this(dataSchema, false, queryContext);
+  }
+
+  public TableResizer(DataSchema dataSchema, boolean hasFinalInput, 
QueryContext queryContext) {
     _dataSchema = dataSchema;
+    _hasFinalInput = hasFinalInput;
 
     // NOTE: The data schema will always have group-by expressions in the 
front, followed by aggregation functions of
     //       the same order as in the query context. This is handled in 
AggregationGroupByOrderByOperator.
@@ -144,16 +150,20 @@ public class TableResizer {
         expression);
     if (function.getType() == FunctionContext.Type.AGGREGATION) {
       // Aggregation function
-      return new 
AggregationFunctionExtractor(_aggregationFunctionIndexMap.get(function));
-    } else if (function.getType() == FunctionContext.Type.TRANSFORM
-        && "FILTER".equalsIgnoreCase(function.getFunctionName())) {
+      int index = _aggregationFunctionIndexMap.get(function);
+      // For final aggregate result, we can handle it the same way as group key
+      return _hasFinalInput ? new 
GroupByExpressionExtractor(_numGroupByExpressions + index)
+          : new AggregationFunctionExtractor(index);
+    } else if (function.getType() == FunctionContext.Type.TRANSFORM && 
"FILTER".equalsIgnoreCase(
+        function.getFunctionName())) {
+      // Filtered aggregation
       FunctionContext aggregation = 
function.getArguments().get(0).getFunction();
       ExpressionContext filterExpression = function.getArguments().get(1);
       FilterContext filter = RequestContextUtils.getFilter(filterExpression);
-
-      int functionIndex = 
_filteredAggregationIndexMap.get(Pair.of(aggregation, filter));
-      AggregationFunction aggregationFunction = 
_filteredAggregationFunctions.get(functionIndex).getLeft();
-      return new AggregationFunctionExtractor(functionIndex, 
aggregationFunction);
+      int index = _filteredAggregationIndexMap.get(Pair.of(aggregation, 
filter));
+      // For final aggregate result, we can handle it the same way as group key
+      return _hasFinalInput ? new 
GroupByExpressionExtractor(_numGroupByExpressions + index)
+          : new AggregationFunctionExtractor(index, 
_filteredAggregationFunctions.get(index).getLeft());
     } else {
       // Post-aggregation function
       return new PostAggregationFunctionExtractor(function);
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/UnboundedConcurrentIndexedTable.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/UnboundedConcurrentIndexedTable.java
index 78788f5100..67f82b2011 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/data/table/UnboundedConcurrentIndexedTable.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/data/table/UnboundedConcurrentIndexedTable.java
@@ -36,7 +36,12 @@ import 
org.apache.pinot.core.query.request.context.QueryContext;
 public class UnboundedConcurrentIndexedTable extends ConcurrentIndexedTable {
 
   public UnboundedConcurrentIndexedTable(DataSchema dataSchema, QueryContext 
queryContext, int resultSize) {
-    super(dataSchema, queryContext, resultSize, Integer.MAX_VALUE, 
Integer.MAX_VALUE);
+    this(dataSchema, false, queryContext, resultSize);
+  }
+
+  public UnboundedConcurrentIndexedTable(DataSchema dataSchema, boolean 
hasFinalInput, QueryContext queryContext,
+      int resultSize) {
+    super(dataSchema, hasFinalInput, queryContext, resultSize, 
Integer.MAX_VALUE, Integer.MAX_VALUE);
   }
 
   @Override
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByCombineOperator.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByCombineOperator.java
index a2c777ed7b..5faf3bf974 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByCombineOperator.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/combine/GroupByCombineOperator.java
@@ -239,10 +239,12 @@ public class GroupByCombineOperator extends 
BaseSingleBlockCombineOperator<Group
     }
 
     IndexedTable indexedTable = _indexedTable;
-    if (!_queryContext.isServerReturnFinalResult()) {
-      indexedTable.finish(false);
-    } else {
+    if (_queryContext.isServerReturnFinalResult()) {
       indexedTable.finish(true, true);
+    } else if (_queryContext.isServerReturnFinalResultKeyUnpartitioned()) {
+      indexedTable.finish(false, true);
+    } else {
+      indexedTable.finish(false);
     }
     GroupByResultsBlock mergedBlock = new GroupByResultsBlock(indexedTable, 
_queryContext);
     mergedBlock.setNumGroupsLimitReached(_numGroupsLimitReached);
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/streaming/StreamingGroupByCombineOperator.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/streaming/StreamingGroupByCombineOperator.java
index e05dd84edd..759564c855 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/streaming/StreamingGroupByCombineOperator.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/streaming/StreamingGroupByCombineOperator.java
@@ -244,10 +244,12 @@ public class StreamingGroupByCombineOperator extends 
BaseStreamingCombineOperato
     }
 
     IndexedTable indexedTable = _indexedTable;
-    if (!_queryContext.isServerReturnFinalResult()) {
-      indexedTable.finish(false);
-    } else {
+    if (_queryContext.isServerReturnFinalResult()) {
       indexedTable.finish(true, true);
+    } else if (_queryContext.isServerReturnFinalResultKeyUnpartitioned()) {
+      indexedTable.finish(false, true);
+    } else {
+      indexedTable.finish(false);
     }
     GroupByResultsBlock mergedBlock = new GroupByResultsBlock(indexedTable, 
_queryContext);
     mergedBlock.setNumGroupsLimitReached(_numGroupsLimitReached);
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
index 8176308713..c172ef5b91 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
@@ -124,6 +124,14 @@ public interface AggregationFunction<IntermediateResult, 
FinalResult extends Com
    */
   FinalResult extractFinalResult(IntermediateResult intermediateResult);
 
+  /**
+   * Merges two final results. This can be used to optimized certain functions 
(e.g. DISTINCT_COUNT) when data is
+   * partitioned on each server, where we may directly request servers to 
return final result and merge them on broker.
+   */
+  default FinalResult mergeFinalResult(FinalResult finalResult1, FinalResult 
finalResult2) {
+    throw new UnsupportedOperationException("Cannot merge final results for 
function: " + getType());
+  }
+
   /** @return Description of this operator for Explain Plan */
   String toExplainString();
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
index 8d6cbf4aac..99a2165503 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
@@ -18,6 +18,11 @@
  */
 package org.apache.pinot.core.query.aggregation.function;
 
+import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.floats.FloatArrayList;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+import it.unimi.dsi.fastutil.objects.ObjectArrayList;
 import java.sql.Timestamp;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -141,7 +146,7 @@ public class AggregationFunctionUtils {
    * TODO: Move ser/de into AggregationFunction interface
    */
   public static Object getIntermediateResult(DataTable dataTable, 
ColumnDataType columnDataType, int rowId, int colId) {
-    switch (columnDataType) {
+    switch (columnDataType.getStoredType()) {
       case INT:
         return dataTable.getInt(rowId, colId);
       case LONG:
@@ -156,9 +161,43 @@ public class AggregationFunctionUtils {
     }
   }
 
+  /**
+   * Reads the final result from the {@link DataTable}.
+   */
+  public static Comparable getFinalResult(DataTable dataTable, ColumnDataType 
columnDataType, int rowId, int colId) {
+    switch (columnDataType.getStoredType()) {
+      case INT:
+        return dataTable.getInt(rowId, colId);
+      case LONG:
+        return dataTable.getLong(rowId, colId);
+      case FLOAT:
+        return dataTable.getFloat(rowId, colId);
+      case DOUBLE:
+        return dataTable.getDouble(rowId, colId);
+      case BIG_DECIMAL:
+        return dataTable.getBigDecimal(rowId, colId);
+      case STRING:
+        return dataTable.getString(rowId, colId);
+      case BYTES:
+        return dataTable.getBytes(rowId, colId);
+      case INT_ARRAY:
+        return IntArrayList.wrap(dataTable.getIntArray(rowId, colId));
+      case LONG_ARRAY:
+        return LongArrayList.wrap(dataTable.getLongArray(rowId, colId));
+      case FLOAT_ARRAY:
+        return FloatArrayList.wrap(dataTable.getFloatArray(rowId, colId));
+      case DOUBLE_ARRAY:
+        return DoubleArrayList.wrap(dataTable.getDoubleArray(rowId, colId));
+      case STRING_ARRAY:
+        return ObjectArrayList.wrap(dataTable.getStringArray(rowId, colId));
+      default:
+        throw new IllegalStateException("Illegal column data type in final 
result: " + columnDataType);
+    }
+  }
+
   /**
    * Reads the converted final result from the {@link DataTable}. It should be 
equivalent to running
-   * {@link AggregationFunction#extractFinalResult(Object)} and {@link 
ColumnDataType#convert(Object)}.
+   * {@link #getFinalResult} and {@link ColumnDataType#convert}.
    */
   public static Object getConvertedFinalResult(DataTable dataTable, 
ColumnDataType columnDataType, int rowId,
       int colId) {
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseBooleanAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseBooleanAggregationFunction.java
index 4045e496f6..c6b1216ca9 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseBooleanAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseBooleanAggregationFunction.java
@@ -246,6 +246,11 @@ public abstract class BaseBooleanAggregationFunction 
extends BaseSingleInputAggr
     return intermediateResult;
   }
 
+  @Override
+  public Integer mergeFinalResult(Integer finalResult1, Integer finalResult2) {
+    return merge(finalResult1, finalResult2);
+  }
+
   private int getInt(Integer val) {
     return val == null ? _merger.getDefaultValue() : val;
   }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ChildAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ChildAggregationFunction.java
index f1005799f1..357ebac212 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ChildAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ChildAggregationFunction.java
@@ -119,6 +119,11 @@ public abstract class ChildAggregationFunction implements 
AggregationFunction<Lo
     return 0L;
   }
 
+  @Override
+  public Long mergeFinalResult(Long finalResult1, Long finalResult2) {
+    return 0L;
+  }
+
   /**
    * The name of the column as follows:
    * CHILD_AGGREGATION_NAME_PREFIX + actual function type + operands + 
CHILD_AGGREGATION_SEPERATOR
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
index bc730adb05..b222803a44 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
@@ -204,6 +204,11 @@ public class CountAggregationFunction extends 
BaseSingleInputAggregationFunction
     return intermediateResult;
   }
 
+  @Override
+  public Long mergeFinalResult(Long finalResult1, Long finalResult2) {
+    return finalResult1 + finalResult2;
+  }
+
   @Override
   public String toExplainString() {
     StringBuilder stringBuilder = new 
StringBuilder(getType().getName()).append('(');
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
index 61588bbecb..076bc2ccda 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
@@ -66,4 +66,9 @@ public class DistinctCountAggregationFunction extends 
BaseDistinctAggregateAggre
   public Integer extractFinalResult(Set intermediateResult) {
     return intermediateResult.size();
   }
+
+  @Override
+  public Integer mergeFinalResult(Integer finalResult1, Integer finalResult2) {
+    return finalResult1 + finalResult2;
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapAggregationFunction.java
index d37851acf9..d3a7593335 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountBitmapAggregationFunction.java
@@ -329,6 +329,11 @@ public class DistinctCountBitmapAggregationFunction 
extends BaseSingleInputAggre
     return intermediateResult.getCardinality();
   }
 
+  @Override
+  public Integer mergeFinalResult(Integer finalResult1, Integer finalResult2) {
+    return finalResult1 + finalResult2;
+  }
+
   /**
    * Returns the dictionary id bitmap from the result holder or creates a new 
one if it does not exist.
    */
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountCPCSketchAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountCPCSketchAggregationFunction.java
index b42e36a091..8784ec7373 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountCPCSketchAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountCPCSketchAggregationFunction.java
@@ -414,6 +414,11 @@ public class DistinctCountCPCSketchAggregationFunction
     return Math.round(intermediateResult.getResult().getEstimate());
   }
 
+  @Override
+  public Comparable mergeFinalResult(Comparable finalResult1, Comparable 
finalResult2) {
+    return (Long) finalResult1 + (Long) finalResult2;
+  }
+
   /**
    * Returns the CpcSketch from the result holder or creates a new one if it 
does not exist.
    */
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java
index a4386827bb..504c542f0e 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java
@@ -363,6 +363,11 @@ public class DistinctCountHLLAggregationFunction extends 
BaseSingleInputAggregat
     return intermediateResult.cardinality();
   }
 
+  @Override
+  public Long mergeFinalResult(Long finalResult1, Long finalResult2) {
+    return finalResult1 + finalResult2;
+  }
+
   /**
    * Returns the dictionary id bitmap from the result holder or creates a new 
one if it does not exist.
    */
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLPlusAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLPlusAggregationFunction.java
index 2ca7d4eec3..b27f4dd524 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLPlusAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLPlusAggregationFunction.java
@@ -376,6 +376,11 @@ public class DistinctCountHLLPlusAggregationFunction 
extends BaseSingleInputAggr
     return intermediateResult.cardinality();
   }
 
+  @Override
+  public Long mergeFinalResult(Long finalResult1, Long finalResult2) {
+    return finalResult1 + finalResult2;
+  }
+
   /**
    * Returns the dictionary id bitmap from the result holder or creates a new 
one if it does not exist.
    */
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountIntegerTupleSketchAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountIntegerTupleSketchAggregationFunction.java
index 68ec18e401..b10797a58c 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountIntegerTupleSketchAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountIntegerTupleSketchAggregationFunction.java
@@ -51,4 +51,9 @@ public class 
DistinctCountIntegerTupleSketchAggregationFunction extends IntegerT
     accumulator.setThreshold(_accumulatorThreshold);
     return Double.valueOf(accumulator.getResult().getEstimate()).longValue();
   }
+
+  @Override
+  public Comparable mergeFinalResult(Comparable finalResult1, Comparable 
finalResult2) {
+    return (Long) finalResult1 + (Long) finalResult2;
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
index b0ec975348..aa1cd6da66 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
@@ -65,4 +65,9 @@ public class DistinctCountMVAggregationFunction extends 
BaseDistinctAggregateAgg
   public Integer extractFinalResult(Set intermediateResult) {
     return intermediateResult.size();
   }
+
+  @Override
+  public Integer mergeFinalResult(Integer finalResult1, Integer finalResult2) {
+    return finalResult1 + finalResult2;
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountSmartHLLAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountSmartHLLAggregationFunction.java
index 0aedb3ae7e..800cc47989 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountSmartHLLAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountSmartHLLAggregationFunction.java
@@ -734,6 +734,11 @@ public class DistinctCountSmartHLLAggregationFunction 
extends BaseSingleInputAgg
     }
   }
 
+  @Override
+  public Integer mergeFinalResult(Integer finalResult1, Integer finalResult2) {
+    return finalResult1 + finalResult2;
+  }
+
   /**
    * Returns the dictionary id bitmap from the result holder or creates a new 
one if it does not exist.
    */
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
index a2dd23708b..aef397b821 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
@@ -1030,6 +1030,11 @@ public class DistinctCountThetaSketchAggregationFunction
     return 
Math.round(evaluatePostAggregationExpression(_postAggregationExpression, 
mergedSketches).getEstimate());
   }
 
+  @Override
+  public Comparable mergeFinalResult(Comparable finalResult1, Comparable 
finalResult2) {
+    return (Long) finalResult1 + (Long) finalResult2;
+  }
+
   // This ensures backward compatibility with servers that still return 
sketches directly.
   // The AggregationDataTableReducer casts intermediate results to Objects and 
although the code compiles,
   // types might still be incompatible at runtime due to type erasure.
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountULLAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountULLAggregationFunction.java
index 66f731c66b..9e69cc9b85 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountULLAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountULLAggregationFunction.java
@@ -359,6 +359,11 @@ public class DistinctCountULLAggregationFunction extends 
BaseSingleInputAggregat
     return Math.round(intermediateResult.getDistinctCountEstimate());
   }
 
+  @Override
+  public Comparable mergeFinalResult(Comparable finalResult1, Comparable 
finalResult2) {
+    return (Long) finalResult1 + (Long) finalResult2;
+  }
+
   /**
    * Returns the dictionary id bitmap from the result holder or creates a new 
one if it does not exist.
    */
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctSumAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctSumAggregationFunction.java
index a7bb1894c3..602017da65 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctSumAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctSumAggregationFunction.java
@@ -71,4 +71,9 @@ public class DistinctSumAggregationFunction extends 
BaseDistinctAggregateAggrega
 
     return distinctSum;
   }
+
+  @Override
+  public Double mergeFinalResult(Double finalResult1, Double finalResult2) {
+    return finalResult1 + finalResult2;
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctSumMVAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctSumMVAggregationFunction.java
index acd20a5348..044f95db04 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctSumMVAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctSumMVAggregationFunction.java
@@ -71,4 +71,9 @@ public class DistinctSumMVAggregationFunction extends 
BaseDistinctAggregateAggre
 
     return distinctSum;
   }
+
+  @Override
+  public Double mergeFinalResult(Double finalResult1, Double finalResult2) {
+    return finalResult1 + finalResult2;
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java
index e1c7db767d..a9f764352c 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java
@@ -179,6 +179,11 @@ public class FastHLLAggregationFunction extends 
BaseSingleInputAggregationFuncti
     return intermediateResult.cardinality();
   }
 
+  @Override
+  public Long mergeFinalResult(Long finalResult1, Long finalResult2) {
+    return finalResult1 + finalResult2;
+  }
+
   private static HyperLogLog convertStringToHLL(String value) {
     char[] chars = value.toCharArray();
     int length = chars.length;
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
index 25654fac59..c2d37d35d9 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
@@ -310,4 +310,9 @@ public class MaxAggregationFunction extends 
BaseSingleInputAggregationFunction<D
   public Double extractFinalResult(Double intermediateResult) {
     return intermediateResult;
   }
+
+  @Override
+  public Double mergeFinalResult(Double finalResult1, Double finalResult2) {
+    return merge(finalResult1, finalResult2);
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
index aa2ca50bbc..a74b7a53ee 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
@@ -309,4 +309,9 @@ public class MinAggregationFunction extends 
BaseSingleInputAggregationFunction<D
   public Double extractFinalResult(Double intermediateResult) {
     return intermediateResult;
   }
+
+  @Override
+  public Double mergeFinalResult(Double finalResult1, Double finalResult2) {
+    return merge(finalResult1, finalResult2);
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java
index 28299429c6..e6b5b0ad84 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java
@@ -180,6 +180,14 @@ public class MinMaxRangeAggregationFunction extends 
NullableSingleInputAggregati
 
   @Override
   public MinMaxRangePair merge(MinMaxRangePair intermediateResult1, 
MinMaxRangePair intermediateResult2) {
+    if (_nullHandlingEnabled) {
+      if (intermediateResult1 == null) {
+        return intermediateResult2;
+      }
+      if (intermediateResult2 == null) {
+        return intermediateResult1;
+      }
+    }
     intermediateResult1.apply(intermediateResult2);
     return intermediateResult1;
   }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java
index 996a077c0a..dcf05cc2ed 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java
@@ -328,6 +328,11 @@ public class 
SegmentPartitionedDistinctCountAggregationFunction extends BaseSing
     return intermediateResult;
   }
 
+  @Override
+  public Long mergeFinalResult(Long finalResult1, Long finalResult2) {
+    return finalResult1 + finalResult2;
+  }
+
   /**
    * Helper method to set an INT value for the given group key into the result 
holder.
    */
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
index 46e734349a..b90dcc2051 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
@@ -294,4 +294,9 @@ public class SumAggregationFunction extends 
BaseSingleInputAggregationFunction<D
   public Double extractFinalResult(Double intermediateResult) {
     return intermediateResult;
   }
+
+  @Override
+  public Double mergeFinalResult(Double finalResult1, Double finalResult2) {
+    return merge(finalResult1, finalResult2);
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
index 5734a49907..2bad736974 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java
@@ -489,6 +489,11 @@ public class SumPrecisionAggregationFunction extends 
BaseSingleInputAggregationF
     return _scale == null ? result : result.setScale(_scale, 
RoundingMode.HALF_EVEN);
   }
 
+  @Override
+  public BigDecimal mergeFinalResult(BigDecimal finalResult1, BigDecimal 
finalResult2) {
+    return merge(finalResult1, finalResult2);
+  }
+
   public BigDecimal getDefaultResult(AggregationResultHolder 
aggregationResultHolder) {
     BigDecimal result = aggregationResultHolder.getResult();
     return result != null ? result : BigDecimal.ZERO;
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumValuesIntegerTupleSketchAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumValuesIntegerTupleSketchAggregationFunction.java
index d37854b1b0..fa4ac2d68d 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumValuesIntegerTupleSketchAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumValuesIntegerTupleSketchAggregationFunction.java
@@ -59,4 +59,9 @@ public class SumValuesIntegerTupleSketchAggregationFunction 
extends IntegerTuple
     double estimate = retainedTotal / result.getTheta();
     return Math.round(estimate);
   }
+
+  @Override
+  public Comparable mergeFinalResult(Comparable finalResult1, Comparable 
finalResult2) {
+    return (Long) finalResult1 + (Long) finalResult2;
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunction.java
index 3c258277db..29b18078fc 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelCountAggregationFunction.java
@@ -155,6 +155,16 @@ public class FunnelCountAggregationFunction<A, I> 
implements AggregationFunction
     return _mergeStrategy.merge(a, b);
   }
 
+  @Override
+  public ColumnDataType getIntermediateResultColumnType() {
+    return ColumnDataType.OBJECT;
+  }
+
+  @Override
+  public ColumnDataType getFinalResultColumnType() {
+    return ColumnDataType.LONG_ARRAY;
+  }
+
   @Override
   public LongArrayList extractFinalResult(I intermediateResult) {
     if (intermediateResult == null) {
@@ -164,13 +174,13 @@ public class FunnelCountAggregationFunction<A, I> 
implements AggregationFunction
   }
 
   @Override
-  public ColumnDataType getIntermediateResultColumnType() {
-    return ColumnDataType.OBJECT;
-  }
-
-  @Override
-  public ColumnDataType getFinalResultColumnType() {
-    return ColumnDataType.LONG_ARRAY;
+  public LongArrayList mergeFinalResult(LongArrayList finalResult1, 
LongArrayList finalResult2) {
+    long[] elements1 = finalResult1.elements();
+    long[] elements2 = finalResult2.elements();
+    for (int i = 0; i < _numSteps; i++) {
+      elements1[i] += elements2[i];
+    }
+    return finalResult1;
   }
 
   @Override
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelMaxStepAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelMaxStepAggregationFunction.java
index e8f316e187..cb616649ea 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelMaxStepAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/funnel/FunnelMaxStepAggregationFunction.java
@@ -291,6 +291,11 @@ public class FunnelMaxStepAggregationFunction
     return maxStep;
   }
 
+  @Override
+  public Long mergeFinalResult(Long finalResult1, Long finalResult2) {
+    return Math.max(finalResult1, finalResult2);
+  }
+
   @Override
   public String toExplainString() {
     return "WindowFunnelAggregationFunction{"
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
index d3c2711e81..1c39b6971b 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.core.query.reduce;
 
-import com.google.common.base.Preconditions;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
@@ -69,11 +68,15 @@ public class AggregationDataTableReducer implements 
DataTableReducer {
       return;
     }
 
-    if (!_queryContext.isServerReturnFinalResult()) {
-      reduceWithIntermediateResult(dataSchema, dataTableMap.values(), 
brokerResponseNative);
+    Collection<DataTable> dataTables = dataTableMap.values();
+    if (_queryContext.isServerReturnFinalResult()) {
+      if (dataTables.size() == 1) {
+        processSingleFinalResult(dataSchema, dataTables.iterator().next(), 
brokerResponseNative);
+      } else {
+        reduceWithFinalResult(dataSchema, dataTables, brokerResponseNative);
+      }
     } else {
-      Preconditions.checkState(dataTableMap.size() == 1, "Cannot merge final 
results from multiple servers");
-      reduceWithFinalResult(dataSchema, 
dataTableMap.values().iterator().next(), brokerResponseNative);
+      reduceWithIntermediateResult(dataSchema, dataTables, 
brokerResponseNative);
     }
   }
 
@@ -82,6 +85,7 @@ public class AggregationDataTableReducer implements 
DataTableReducer {
     int numAggregationFunctions = _aggregationFunctions.length;
     Object[] intermediateResults = new Object[numAggregationFunctions];
     for (DataTable dataTable : dataTables) {
+      Tracing.ThreadAccountantOps.sampleAndCheckInterruption();
       for (int i = 0; i < numAggregationFunctions; i++) {
         Object intermediateResultToMerge;
         ColumnDataType columnDataType = dataSchema.getColumnDataType(i);
@@ -101,19 +105,18 @@ public class AggregationDataTableReducer implements 
DataTableReducer {
         } else {
           intermediateResults[i] = 
_aggregationFunctions[i].merge(mergedIntermediateResult, 
intermediateResultToMerge);
         }
-        Tracing.ThreadAccountantOps.sampleAndCheckInterruptionPeriodically(i);
       }
     }
     Object[] finalResults = new Object[numAggregationFunctions];
     for (int i = 0; i < numAggregationFunctions; i++) {
       AggregationFunction aggregationFunction = _aggregationFunctions[i];
       Comparable result = 
aggregationFunction.extractFinalResult(intermediateResults[i]);
-      finalResults[i] = result == null ? null : 
aggregationFunction.getFinalResultColumnType().convert(result);
+      finalResults[i] = result != null ? 
aggregationFunction.getFinalResultColumnType().convert(result) : null;
     }
     
brokerResponseNative.setResultTable(reduceToResultTable(getPrePostAggregationDataSchema(dataSchema),
 finalResults));
   }
 
-  private void reduceWithFinalResult(DataSchema dataSchema, DataTable 
dataTable,
+  private void processSingleFinalResult(DataSchema dataSchema, DataTable 
dataTable,
       BrokerResponseNative brokerResponseNative) {
     int numAggregationFunctions = _aggregationFunctions.length;
     Object[] finalResults = new Object[numAggregationFunctions];
@@ -133,6 +136,43 @@ public class AggregationDataTableReducer implements 
DataTableReducer {
     brokerResponseNative.setResultTable(reduceToResultTable(dataSchema, 
finalResults));
   }
 
+  private void reduceWithFinalResult(DataSchema dataSchema, 
Collection<DataTable> dataTables,
+      BrokerResponseNative brokerResponseNative) {
+    int numAggregationFunctions = _aggregationFunctions.length;
+    Comparable[] finalResults = new Comparable[numAggregationFunctions];
+    for (DataTable dataTable : dataTables) {
+      for (int i = 0; i < numAggregationFunctions; i++) {
+        Tracing.ThreadAccountantOps.sampleAndCheckInterruption();
+        Comparable finalResultToMerge;
+        ColumnDataType columnDataType = dataSchema.getColumnDataType(i);
+        if (_queryContext.isNullHandlingEnabled()) {
+          RoaringBitmap nullBitmap = dataTable.getNullRowIds(i);
+          if (nullBitmap != null && nullBitmap.contains(0)) {
+            finalResultToMerge = null;
+          } else {
+            finalResultToMerge = 
AggregationFunctionUtils.getFinalResult(dataTable, columnDataType, 0, i);
+          }
+        } else {
+          finalResultToMerge = 
AggregationFunctionUtils.getFinalResult(dataTable, columnDataType, 0, i);
+        }
+        Comparable mergedFinalResult = finalResults[i];
+        if (mergedFinalResult == null) {
+          finalResults[i] = finalResultToMerge;
+        } else {
+          finalResults[i] = 
_aggregationFunctions[i].mergeFinalResult(mergedFinalResult, 
finalResultToMerge);
+        }
+      }
+    }
+    Object[] convertedFinalResults = new Object[numAggregationFunctions];
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      AggregationFunction aggregationFunction = _aggregationFunctions[i];
+      Comparable result = finalResults[i];
+      convertedFinalResults[i] = result != null ? 
aggregationFunction.getFinalResultColumnType().convert(result) : null;
+    }
+    brokerResponseNative.setResultTable(
+        reduceToResultTable(getPrePostAggregationDataSchema(dataSchema), 
convertedFinalResults));
+  }
+
   /**
    * Sets aggregation results into ResultsTable
    */
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
index c0a109f7e4..46d46d7391 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
@@ -18,7 +18,11 @@
  */
 package org.apache.pinot.core.query.reduce;
 
-import com.google.common.base.Preconditions;
+import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.floats.FloatArrayList;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+import it.unimi.dsi.fastutil.objects.ObjectArrayList;
 import java.sql.Timestamp;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -30,8 +34,10 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicReference;
 import org.apache.commons.lang.StringUtils;
 import org.apache.pinot.common.CustomObject;
+import org.apache.pinot.common.Utils;
 import org.apache.pinot.common.datatable.DataTable;
 import org.apache.pinot.common.exception.QueryException;
 import org.apache.pinot.common.metrics.BrokerGauge;
@@ -69,14 +75,13 @@ import org.roaringbitmap.RoaringBitmap;
 /**
  * Helper class to reduce data tables and set group by results into the 
BrokerResponseNative
  */
-@SuppressWarnings({"rawtypes", "unchecked"})
+@SuppressWarnings("rawtypes")
 public class GroupByDataTableReducer implements DataTableReducer {
   private static final int MIN_DATA_TABLES_FOR_CONCURRENT_REDUCE = 2; // TBD, 
find a better value.
 
   private final QueryContext _queryContext;
   private final AggregationFunction[] _aggregationFunctions;
   private final int _numAggregationFunctions;
-  private final List<ExpressionContext> _groupByExpressions;
   private final int _numGroupByExpressions;
   private final int _numColumns;
 
@@ -85,9 +90,9 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
     _aggregationFunctions = queryContext.getAggregationFunctions();
     assert _aggregationFunctions != null;
     _numAggregationFunctions = _aggregationFunctions.length;
-    _groupByExpressions = queryContext.getGroupByExpressions();
-    assert _groupByExpressions != null;
-    _numGroupByExpressions = _groupByExpressions.size();
+    List<ExpressionContext> groupByExpressions = 
queryContext.getGroupByExpressions();
+    assert groupByExpressions != null;
+    _numGroupByExpressions = groupByExpressions.size();
     _numColumns = _numAggregationFunctions + _numGroupByExpressions;
   }
 
@@ -109,18 +114,18 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
       return;
     }
 
-    if (!_queryContext.isServerReturnFinalResult()) {
+    Collection<DataTable> dataTables = dataTableMap.values();
+    // NOTE: Use regular reduce when group keys are not partitioned even if 
there are only one data table because the
+    //       records are not sorted yet.
+    if (_queryContext.isServerReturnFinalResult() && dataTables.size() == 1) {
+      processSingleFinalResult(dataSchema, dataTables.iterator().next(), 
brokerResponse);
+    } else {
       try {
-        reduceWithIntermediateResult(brokerResponse, dataSchema, 
dataTableMap.values(), reducerContext, tableName,
-            brokerMetrics);
+        reduceResult(brokerResponse, dataSchema, dataTables, reducerContext, 
tableName, brokerMetrics);
       } catch (TimeoutException e) {
         brokerResponse.getExceptions()
             .add(new 
QueryProcessingException(QueryException.BROKER_TIMEOUT_ERROR_CODE, 
e.getMessage()));
       }
-    } else {
-      // TODO: Support merging results from multiple servers when the data is 
partitioned on the group-by column
-      Preconditions.checkState(dataTableMap.size() == 1, "Cannot merge final 
results from multiple servers");
-      reduceWithFinalResult(dataSchema, 
dataTableMap.values().iterator().next(), brokerResponse);
     }
 
     if (brokerMetrics != null && brokerResponse.getResultTable() != null) {
@@ -139,10 +144,11 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
    * @param brokerMetrics broker metrics (meters)
    * @throws TimeoutException If unable complete within timeout.
    */
-  private void reduceWithIntermediateResult(BrokerResponseNative 
brokerResponseNative, DataSchema dataSchema,
+  private void reduceResult(BrokerResponseNative brokerResponseNative, 
DataSchema dataSchema,
       Collection<DataTable> dataTables, DataTableReducerContext 
reducerContext, String rawTableName,
       BrokerMetrics brokerMetrics)
       throws TimeoutException {
+    // NOTE: This step will modify the data schema and also return final 
aggregate results.
     IndexedTable indexedTable = getIndexedTable(dataSchema, dataTables, 
reducerContext);
     if (brokerMetrics != null) {
       brokerMetrics.addMeteredTableValue(rawTableName, 
BrokerMeter.NUM_RESIZES, indexedTable.getNumResizes());
@@ -151,9 +157,7 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
     int numRecords = indexedTable.size();
     Iterator<Record> sortedIterator = indexedTable.iterator();
 
-    DataSchema prePostAggregationDataSchema = 
getPrePostAggregationDataSchema(dataSchema);
-    PostAggregationHandler postAggregationHandler =
-        new PostAggregationHandler(_queryContext, 
prePostAggregationDataSchema);
+    PostAggregationHandler postAggregationHandler = new 
PostAggregationHandler(_queryContext, dataSchema);
     DataSchema resultDataSchema = postAggregationHandler.getResultDataSchema();
 
     // Directly return when there is no record returned, or limit is 0
@@ -165,7 +169,7 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
 
     // Calculate rows before post-aggregation
     List<Object[]> rows;
-    ColumnDataType[] columnDataTypes = 
prePostAggregationDataSchema.getColumnDataTypes();
+    ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
     int numColumns = columnDataTypes.length;
     FilterContext havingFilter = _queryContext.getHavingFilter();
     if (havingFilter != null) {
@@ -175,7 +179,6 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
       int processedRows = 0;
       while (rows.size() < limit && sortedIterator.hasNext()) {
         Object[] row = sortedIterator.next().getValues();
-        extractFinalAggregationResults(row);
         for (int i = 0; i < numColumns; i++) {
           Object value = row[i];
           if (value != null) {
@@ -193,7 +196,6 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
       rows = new ArrayList<>(numRows);
       for (int i = 0; i < numRows; i++) {
         Object[] row = sortedIterator.next().getValues();
-        extractFinalAggregationResults(row);
         for (int j = 0; j < numColumns; j++) {
           Object value = row[j];
           if (value != null) {
@@ -208,22 +210,9 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
     // Calculate final result rows after post aggregation
     List<Object[]> resultRows = 
calculateFinalResultRows(postAggregationHandler, rows);
 
-    RewriterResult resultRewriterResult =
-        ResultRewriteUtils.rewriteResult(resultDataSchema, resultRows);
-    resultRows = resultRewriterResult.getRows();
-    resultDataSchema = resultRewriterResult.getDataSchema();
-
-    brokerResponseNative.setResultTable(new ResultTable(resultDataSchema, 
resultRows));
-  }
-
-  /**
-   * Helper method to extract the final aggregation results for the given row 
(in-place).
-   */
-  private void extractFinalAggregationResults(Object[] row) {
-    for (int i = 0; i < _numAggregationFunctions; i++) {
-      int valueIndex = i + _numGroupByExpressions;
-      row[valueIndex] = 
_aggregationFunctions[i].extractFinalResult(row[valueIndex]);
-    }
+    // Rewrite and set result table
+    RewriterResult rewriterResult = 
ResultRewriteUtils.rewriteResult(resultDataSchema, resultRows);
+    brokerResponseNative.setResultTable(new 
ResultTable(rewriterResult.getDataSchema(), rewriterResult.getRows()));
   }
 
   /**
@@ -248,6 +237,8 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
     // Get the number of threads to use for reducing.
     // In case of single reduce thread, fall back to SimpleIndexedTable to 
avoid redundant locking/unlocking calls.
     int numReduceThreadsToUse = getNumReduceThreadsToUse(numDataTables, 
reducerContext.getMaxReduceThreadsPerQuery());
+    boolean hasFinalInput =
+        _queryContext.isServerReturnFinalResult() || 
_queryContext.isServerReturnFinalResultKeyUnpartitioned();
     int limit = _queryContext.getLimit();
     int trimSize = GroupByUtils.getTableCapacity(limit, 
reducerContext.getMinGroupTrimSize());
     // NOTE: For query with HAVING clause, use trimSize as resultSize to 
ensure the result accuracy.
@@ -256,16 +247,18 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
     int trimThreshold = reducerContext.getGroupByTrimThreshold();
     IndexedTable indexedTable;
     if (numReduceThreadsToUse == 1) {
-      indexedTable = new SimpleIndexedTable(dataSchema, _queryContext, 
resultSize, trimSize, trimThreshold);
+      indexedTable =
+          new SimpleIndexedTable(dataSchema, hasFinalInput, _queryContext, 
resultSize, trimSize, trimThreshold);
     } else {
       if (trimThreshold >= GroupByCombineOperator.MAX_TRIM_THRESHOLD) {
         // special case of trim threshold where it is set to max value.
         // there won't be any trimming during upsert in this case.
         // thus we can avoid the overhead of read-lock and write-lock
         // in the upsert method.
-        indexedTable = new UnboundedConcurrentIndexedTable(dataSchema, 
_queryContext, resultSize);
+        indexedTable = new UnboundedConcurrentIndexedTable(dataSchema, 
hasFinalInput, _queryContext, resultSize);
       } else {
-        indexedTable = new ConcurrentIndexedTable(dataSchema, _queryContext, 
resultSize, trimSize, trimThreshold);
+        indexedTable =
+            new ConcurrentIndexedTable(dataSchema, hasFinalInput, 
_queryContext, resultSize, trimSize, trimThreshold);
       }
     }
 
@@ -282,7 +275,8 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
     }
 
     Future[] futures = new Future[numReduceThreadsToUse];
-    CountDownLatch countDownLatch = new CountDownLatch(numDataTables);
+    CountDownLatch countDownLatch = new CountDownLatch(numReduceThreadsToUse);
+    AtomicReference<Throwable> exception = new AtomicReference<>();
     ColumnDataType[] storedColumnDataTypes = 
dataSchema.getStoredColumnDataTypes();
     for (int i = 0; i < numReduceThreadsToUse; i++) {
       List<DataTable> reduceGroup = reduceGroups.get(i);
@@ -294,72 +288,87 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
           Tracing.ThreadAccountantOps.setupWorker(taskId, new 
ThreadResourceUsageProvider(), parentContext);
           try {
             for (DataTable dataTable : reduceGroup) {
-              try {
-                boolean nullHandlingEnabled = 
_queryContext.isNullHandlingEnabled();
-                RoaringBitmap[] nullBitmaps = null;
-                if (nullHandlingEnabled) {
-                  nullBitmaps = new RoaringBitmap[_numColumns];
-                  for (int i = 0; i < _numColumns; i++) {
-                    nullBitmaps[i] = dataTable.getNullRowIds(i);
-                  }
+              boolean nullHandlingEnabled = 
_queryContext.isNullHandlingEnabled();
+              RoaringBitmap[] nullBitmaps = null;
+              if (nullHandlingEnabled) {
+                nullBitmaps = new RoaringBitmap[_numColumns];
+                for (int i = 0; i < _numColumns; i++) {
+                  nullBitmaps[i] = dataTable.getNullRowIds(i);
                 }
+              }
 
-                int numRows = dataTable.getNumberOfRows();
-                for (int rowId = 0; rowId < numRows; rowId++) {
-                  // Terminate when thread is interrupted.
-                  // This is expected when the query already fails in the main 
thread.
-                  // The first check will always be performed when rowId = 0
-                  
Tracing.ThreadAccountantOps.sampleAndCheckInterruptionPeriodically(rowId);
-                  Object[] values = new Object[_numColumns];
-                  for (int colId = 0; colId < _numColumns; colId++) {
-                    switch (storedColumnDataTypes[colId]) {
-                      case INT:
-                        values[colId] = dataTable.getInt(rowId, colId);
-                        break;
-                      case LONG:
-                        values[colId] = dataTable.getLong(rowId, colId);
-                        break;
-                      case FLOAT:
-                        values[colId] = dataTable.getFloat(rowId, colId);
-                        break;
-                      case DOUBLE:
-                        values[colId] = dataTable.getDouble(rowId, colId);
-                        break;
-                      case BIG_DECIMAL:
-                        values[colId] = dataTable.getBigDecimal(rowId, colId);
-                        break;
-                      case STRING:
-                        values[colId] = dataTable.getString(rowId, colId);
-                        break;
-                      case BYTES:
-                        values[colId] = dataTable.getBytes(rowId, colId);
-                        break;
-                      case OBJECT:
-                        // TODO: Move ser/de into AggregationFunction interface
-                        CustomObject customObject = 
dataTable.getCustomObject(rowId, colId);
-                        if (customObject != null) {
-                          values[colId] = 
ObjectSerDeUtils.deserialize(customObject);
-                        }
-                        break;
-                      // Add other aggregation intermediate result / group-by 
column type supports here
-                      default:
-                        throw new IllegalStateException();
-                    }
-                  }
-                  if (nullHandlingEnabled) {
-                    for (int colId = 0; colId < _numColumns; colId++) {
-                      if (nullBitmaps[colId] != null && 
nullBitmaps[colId].contains(rowId)) {
-                        values[colId] = null;
+              int numRows = dataTable.getNumberOfRows();
+              for (int rowId = 0; rowId < numRows; rowId++) {
+                // Terminate when thread is interrupted.
+                // This is expected when the query already fails in the main 
thread.
+                // The first check will always be performed when rowId = 0
+                
Tracing.ThreadAccountantOps.sampleAndCheckInterruptionPeriodically(rowId);
+                Object[] values = new Object[_numColumns];
+                for (int colId = 0; colId < _numColumns; colId++) {
+                  // NOTE: We need to handle data types for group key, 
intermediate and final aggregate result.
+                  switch (storedColumnDataTypes[colId]) {
+                    case INT:
+                      values[colId] = dataTable.getInt(rowId, colId);
+                      break;
+                    case LONG:
+                      values[colId] = dataTable.getLong(rowId, colId);
+                      break;
+                    case FLOAT:
+                      values[colId] = dataTable.getFloat(rowId, colId);
+                      break;
+                    case DOUBLE:
+                      values[colId] = dataTable.getDouble(rowId, colId);
+                      break;
+                    case BIG_DECIMAL:
+                      values[colId] = dataTable.getBigDecimal(rowId, colId);
+                      break;
+                    case STRING:
+                      values[colId] = dataTable.getString(rowId, colId);
+                      break;
+                    case BYTES:
+                      values[colId] = dataTable.getBytes(rowId, colId);
+                      break;
+                    case INT_ARRAY:
+                      values[colId] = 
IntArrayList.wrap(dataTable.getIntArray(rowId, colId));
+                      break;
+                    case LONG_ARRAY:
+                      values[colId] = 
LongArrayList.wrap(dataTable.getLongArray(rowId, colId));
+                      break;
+                    case FLOAT_ARRAY:
+                      values[colId] = 
FloatArrayList.wrap(dataTable.getFloatArray(rowId, colId));
+                      break;
+                    case DOUBLE_ARRAY:
+                      values[colId] = 
DoubleArrayList.wrap(dataTable.getDoubleArray(rowId, colId));
+                      break;
+                    case STRING_ARRAY:
+                      values[colId] = 
ObjectArrayList.wrap(dataTable.getStringArray(rowId, colId));
+                      break;
+                    case OBJECT:
+                      // TODO: Move ser/de into AggregationFunction interface
+                      CustomObject customObject = 
dataTable.getCustomObject(rowId, colId);
+                      if (customObject != null) {
+                        values[colId] = 
ObjectSerDeUtils.deserialize(customObject);
                       }
+                      break;
+                    // Add other aggregation intermediate result / group-by 
column type supports here
+                    default:
+                      throw new IllegalStateException();
+                  }
+                }
+                if (nullHandlingEnabled) {
+                  for (int colId = 0; colId < _numColumns; colId++) {
+                    if (nullBitmaps[colId] != null && 
nullBitmaps[colId].contains(rowId)) {
+                      values[colId] = null;
                     }
                   }
-                  indexedTable.upsert(new Record(values));
                 }
-              } finally {
-                countDownLatch.countDown();
+                indexedTable.upsert(new Record(values));
               }
             }
+          } catch (Throwable t) {
+            exception.compareAndSet(null, t);
           } finally {
+            countDownLatch.countDown();
             Tracing.ThreadAccountantOps.clear();
           }
         }
@@ -371,10 +380,15 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
       if (!countDownLatch.await(timeOutMs, TimeUnit.MILLISECONDS)) {
         throw new TimeoutException("Timed out in broker reduce phase");
       }
+      Throwable t = exception.get();
+      if (t != null) {
+        Utils.rethrowException(t);
+      }
     } catch (InterruptedException e) {
       Exception killedErrorMsg = 
Tracing.getThreadAccountant().getErrorStatus();
-      throw new EarlyTerminationException("Interrupted in broker reduce phase"
-          + (killedErrorMsg == null ? StringUtils.EMPTY : " " + 
killedErrorMsg), e);
+      throw new EarlyTerminationException(
+          "Interrupted in broker reduce phase" + (killedErrorMsg == null ? 
StringUtils.EMPTY : " " + killedErrorMsg),
+          e);
     } finally {
       for (Future future : futures) {
         if (!future.isDone()) {
@@ -383,7 +397,7 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
       }
     }
 
-    indexedTable.finish(true);
+    indexedTable.finish(true, true);
     return indexedTable;
   }
 
@@ -408,7 +422,7 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
     }
   }
 
-  private void reduceWithFinalResult(DataSchema dataSchema, DataTable 
dataTable,
+  private void processSingleFinalResult(DataSchema dataSchema, DataTable 
dataTable,
       BrokerResponseNative brokerResponseNative) {
     PostAggregationHandler postAggregationHandler = new 
PostAggregationHandler(_queryContext, dataSchema);
     DataSchema resultDataSchema = postAggregationHandler.getResultDataSchema();
@@ -448,12 +462,9 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
     // Calculate final result rows after post aggregation
     List<Object[]> resultRows = 
calculateFinalResultRows(postAggregationHandler, rows);
 
-    RewriterResult resultRewriterResult =
-        ResultRewriteUtils.rewriteResult(resultDataSchema, resultRows);
-    resultRows = resultRewriterResult.getRows();
-    resultDataSchema = resultRewriterResult.getDataSchema();
-
-    brokerResponseNative.setResultTable(new ResultTable(resultDataSchema, 
resultRows));
+    // Rewrite and set result table
+    RewriterResult rewriterResult = 
ResultRewriteUtils.rewriteResult(resultDataSchema, resultRows);
+    brokerResponseNative.setResultTable(new 
ResultTable(rewriterResult.getDataSchema(), rewriterResult.getRows()));
   }
 
   private List<Object[]> calculateFinalResultRows(PostAggregationHandler 
postAggregationHandler, List<Object[]> rows) {
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
index cd0ea14790..6c4a3d75c3 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
@@ -125,6 +125,8 @@ public class QueryContext {
   private boolean _nullHandlingEnabled;
   // Whether server returns the final result
   private boolean _serverReturnFinalResult;
+  // Whether server returns the final result with unpartitioned group key
+  private boolean _serverReturnFinalResultKeyUnpartitioned;
   // Collection of index types to skip per column
   private Map<String, Set<FieldConfig.IndexType>> _skipIndexes;
 
@@ -406,6 +408,14 @@ public class QueryContext {
     _serverReturnFinalResult = serverReturnFinalResult;
   }
 
+  public boolean isServerReturnFinalResultKeyUnpartitioned() {
+    return _serverReturnFinalResultKeyUnpartitioned;
+  }
+
+  public void setServerReturnFinalResultKeyUnpartitioned(boolean 
serverReturnFinalResultKeyUnpartitioned) {
+    _serverReturnFinalResultKeyUnpartitioned = 
serverReturnFinalResultKeyUnpartitioned;
+  }
+
   /**
    * Gets or computes a value of type {@code V} associated with a key of type 
{@code K} so that it can be shared
    * within the scope of a query.
@@ -545,6 +555,8 @@ public class QueryContext {
               _expressionOverrideHints, _explain);
       
queryContext.setNullHandlingEnabled(QueryOptionsUtils.isNullHandlingEnabled(_queryOptions));
       
queryContext.setServerReturnFinalResult(QueryOptionsUtils.isServerReturnFinalResult(_queryOptions));
+      queryContext.setServerReturnFinalResultKeyUnpartitioned(
+          
QueryOptionsUtils.isServerReturnFinalResultKeyUnpartitioned(_queryOptions));
 
       // Pre-calculate the aggregation functions and columns for the query
       generateAggregationFunctions(queryContext);
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiNodesOfflineClusterIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiNodesOfflineClusterIntegrationTest.java
index 3cf9cd8485..200c022523 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiNodesOfflineClusterIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiNodesOfflineClusterIntegrationTest.java
@@ -191,6 +191,101 @@ public class MultiNodesOfflineClusterIntegrationTest 
extends OfflineClusterInteg
     }
   }
 
+  @Test
+  public void testServerReturnFinalResult()
+      throws Exception {
+    // Data is segment partitioned on DaysSinceEpoch.
+    JsonNode result = postQuery("SELECT DISTINCT_COUNT(DaysSinceEpoch) FROM 
mytable");
+    
assertEquals(result.get("resultTable").get("rows").get(0).get(0).intValue(), 
364);
+    result = postQuery("SELECT 
SEGMENT_PARTITIONED_DISTINCT_COUNT(DaysSinceEpoch) FROM mytable");
+    
assertEquals(result.get("resultTable").get("rows").get(0).get(0).intValue(), 
364);
+    result = postQuery("SET serverReturnFinalResult = true; SELECT 
DISTINCT_COUNT(DaysSinceEpoch) FROM mytable");
+    
assertEquals(result.get("resultTable").get("rows").get(0).get(0).intValue(), 
364);
+
+    // Data is not partitioned on DayOfWeek. Each segment contains all 7 
unique values.
+    result = postQuery("SELECT DISTINCT_COUNT(DayOfWeek) FROM mytable");
+    
assertEquals(result.get("resultTable").get("rows").get(0).get(0).intValue(), 7);
+    result = postQuery("SELECT SEGMENT_PARTITIONED_DISTINCT_COUNT(DayOfWeek) 
FROM mytable");
+    
assertEquals(result.get("resultTable").get("rows").get(0).get(0).intValue(), 
84);
+    result = postQuery("SET serverReturnFinalResult = true; SELECT 
DISTINCT_COUNT(DayOfWeek) FROM mytable");
+    
assertEquals(result.get("resultTable").get("rows").get(0).get(0).intValue(), 
21);
+
+    // Data is segment partitioned on DaysSinceEpoch.
+    result =
+        postQuery("SELECT DaysSinceEpoch, DISTINCT_COUNT(CRSArrTime) FROM 
mytable GROUP BY 1 ORDER BY 2 DESC LIMIT 1");
+    JsonNode row = result.get("resultTable").get("rows").get(0);
+    assertEquals(row.get(0).intValue(), 16138);
+    assertEquals(row.get(1).intValue(), 398);
+    result = postQuery("SELECT DaysSinceEpoch, 
SEGMENT_PARTITIONED_DISTINCT_COUNT(CRSArrTime) FROM mytable GROUP BY 1 "
+        + "ORDER BY 2 DESC LIMIT 1");
+    row = result.get("resultTable").get("rows").get(0);
+    assertEquals(row.get(0).intValue(), 16138);
+    assertEquals(row.get(1).intValue(), 398);
+    result = postQuery("SET serverReturnFinalResult = true; "
+        + "SELECT DaysSinceEpoch, DISTINCT_COUNT(CRSArrTime) FROM mytable 
GROUP BY 1 ORDER BY 2 DESC LIMIT 1");
+    row = result.get("resultTable").get("rows").get(0);
+    assertEquals(row.get(0).intValue(), 16138);
+    assertEquals(row.get(1).intValue(), 398);
+    result = postQuery("SET serverReturnFinalResultKeyUnpartitioned = true; "
+        + "SELECT DaysSinceEpoch, DISTINCT_COUNT(CRSArrTime) FROM mytable 
GROUP BY 1 ORDER BY 2 DESC LIMIT 1");
+    row = result.get("resultTable").get("rows").get(0);
+    assertEquals(row.get(0).intValue(), 16138);
+    assertEquals(row.get(1).intValue(), 398);
+
+    // Data is segment partitioned on DaysSinceEpoch.
+    result =
+        postQuery("SELECT CRSArrTime, DISTINCT_COUNT(DaysSinceEpoch) FROM 
mytable GROUP BY 1 ORDER BY 2 DESC LIMIT 1");
+    row = result.get("resultTable").get("rows").get(0);
+    assertEquals(row.get(0).intValue(), 2100);
+    assertEquals(row.get(1).intValue(), 253);
+    result = postQuery("SELECT CRSArrTime, 
SEGMENT_PARTITIONED_DISTINCT_COUNT(DaysSinceEpoch) FROM mytable GROUP BY 1 "
+        + "ORDER BY 2 DESC LIMIT 1");
+    row = result.get("resultTable").get("rows").get(0);
+    assertEquals(row.get(0).intValue(), 2100);
+    assertEquals(row.get(1).intValue(), 253);
+    result = postQuery("SET serverReturnFinalResultKeyUnpartitioned = true; "
+        + "SELECT CRSArrTime, DISTINCT_COUNT(DaysSinceEpoch) FROM mytable 
GROUP BY 1 ORDER BY 2 DESC LIMIT 1");
+    row = result.get("resultTable").get("rows").get(0);
+    assertEquals(row.get(0).intValue(), 2100);
+    assertEquals(row.get(1).intValue(), 253);
+    // Data is not partitioned on CRSArrTime. Using serverReturnFinalResult 
will give wrong result.
+    result = postQuery("SET serverReturnFinalResult = true; "
+        + "SELECT CRSArrTime, DISTINCT_COUNT(DaysSinceEpoch) FROM mytable 
GROUP BY 1 ORDER BY 2 DESC LIMIT 1");
+    row = result.get("resultTable").get("rows").get(0);
+    assertTrue(row.get(1).intValue() < 253);
+
+    // Should fail when merging final results that cannot be merged.
+    try {
+      postQuery("SET serverReturnFinalResult = true; SELECT 
AVG(DaysSinceEpoch) FROM mytable");
+      fail();
+    } catch (Exception e) {
+      assertTrue(e.getMessage().contains("Cannot merge final results for 
function: AVG"));
+    }
+    try {
+      postQuery("SET serverReturnFinalResultKeyUnpartitioned = true; "
+          + "SELECT CRSArrTime, AVG(DaysSinceEpoch) FROM mytable GROUP BY 1 
ORDER BY 2 DESC LIMIT 1");
+      fail();
+    } catch (Exception e) {
+      assertTrue(e.getMessage().contains("Cannot merge final results for 
function: AVG"));
+    }
+
+    // Should not fail when group keys are partitioned because there is no 
need to merge final results.
+    result = postQuery("SELECT DaysSinceEpoch, AVG(CRSArrTime) FROM mytable 
GROUP BY 1 ORDER BY 2 DESC LIMIT 1");
+    row = result.get("resultTable").get("rows").get(0);
+    assertEquals(row.get(0).intValue(), 16257);
+    assertEquals(row.get(1).doubleValue(), 725560.0 / 444);
+    result = postQuery("SET serverReturnFinalResult = true; "
+        + "SELECT DaysSinceEpoch, AVG(CRSArrTime) FROM mytable GROUP BY 1 
ORDER BY 2 DESC LIMIT 1");
+    row = result.get("resultTable").get("rows").get(0);
+    assertEquals(row.get(0).intValue(), 16257);
+    assertEquals(row.get(1).doubleValue(), 725560.0 / 444);
+    result = postQuery("SET serverReturnFinalResultKeyUnpartitioned = true; "
+        + "SELECT DaysSinceEpoch, AVG(CRSArrTime) FROM mytable GROUP BY 1 
ORDER BY 2 DESC LIMIT 1");
+    row = result.get("resultTable").get("rows").get(0);
+    assertEquals(row.get(0).intValue(), 16257);
+    assertEquals(row.get(1).doubleValue(), 725560.0 / 444);
+  }
+
   // Disabled because with multiple replicas, there is no guarantee that all 
replicas are reloaded
   @Test(enabled = false)
   public void testStarTreeTriggering(boolean useMultiStageQueryEngine) {
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineGRPCServerIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineGRPCServerIntegrationTest.java
index 5ea826116b..6408fd8f31 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineGRPCServerIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineGRPCServerIntegrationTest.java
@@ -59,8 +59,8 @@ import static org.testng.Assert.*;
 
 public class OfflineGRPCServerIntegrationTest extends 
BaseClusterIntegrationTest {
   private static final ExecutorService EXECUTOR_SERVICE = 
Executors.newFixedThreadPool(2);
-  private static final DataTableReducerContext DATATABLE_REDUCER_CONTEXT = new 
DataTableReducerContext(
-      EXECUTOR_SERVICE, 2, 10000, 10000, 5000);
+  private static final DataTableReducerContext DATATABLE_REDUCER_CONTEXT =
+      new DataTableReducerContext(EXECUTOR_SERVICE, 2, 10000, 10000, 5000);
 
   @BeforeClass
   public void setUp()
@@ -106,7 +106,7 @@ public class OfflineGRPCServerIntegrationTest extends 
BaseClusterIntegrationTest
     GrpcQueryClient queryClient = getGrpcQueryClient();
     String sql = "SELECT * FROM mytable_OFFLINE LIMIT 1000000 
OPTION(timeoutMs=30000)";
     BrokerRequest brokerRequest = 
CalciteSqlCompiler.compileToBrokerRequest(sql);
-    List<String> segments = 
_helixResourceManager.getSegmentsFor("mytable_OFFLINE", false);
+    List<String> segments = 
_helixResourceManager.getSegmentsFor("mytable_OFFLINE", true);
 
     GrpcRequestBuilder requestBuilder = new 
GrpcRequestBuilder().setSegments(segments);
     
testNonStreamingRequest(queryClient.submit(requestBuilder.setSql(sql).build()));
@@ -121,15 +121,12 @@ public class OfflineGRPCServerIntegrationTest extends 
BaseClusterIntegrationTest
   @Test(dataProvider = "provideSqlTestCases")
   public void testQueryingGrpcServer(String sql)
       throws Exception {
-    GrpcQueryClient queryClient = getGrpcQueryClient();
-    List<String> segments = 
_helixResourceManager.getSegmentsFor("mytable_OFFLINE", false);
-
-    GrpcRequestBuilder requestBuilder = new 
GrpcRequestBuilder().setSegments(segments);
-    DataTable dataTable = 
collectNonStreamingRequestResult(queryClient.submit(requestBuilder.setSql(sql).build()));
-
-    requestBuilder.setEnableStreaming(true);
-    collectAndCompareResult(sql, 
queryClient.submit(requestBuilder.setSql(sql).build()), dataTable);
-    queryClient.close();
+    try (GrpcQueryClient queryClient = getGrpcQueryClient()) {
+      List<String> segments = 
_helixResourceManager.getSegmentsFor("mytable_OFFLINE", true);
+      GrpcRequestBuilder requestBuilder = new 
GrpcRequestBuilder().setSql(sql).setSegments(segments);
+      DataTable dataTable = 
collectNonStreamingRequestResult(queryClient.submit(requestBuilder.build()));
+      collectAndCompareResult(sql, 
queryClient.submit(requestBuilder.setEnableStreaming(true).build()), dataTable);
+    }
   }
 
   @DataProvider(name = "provideSqlTestCases")
@@ -157,12 +154,15 @@ public class OfflineGRPCServerIntegrationTest extends 
BaseClusterIntegrationTest
 
     // distinct
     entries.add(new Object[]{"SELECT DISTINCT(AirlineID) FROM mytable_OFFLINE 
LIMIT 1000000"});
-    entries.add(new Object[]{"SELECT AirlineID, ArrTime FROM mytable_OFFLINE "
-        + "GROUP BY AirlineID, ArrTime LIMIT 1000000"});
+    entries.add(new Object[]{
+        "SELECT AirlineID, ArrTime FROM mytable_OFFLINE GROUP BY AirlineID, 
ArrTime LIMIT 1000000"
+    });
 
     // order by
-    entries.add(new Object[]{"SELECT DaysSinceEpoch, 
timeConvert(DaysSinceEpoch,'DAYS','SECONDS') "
-        + "FROM mytable_OFFLINE ORDER BY DaysSinceEpoch limit 1000000"});
+    entries.add(new Object[]{
+        "SELECT DaysSinceEpoch, timeConvert(DaysSinceEpoch,'DAYS','SECONDS') 
FROM mytable_OFFLINE "
+            + "ORDER BY DaysSinceEpoch limit 1000000"
+    });
 
     return entries.toArray(new Object[entries.size()][]);
   }
@@ -205,10 +205,9 @@ public class OfflineGRPCServerIntegrationTest extends 
BaseClusterIntegrationTest
         BrokerResponseNative streamingBrokerResponse = new 
BrokerResponseNative();
         reducer.reduceAndSetResults("mytable_OFFLINE", cachedDataSchema, 
dataTableMap, streamingBrokerResponse,
             DATATABLE_REDUCER_CONTEXT, mock(BrokerMetrics.class));
-        dataTableMap.clear();
-        dataTableMap.put(mock(ServerRoutingInstance.class), 
nonStreamResultDataTable);
         BrokerResponseNative nonStreamBrokerResponse = new 
BrokerResponseNative();
-        reducer.reduceAndSetResults("mytable_OFFLINE", cachedDataSchema, 
dataTableMap, nonStreamBrokerResponse,
+        reducer.reduceAndSetResults("mytable_OFFLINE", 
nonStreamResultDataTable.getDataSchema(),
+            Map.of(mock(ServerRoutingInstance.class), 
nonStreamResultDataTable), nonStreamBrokerResponse,
             DATATABLE_REDUCER_CONTEXT, mock(BrokerMetrics.class));
         assertEquals(streamingBrokerResponse.getResultTable().getRows().size(),
             nonStreamBrokerResponse.getResultTable().getRows().size());
diff --git 
a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java 
b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
index c61101d336..7b4263f46e 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
@@ -366,7 +366,17 @@ public class CommonConstants {
         public static final String EXPLAIN_PLAN_VERBOSE = "explainPlanVerbose";
         public static final String USE_MULTISTAGE_ENGINE = 
"useMultistageEngine";
         public static final String ENABLE_NULL_HANDLING = "enableNullHandling";
+
+        // Can be applied to aggregation and group-by queries to ask servers 
to directly return final results instead of
+        // intermediate results for aggregations.
         public static final String SERVER_RETURN_FINAL_RESULT = 
"serverReturnFinalResult";
+        // Can be applied to group-by queries to ask servers to directly 
return final results instead of intermediate
+        // results for aggregations. Different from 
SERVER_RETURN_FINAL_RESULT, this option should be used when the
+        // group key is not server partitioned, but the aggregated values are 
server partitioned. When this option is
+        // used, server will return final results, but won't directly trim the 
result to the query limit.
+        public static final String 
SERVER_RETURN_FINAL_RESULT_KEY_UNPARTITIONED =
+            "serverReturnFinalResultKeyUnpartitioned";
+
         // Reorder scan based predicates based on cardinality and number of 
selected values
         public static final String AND_SCAN_REORDERING = "AndScanReordering";
         public static final String SKIP_INDEXES = "skipIndexes";


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org


Reply via email to