This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new df6cef135b Make ser/de part of the AggregationFunction (#15158)
df6cef135b is described below

commit df6cef135b3d0ada64ca2616232dfbc2fc913202
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Mon Mar 3 19:06:01 2025 -0700

    Make ser/de part of the AggregationFunction (#15158)
---
 .../apache/pinot/common/datablock/DataBlock.java   |   1 +
 .../apache/pinot/core/common/ObjectSerDeUtils.java |  10 +-
 .../core/common/datablock/DataBlockBuilder.java    | 151 ++++++++++++++-------
 .../common/datatable/BaseDataTableBuilder.java     |  40 +++---
 .../core/common/datatable/DataTableBuilder.java    |  13 +-
 .../blocks/results/AggregationResultsBlock.java    |  62 ++++++---
 .../blocks/results/GroupByResultsBlock.java        |  41 ++++--
 .../aggregation/function/AggregationFunction.java  |  66 ++++++++-
 .../function/AggregationFunctionUtils.java         |   9 +-
 .../function/AvgAggregationFunction.java           |  12 ++
 .../query/reduce/AggregationDataTableReducer.java  |   9 +-
 .../core/query/reduce/GroupByDataTableReducer.java |   7 +-
 .../query/selection/SelectionOperatorUtils.java    |   2 +-
 .../pinot/core/util/DataBlockExtractUtils.java     |  43 +++++-
 .../common/datablock/DataBlockBuilderTest.java     |  98 ++++++-------
 .../core/common/datatable/DataTableSerDeTest.java  |  23 +---
 .../query/runtime/blocks/TransferableBlock.java    |  23 +++-
 .../runtime/blocks/TransferableBlockUtils.java     |   3 +-
 .../query/runtime/operator/AggregateOperator.java  |  24 ++--
 .../LeafStageTransferableBlockOperator.java        |  30 ++--
 .../operator/MultistageGroupByExecutor.java        |   4 +-
 .../runtime/operator/exchange/HashExchange.java    |   6 +-
 .../LeafStageTransferableBlockOperatorTest.java    |   4 +-
 .../pinot/segment/spi/AggregationFunctionType.java |   3 +-
 24 files changed, 439 insertions(+), 245 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlock.java 
b/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlock.java
index c198edc945..745ff137bd 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlock.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlock.java
@@ -84,6 +84,7 @@ public interface DataBlock {
 
   Map<String, Object> getMap(int rowId, int colId);
 
+  @Nullable
   CustomObject getCustomObject(int rowId, int colId);
 
   @Nullable
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java 
b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
index 379c697f76..7ffe313a34 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
@@ -1805,17 +1805,25 @@ public class ObjectSerDeUtils {
   };
   //@formatter:on
 
+  /**
+   * @deprecated Use each individual SER_DE class instead.
+   */
+  @Deprecated
   public static byte[] serialize(Object value, int objectTypeValue) {
     return SER_DES[objectTypeValue].serialize(value);
   }
 
+  /**
+   * @deprecated Use each individual SER_DE class instead.
+   */
+  @Deprecated
   public static <T> T deserialize(CustomObject customObject) {
     return (T) 
SER_DES[customObject.getType()].deserialize(customObject.getBuffer());
   }
 
   @VisibleForTesting
   public static byte[] serialize(Object value) {
-    return serialize(value, ObjectType.getObjectType(value)._value);
+    return SER_DES[ObjectType.getObjectType(value)._value].serialize(value);
   }
 
   @VisibleForTesting
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/common/datablock/DataBlockBuilder.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/common/datablock/DataBlockBuilder.java
index c795f5af25..3ade560b0b 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/common/datablock/DataBlockBuilder.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/common/datablock/DataBlockBuilder.java
@@ -36,7 +36,7 @@ import org.apache.pinot.common.datablock.RowDataBlock;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.common.utils.RoaringBitmapUtils;
-import org.apache.pinot.core.common.ObjectSerDeUtils;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.segment.spi.memory.CompoundDataBuffer;
 import org.apache.pinot.segment.spi.memory.PagedPinotOutputStream;
 import org.apache.pinot.segment.spi.memory.PinotByteBuffer;
@@ -46,19 +46,31 @@ import org.apache.pinot.spi.utils.MapUtils;
 import org.roaringbitmap.RoaringBitmap;
 
 
+@SuppressWarnings({"rawtypes", "unchecked"})
 public class DataBlockBuilder {
-
   private DataBlockBuilder() {
   }
 
   public static RowDataBlock buildFromRows(List<Object[]> rows, DataSchema 
dataSchema)
       throws IOException {
-    return buildFromRows(rows, dataSchema, 
PagedPinotOutputStream.HeapPageAllocator.createSmall());
+    return buildFromRows(rows, dataSchema, null, 
PagedPinotOutputStream.HeapPageAllocator.createSmall());
+  }
+
+  public static RowDataBlock buildFromRows(List<Object[]> rows, DataSchema 
dataSchema,
+      @Nullable AggregationFunction[] aggFunctions)
+      throws IOException {
+    return buildFromRows(rows, dataSchema, aggFunctions, 
PagedPinotOutputStream.HeapPageAllocator.createSmall());
   }
 
   public static RowDataBlock buildFromRows(List<Object[]> rows, DataSchema 
dataSchema,
       PagedPinotOutputStream.PageAllocator allocator)
       throws IOException {
+    return buildFromRows(rows, dataSchema, null, allocator);
+  }
+
+  public static RowDataBlock buildFromRows(List<Object[]> rows, DataSchema 
dataSchema,
+      @Nullable AggregationFunction[] aggFunctions, 
PagedPinotOutputStream.PageAllocator allocator)
+      throws IOException {
     int numRows = rows.size();
 
     // TODO: consolidate these null utils into data table utils.
@@ -74,8 +86,7 @@ public class DataBlockBuilder {
     int nullFixedBytes = numColumns * Integer.BYTES * 2;
     int rowSizeInBytes = calculateBytesPerRow(dataSchema);
     int fixedBytesRequired = rowSizeInBytes * numRows + nullFixedBytes;
-    ByteBuffer fixedSize = ByteBuffer.allocate(fixedBytesRequired)
-        .order(ByteOrder.BIG_ENDIAN);
+    ByteBuffer fixedSize = 
ByteBuffer.allocate(fixedBytesRequired).order(ByteOrder.BIG_ENDIAN);
 
     PagedPinotOutputStream varSize = new PagedPinotOutputStream(allocator);
     Object2IntOpenHashMap<String> dictionary = new Object2IntOpenHashMap<>();
@@ -84,15 +95,36 @@ public class DataBlockBuilder {
       Object[] row = rows.get(rowId);
       for (int colId = 0; colId < numColumns; colId++) {
         Object value = row[colId];
+        ColumnDataType storedType = storedTypes[colId];
+
+        if (storedType == ColumnDataType.OBJECT) {
+          // Custom intermediate result for aggregation function
+          assert aggFunctions != null;
+          if (value == null) {
+            setNull(fixedSize, varSize);
+          } else {
+            // NOTE: The first (numColumns - numAggFunctions) columns are key 
columns
+            int numAggFunctions = aggFunctions.length;
+            AggregationFunction aggFunction = aggFunctions[colId + 
numAggFunctions - numColumns];
+            setColumn(fixedSize, varSize, 
aggFunction.serializeIntermediateResult(value));
+          }
+          continue;
+        }
+
         if (value == null) {
-          nullBitmaps[colId].add(rowId);
-          value = nullPlaceholders[colId];
+          if (storedType == ColumnDataType.UNKNOWN) {
+            setNull(fixedSize, varSize);
+            continue;
+          } else {
+            nullBitmaps[colId].add(rowId);
+            value = nullPlaceholders[colId];
+          }
         }
 
         // NOTE:
         // We intentionally make the type casting very strict here (e.g. only 
accepting Integer for INT) to ensure the
         // rows conform to the data schema. This can help catch the unexpected 
data type issues early.
-        switch (storedTypes[colId]) {
+        switch (storedType) {
           // Single-value column
           case INT:
             fixedSize.putInt((int) value);
@@ -135,26 +167,20 @@ public class DataBlockBuilder {
           case STRING_ARRAY:
             setColumn(fixedSize, varSize, (String[]) value, dictionary);
             break;
-
-          // Special intermediate result for aggregation function
-          case OBJECT:
-            setColumn(fixedSize, varSize, value);
-            break;
-
           // Null
           case UNKNOWN:
-            setColumn(fixedSize, varSize, (Object) null);
+            setNull(fixedSize, varSize);
             break;
 
           default:
-            throw new IllegalStateException("Unsupported stored type: " + 
storedTypes[colId] + " for column: "
-                + dataSchema.getColumnName(colId));
+            throw new IllegalStateException(
+                "Unsupported stored type: " + storedType + " for column: " + 
dataSchema.getColumnName(colId));
         }
       }
     }
 
-    CompoundDataBuffer.Builder varBufferBuilder = new 
CompoundDataBuffer.Builder(ByteOrder.BIG_ENDIAN, true)
-        .addPagedOutputStream(varSize);
+    CompoundDataBuffer.Builder varBufferBuilder =
+        new CompoundDataBuffer.Builder(ByteOrder.BIG_ENDIAN, 
true).addPagedOutputStream(varSize);
 
     // Write null bitmaps after writing data.
     setNullRowIds(nullBitmaps, fixedSize, varBufferBuilder);
@@ -163,12 +189,24 @@ public class DataBlockBuilder {
 
   public static ColumnarDataBlock buildFromColumns(List<Object[]> columns, 
DataSchema dataSchema)
       throws IOException {
-    return buildFromColumns(columns, dataSchema, 
PagedPinotOutputStream.HeapPageAllocator.createSmall());
+    return buildFromColumns(columns, dataSchema, null, 
PagedPinotOutputStream.HeapPageAllocator.createSmall());
+  }
+
+  public static ColumnarDataBlock buildFromColumns(List<Object[]> columns, 
DataSchema dataSchema,
+      @Nullable AggregationFunction[] aggFunctions)
+      throws IOException {
+    return buildFromColumns(columns, dataSchema, aggFunctions, 
PagedPinotOutputStream.HeapPageAllocator.createSmall());
   }
 
   public static ColumnarDataBlock buildFromColumns(List<Object[]> columns, 
DataSchema dataSchema,
       PagedPinotOutputStream.PageAllocator allocator)
       throws IOException {
+    return buildFromColumns(columns, dataSchema, null, allocator);
+  }
+
+  public static ColumnarDataBlock buildFromColumns(List<Object[]> columns, 
DataSchema dataSchema,
+      @Nullable AggregationFunction[] aggFunctions, 
PagedPinotOutputStream.PageAllocator allocator)
+      throws IOException {
     int numRows = columns.isEmpty() ? 0 : columns.get(0).length;
 
     int fixedBytesPerRow = calculateBytesPerRow(dataSchema);
@@ -189,7 +227,13 @@ public class DataBlockBuilder {
       for (int colId = 0; colId < numColumns; colId++) {
         RoaringBitmap nullBitmap = new RoaringBitmap();
         nullBitmaps[colId] = nullBitmap;
-        serializeColumnData(columns, dataSchema, colId, fixedSize, varSize, 
nullBitmap, dictionary);
+        AggregationFunction aggFunction = null;
+        if (aggFunctions != null) {
+          // NOTE: The first (numColumns - numAggFunctions) columns are key 
columns
+          int numAggFunctions = aggFunctions.length;
+          aggFunction = aggFunctions[colId + numAggFunctions - numColumns];
+        }
+        serializeColumnData(columns, dataSchema, colId, fixedSize, varSize, 
nullBitmap, dictionary, aggFunction);
       }
       varBufferBuilder.addPagedOutputStream(varSize);
     }
@@ -200,7 +244,7 @@ public class DataBlockBuilder {
 
   private static void serializeColumnData(List<Object[]> columns, DataSchema 
dataSchema, int colId,
       ByteBuffer fixedSize, PagedPinotOutputStream varSize, RoaringBitmap 
nullBitmap,
-      Object2IntOpenHashMap<String> dictionary)
+      Object2IntOpenHashMap<String> dictionary, @Nullable AggregationFunction 
aggFunction)
       throws IOException {
     ColumnDataType storedType = 
dataSchema.getColumnDataType(colId).getStoredType();
     int numRows = columns.get(colId).length;
@@ -385,24 +429,29 @@ public class DataBlockBuilder {
         }
         break;
       }
-
-      // Special intermediate result for aggregation function
+      // Custom intermediate result for aggregation function
       case OBJECT: {
+        assert aggFunction != null;
         for (int rowId = 0; rowId < numRows; rowId++) {
-          setColumn(fixedSize, varSize, column[rowId]);
+          Object value = column[rowId];
+          if (value == null) {
+            setNull(fixedSize, varSize);
+          } else {
+            setColumn(fixedSize, varSize, 
aggFunction.serializeIntermediateResult(value));
+          }
         }
         break;
       }
       // Null
       case UNKNOWN:
         for (int rowId = 0; rowId < numRows; rowId++) {
-          setColumn(fixedSize, varSize, (Object) null);
+          setNull(fixedSize, varSize);
         }
         break;
 
       default:
-        throw new IllegalStateException("Unsupported stored type: " + 
storedType + " for column: "
-            + dataSchema.getColumnName(colId));
+        throw new IllegalStateException(
+            "Unsupported stored type: " + storedType + " for column: " + 
dataSchema.getColumnName(colId));
     }
   }
 
@@ -444,11 +493,9 @@ public class DataBlockBuilder {
   private static void setNullRowIds(RoaringBitmap[] nullVectors, ByteBuffer 
fixedSize,
       CompoundDataBuffer.Builder varBufferBuilder)
       throws IOException {
-    int varBufSize = Arrays.stream(nullVectors)
-        .mapToInt(bitmap -> bitmap == null ? 0 : 
bitmap.serializedSizeInBytes())
-        .sum();
-    ByteBuffer variableSize = ByteBuffer.allocate(varBufSize)
-        .order(ByteOrder.BIG_ENDIAN);
+    int varBufSize =
+        Arrays.stream(nullVectors).mapToInt(bitmap -> bitmap == null ? 0 : 
bitmap.serializedSizeInBytes()).sum();
+    ByteBuffer variableSize = 
ByteBuffer.allocate(varBufSize).order(ByteOrder.BIG_ENDIAN);
 
     long varWrittenBytes = varBufferBuilder.getWrittenBytes();
     Preconditions.checkArgument(varWrittenBytes < Integer.MAX_VALUE,
@@ -474,8 +521,8 @@ public class DataBlockBuilder {
 
   private static ColumnarDataBlock buildColumnarBlock(int numRows, DataSchema 
dataSchema, String[] dictionary,
       ByteBuffer fixedSize, CompoundDataBuffer.Builder varBufferBuilder) {
-    return new ColumnarDataBlock(numRows, dataSchema, dictionary,
-        PinotByteBuffer.wrap(fixedSize), varBufferBuilder.build());
+    return new ColumnarDataBlock(numRows, dataSchema, dictionary, 
PinotByteBuffer.wrap(fixedSize),
+        varBufferBuilder.build());
   }
 
   private static String[] getReverseDictionary(Object2IntOpenHashMap<String> 
dictionary) {
@@ -511,22 +558,6 @@ public class DataBlockBuilder {
     varSize.write(bytes);
   }
 
-  // TODO: Move ser/de into AggregationFunction interface
-  private static void setColumn(ByteBuffer fixedSize, PagedPinotOutputStream 
varSize, @Nullable Object value)
-      throws IOException {
-    writeVarOffsetInFixed(fixedSize, varSize);
-    if (value == null) {
-      fixedSize.putInt(0);
-      varSize.writeInt(CustomObject.NULL_TYPE_VALUE);
-    } else {
-      int objectTypeValue = 
ObjectSerDeUtils.ObjectType.getObjectType(value).getValue();
-      byte[] bytes = ObjectSerDeUtils.serialize(value, objectTypeValue);
-      fixedSize.putInt(bytes.length);
-      varSize.writeInt(objectTypeValue);
-      varSize.write(bytes);
-    }
-  }
-
   private static void setColumn(ByteBuffer fixedSize, PagedPinotOutputStream 
varSize, int[] values)
       throws IOException {
     writeVarOffsetInFixed(fixedSize, varSize);
@@ -573,4 +604,22 @@ public class DataBlockBuilder {
       varSize.writeInt(dictId);
     }
   }
+
+  private static void setColumn(ByteBuffer fixedSize, PagedPinotOutputStream 
varSize,
+      AggregationFunction.SerializedIntermediateResult value)
+      throws IOException {
+    writeVarOffsetInFixed(fixedSize, varSize);
+    int type = value.getType();
+    byte[] bytes = value.getBytes();
+    fixedSize.putInt(bytes.length);
+    varSize.writeInt(type);
+    varSize.write(bytes);
+  }
+
+  private static void setNull(ByteBuffer fixedSize, PagedPinotOutputStream 
varSize)
+      throws IOException {
+    writeVarOffsetInFixed(fixedSize, varSize);
+    fixedSize.putInt(0);
+    varSize.writeInt(CustomObject.NULL_TYPE_VALUE);
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/BaseDataTableBuilder.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/BaseDataTableBuilder.java
index 1862766af7..a351937d9b 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/BaseDataTableBuilder.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/BaseDataTableBuilder.java
@@ -28,7 +28,7 @@ import javax.annotation.Nullable;
 import org.apache.pinot.common.CustomObject;
 import org.apache.pinot.common.datatable.DataTableUtils;
 import org.apache.pinot.common.utils.DataSchema;
-import org.apache.pinot.core.common.ObjectSerDeUtils;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.spi.utils.BigDecimalUtils;
 import org.apache.pinot.spi.utils.MapUtils;
 
@@ -112,23 +112,6 @@ public abstract class BaseDataTableBuilder implements 
DataTableBuilder {
     }
   }
 
-  @Override
-  public void setColumn(int colId, @Nullable Object value)
-      throws IOException {
-    _currentRowDataByteBuffer.position(_columnOffsets[colId]);
-    
_currentRowDataByteBuffer.putInt(_variableSizeDataByteArrayOutputStream.size());
-    if (value == null) {
-      _currentRowDataByteBuffer.putInt(0);
-      _variableSizeDataOutputStream.writeInt(CustomObject.NULL_TYPE_VALUE);
-    } else {
-      int objectTypeValue = 
ObjectSerDeUtils.ObjectType.getObjectType(value).getValue();
-      byte[] bytes = ObjectSerDeUtils.serialize(value, objectTypeValue);
-      _currentRowDataByteBuffer.putInt(bytes.length);
-      _variableSizeDataOutputStream.writeInt(objectTypeValue);
-      _variableSizeDataByteArrayOutputStream.write(bytes);
-    }
-  }
-
   @Override
   public void setColumn(int colId, int[] values)
       throws IOException {
@@ -173,6 +156,27 @@ public abstract class BaseDataTableBuilder implements 
DataTableBuilder {
     }
   }
 
+  @Override
+  public void setColumn(int colId, 
AggregationFunction.SerializedIntermediateResult value)
+      throws IOException {
+    _currentRowDataByteBuffer.position(_columnOffsets[colId]);
+    
_currentRowDataByteBuffer.putInt(_variableSizeDataByteArrayOutputStream.size());
+    int type = value.getType();
+    byte[] bytes = value.getBytes();
+    _currentRowDataByteBuffer.putInt(bytes.length);
+    _variableSizeDataOutputStream.writeInt(type);
+    _variableSizeDataByteArrayOutputStream.write(bytes);
+  }
+
+  @Override
+  public void setNull(int colId)
+      throws IOException {
+    _currentRowDataByteBuffer.position(_columnOffsets[colId]);
+    
_currentRowDataByteBuffer.putInt(_variableSizeDataByteArrayOutputStream.size());
+    _currentRowDataByteBuffer.putInt(0);
+    _variableSizeDataOutputStream.writeInt(CustomObject.NULL_TYPE_VALUE);
+  }
+
   @Override
   public void finishRow()
       throws IOException {
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableBuilder.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableBuilder.java
index 2e9d04ca9a..bed93863b5 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableBuilder.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableBuilder.java
@@ -23,6 +23,7 @@ import java.math.BigDecimal;
 import java.util.Map;
 import javax.annotation.Nullable;
 import org.apache.pinot.common.datatable.DataTable;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.spi.annotations.InterfaceAudience;
 import org.apache.pinot.spi.annotations.InterfaceStability;
 import org.apache.pinot.spi.utils.ByteArray;
@@ -40,6 +41,8 @@ import org.roaringbitmap.RoaringBitmap;
  * into objects like Integer etc. This will waste cpu resource and increase 
the payload size. We optimize the data
  * format for Pinot use case. We can also support lazy construction of 
objects. In fact we retain the bytes as it is and
  * will be able to look up a field directly within a byte buffer.
+ *
+ * TODO: Consider skipping seeking for the column offsets and directly write 
to the byte buffer
  */
 @InterfaceAudience.Private
 @InterfaceStability.Evolving
@@ -66,10 +69,6 @@ public interface DataTableBuilder {
   void setColumn(int colId, @Nullable Map<String, Object> value)
       throws IOException;
 
-  // TODO: Move ser/de into AggregationFunction interface
-  void setColumn(int colId, @Nullable Object value)
-      throws IOException;
-
   void setColumn(int colId, int[] values)
       throws IOException;
 
@@ -87,6 +86,12 @@ public interface DataTableBuilder {
 
   // TODO: Support MV BYTES
 
+  void setColumn(int colId, AggregationFunction.SerializedIntermediateResult 
value)
+      throws IOException;
+
+  void setNull(int colId)
+      throws IOException;
+
   void finishRow()
       throws IOException;
 
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
index 0fd2e29a25..48570ff1d0 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java
@@ -121,18 +121,34 @@ public class AggregationResultsBlock extends 
BaseResultsBlock {
         nullBitmaps[i] = new RoaringBitmap();
       }
       dataTableBuilder.startRow();
-      for (int i = 0; i < numColumns; i++) {
-        Object result =
-            returnFinalResult ? 
_aggregationFunctions[i].extractFinalResult(_results.get(i)) : _results.get(i);
-        if (result == null) {
-          result = columnDataTypes[i].getNullPlaceholder();
-          nullBitmaps[i].add(0);
-        }
-        if (!returnFinalResult) {
-          setIntermediateResult(dataTableBuilder, columnDataTypes, i, result);
-        } else {
+      if (returnFinalResult) {
+        for (int i = 0; i < numColumns; i++) {
+          Object result = 
_aggregationFunctions[i].extractFinalResult(_results.get(i));
+          if (result == null) {
+            result = columnDataTypes[i].getNullPlaceholder();
+            nullBitmaps[i].add(0);
+          }
+          assert result != null;
           setFinalResult(dataTableBuilder, columnDataTypes, i, result);
         }
+      } else {
+        for (int i = 0; i < numColumns; i++) {
+          Object result = _results.get(i);
+          if (columnDataTypes[i] == ColumnDataType.OBJECT) {
+            if (result == null) {
+              dataTableBuilder.setNull(i);
+            } else {
+              dataTableBuilder.setColumn(i, 
_aggregationFunctions[i].serializeIntermediateResult(result));
+            }
+          } else {
+            if (result == null) {
+              result = columnDataTypes[i].getNullPlaceholder();
+              nullBitmaps[i].add(0);
+            }
+            assert result != null;
+            setIntermediateResult(dataTableBuilder, columnDataTypes, i, 
result);
+          }
+        }
       }
       dataTableBuilder.finishRow();
       for (RoaringBitmap nullBitmap : nullBitmaps) {
@@ -140,14 +156,22 @@ public class AggregationResultsBlock extends 
BaseResultsBlock {
       }
     } else {
       dataTableBuilder.startRow();
-      for (int i = 0; i < numColumns; i++) {
-        Object result = _results.get(i);
-        if (!returnFinalResult) {
-          setIntermediateResult(dataTableBuilder, columnDataTypes, i, result);
-        } else {
-          result = _aggregationFunctions[i].extractFinalResult(result);
+      if (returnFinalResult) {
+        for (int i = 0; i < numColumns; i++) {
+          Object result = 
_aggregationFunctions[i].extractFinalResult(_results.get(i));
+          assert result != null;
           setFinalResult(dataTableBuilder, columnDataTypes, i, result);
         }
+      } else {
+        for (int i = 0; i < numColumns; i++) {
+          Object result = _results.get(i);
+          assert result != null;
+          if (columnDataTypes[i] == ColumnDataType.OBJECT) {
+            dataTableBuilder.setColumn(i, 
_aggregationFunctions[i].serializeIntermediateResult(result));
+          } else {
+            setIntermediateResult(dataTableBuilder, columnDataTypes, i, 
result);
+          }
+        }
       }
       dataTableBuilder.finishRow();
     }
@@ -155,8 +179,7 @@ public class AggregationResultsBlock extends 
BaseResultsBlock {
   }
 
   private void setIntermediateResult(DataTableBuilder dataTableBuilder, 
ColumnDataType[] columnDataTypes, int index,
-      Object result)
-      throws IOException {
+      Object result) {
     ColumnDataType columnDataType = columnDataTypes[index];
     switch (columnDataType) {
       case INT:
@@ -168,9 +191,6 @@ public class AggregationResultsBlock extends 
BaseResultsBlock {
       case DOUBLE:
         dataTableBuilder.setColumn(index, (double) result);
         break;
-      case OBJECT:
-        dataTableBuilder.setColumn(index, result);
-        break;
       default:
         throw new IllegalStateException("Illegal column data type in 
intermediate result: " + columnDataType);
     }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/GroupByResultsBlock.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/GroupByResultsBlock.java
index b1bf65b7eb..21eeea73ca 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/GroupByResultsBlock.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/GroupByResultsBlock.java
@@ -33,6 +33,7 @@ import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.datatable.DataTable;
 import org.apache.pinot.common.datatable.DataTable.MetadataKey;
+import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.utils.ArrayListUtils;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
@@ -41,6 +42,7 @@ import 
org.apache.pinot.core.common.datatable.DataTableBuilderFactory;
 import org.apache.pinot.core.data.table.IntermediateRecord;
 import org.apache.pinot.core.data.table.Record;
 import org.apache.pinot.core.data.table.Table;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import 
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
 import org.apache.pinot.core.query.request.context.QueryContext;
 import org.apache.pinot.spi.trace.Tracing;
@@ -51,6 +53,7 @@ import org.roaringbitmap.RoaringBitmap;
 /**
  * Results block for group-by queries.
  */
+@SuppressWarnings({"rawtypes", "unchecked"})
 public class GroupByResultsBlock extends BaseResultsBlock {
   private final DataSchema _dataSchema;
   private final AggregationGroupByResult _aggregationGroupByResult;
@@ -191,6 +194,10 @@ public class GroupByResultsBlock extends BaseResultsBlock {
     }
     ColumnDataType[] storedColumnDataTypes = 
_dataSchema.getStoredColumnDataTypes();
     int numColumns = _dataSchema.size();
+    AggregationFunction[] aggregationFunctions = 
_queryContext.getAggregationFunctions();
+    List<ExpressionContext> groupByExpressions = 
_queryContext.getGroupByExpressions();
+    assert aggregationFunctions != null && groupByExpressions != null;
+    int numKeyColumns = groupByExpressions.size();
     Iterator<Record> iterator = _table.iterator();
     int numRowsAdded = 0;
     if (_queryContext.isNullHandlingEnabled()) {
@@ -205,13 +212,22 @@ public class GroupByResultsBlock extends BaseResultsBlock 
{
         
Tracing.ThreadAccountantOps.sampleAndCheckInterruptionPeriodically(numRowsAdded);
         dataTableBuilder.startRow();
         Object[] values = iterator.next().getValues();
-        for (int colId = 0; colId < numColumns; colId++) {
-          Object value = values[colId];
-          if (value == null && storedColumnDataTypes[colId] != 
ColumnDataType.OBJECT) {
-            value = nullPlaceholders[colId];
-            nullBitmaps[colId].add(rowId);
+        for (int i = 0; i < numColumns; i++) {
+          Object value = values[i];
+          if (storedColumnDataTypes[i] == ColumnDataType.OBJECT) {
+            if (value == null) {
+              dataTableBuilder.setNull(i);
+            } else {
+              dataTableBuilder.setColumn(i, aggregationFunctions[i - 
numKeyColumns].serializeIntermediateResult(value));
+            }
+          } else {
+            if (value == null) {
+              value = nullPlaceholders[i];
+              nullBitmaps[i].add(rowId);
+            }
+            assert value != null;
+            setDataTableColumn(storedColumnDataTypes[i], dataTableBuilder, i, 
value);
           }
-          setDataTableColumn(storedColumnDataTypes[colId], dataTableBuilder, 
colId, value);
         }
         dataTableBuilder.finishRow();
         numRowsAdded++;
@@ -225,8 +241,14 @@ public class GroupByResultsBlock extends BaseResultsBlock {
         
Tracing.ThreadAccountantOps.sampleAndCheckInterruptionPeriodically(numRowsAdded);
         dataTableBuilder.startRow();
         Object[] values = iterator.next().getValues();
-        for (int colId = 0; colId < numColumns; colId++) {
-          setDataTableColumn(storedColumnDataTypes[colId], dataTableBuilder, 
colId, values[colId]);
+        for (int i = 0; i < numColumns; i++) {
+          Object value = values[i];
+          assert value != null;
+          if (storedColumnDataTypes[i] == ColumnDataType.OBJECT) {
+            dataTableBuilder.setColumn(i, aggregationFunctions[i - 
numKeyColumns].serializeIntermediateResult(value));
+          } else {
+            setDataTableColumn(storedColumnDataTypes[i], dataTableBuilder, i, 
value);
+          }
         }
         dataTableBuilder.finishRow();
         numRowsAdded++;
@@ -296,9 +318,6 @@ public class GroupByResultsBlock extends BaseResultsBlock {
           dataTableBuilder.setColumn(columnIndex, (String[]) value);
         }
         break;
-      case OBJECT:
-        dataTableBuilder.setColumn(columnIndex, value);
-        break;
       default:
         throw new IllegalStateException("Unsupported stored type: " + 
storedColumnDataType);
     }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
index 4bac389197..acc0686cf1 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
@@ -20,10 +20,13 @@ package org.apache.pinot.core.query.aggregation.function;
 
 import java.util.List;
 import java.util.Map;
+import javax.annotation.Nullable;
 import javax.annotation.concurrent.ThreadSafe;
+import org.apache.pinot.common.CustomObject;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.common.ObjectSerDeUtils;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 import org.apache.pinot.segment.spi.AggregationFunctionType;
@@ -90,21 +93,22 @@ public interface AggregationFunction<IntermediateResult, 
FinalResult extends Com
 
   /**
    * Extracts the intermediate result from the aggregation result holder 
(aggregation only).
-   * TODO: Support serializing/deserializing null values in DataTable and use 
null as the empty intermediate result
    */
+  @Nullable
   IntermediateResult extractAggregationResult(AggregationResultHolder 
aggregationResultHolder);
 
   /**
    * Extracts the intermediate result from the group-by result holder for the 
given group key (aggregation group-by).
-   * TODO: Support serializing/deserializing null values in DataTable and use 
null as the empty intermediate result
    */
+  @Nullable
   IntermediateResult extractGroupByResult(GroupByResultHolder 
groupByResultHolder, int groupKey);
 
   /**
    * Merges two intermediate results.
-   * TODO: Support serializing/deserializing null values in DataTable and use 
null as the empty intermediate result
    */
-  IntermediateResult merge(IntermediateResult intermediateResult1, 
IntermediateResult intermediateResult2);
+  @Nullable
+  IntermediateResult merge(@Nullable IntermediateResult intermediateResult1,
+      @Nullable IntermediateResult intermediateResult2);
 
   /**
    * Returns the {@link ColumnDataType} of the intermediate result.
@@ -112,6 +116,53 @@ public interface AggregationFunction<IntermediateResult, 
FinalResult extends Com
    */
   ColumnDataType getIntermediateResultColumnType();
 
+  /**
+   * Serializes the intermediate result into a custom object. This method 
should be implemented if the intermediate
+   * result type is OBJECT.
+   *
+   * TODO: Override this method in the aggregation functions that return 
OBJECT type intermediate results to reduce the
+   *       overhead of instanceof checks in the default implementation.
+   */
+  default SerializedIntermediateResult 
serializeIntermediateResult(IntermediateResult intermediateResult) {
+    assert getIntermediateResultColumnType() == ColumnDataType.OBJECT;
+    int type = 
ObjectSerDeUtils.ObjectType.getObjectType(intermediateResult).getValue();
+    byte[] bytes = ObjectSerDeUtils.serialize(intermediateResult, type);
+    return new SerializedIntermediateResult(type, bytes);
+  }
+
+  /**
+   * Serialized intermediate result. Type can be used to identify the 
intermediate result type when deserializing it.
+   */
+  class SerializedIntermediateResult {
+    private final int _type;
+    private final byte[] _bytes;
+
+    public SerializedIntermediateResult(int type, byte[] buffer) {
+      _type = type;
+      _bytes = buffer;
+    }
+
+    public int getType() {
+      return _type;
+    }
+
+    public byte[] getBytes() {
+      return _bytes;
+    }
+  }
+
+  /**
+   * Deserializes the intermediate result from the custom object. This method 
should be implemented if the intermediate
+   * result type is OBJECT.
+   *
+   * TODO: Override this method in the aggregation functions that return 
OBJECT type intermediate results to not rely
+   *       on the type to decouple this from ObjectSerDeUtils.
+   */
+  default IntermediateResult deserializeIntermediateResult(CustomObject 
customObject) {
+    assert getIntermediateResultColumnType() == ColumnDataType.OBJECT;
+    return ObjectSerDeUtils.deserialize(customObject);
+  }
+
   /**
    * Returns the {@link ColumnDataType} of the final result.
    * <p>This column data type is used for constructing the result table</p>
@@ -120,15 +171,16 @@ public interface AggregationFunction<IntermediateResult, 
FinalResult extends Com
 
   /**
    * Extracts the final result used in the broker response from the given 
intermediate result.
-   * TODO: Support serializing/deserializing null values in DataTable and use 
null as the empty intermediate result
    */
-  FinalResult extractFinalResult(IntermediateResult intermediateResult);
+  @Nullable
+  FinalResult extractFinalResult(@Nullable IntermediateResult 
intermediateResult);
 
   /**
    * Merges two final results. This can be used to optimized certain functions 
(e.g. DISTINCT_COUNT) when data is
    * partitioned on each server, where we may directly request servers to 
return final result and merge them on broker.
    */
-  default FinalResult mergeFinalResult(FinalResult finalResult1, FinalResult 
finalResult2) {
+  @Nullable
+  default FinalResult mergeFinalResult(@Nullable FinalResult finalResult1, 
@Nullable FinalResult finalResult2) {
     throw new UnsupportedOperationException("Cannot merge final results for 
function: " + getType());
   }
 
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
index a88cd7623b..5a74894116 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
@@ -41,7 +41,6 @@ import 
org.apache.pinot.common.request.context.predicate.Predicate;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.common.utils.config.QueryOptionsUtils;
 import org.apache.pinot.core.common.BlockValSet;
-import org.apache.pinot.core.common.ObjectSerDeUtils;
 import org.apache.pinot.core.operator.BaseProjectOperator;
 import org.apache.pinot.core.operator.blocks.ValueBlock;
 import org.apache.pinot.core.operator.filter.BaseFilterOperator;
@@ -142,10 +141,10 @@ public class AggregationFunctionUtils {
 
   /**
    * Reads the intermediate result from the {@link DataTable}.
-   *
-   * TODO: Move ser/de into AggregationFunction interface
    */
-  public static Object getIntermediateResult(DataTable dataTable, 
ColumnDataType columnDataType, int rowId, int colId) {
+  @Nullable
+  public static Object getIntermediateResult(AggregationFunction 
aggregationFunction, DataTable dataTable,
+      ColumnDataType columnDataType, int rowId, int colId) {
     switch (columnDataType.getStoredType()) {
       case INT:
         return dataTable.getInt(rowId, colId);
@@ -155,7 +154,7 @@ public class AggregationFunctionUtils {
         return dataTable.getDouble(rowId, colId);
       case OBJECT:
         CustomObject customObject = dataTable.getCustomObject(rowId, colId);
-        return customObject != null ? 
ObjectSerDeUtils.deserialize(customObject) : null;
+        return customObject != null ? 
aggregationFunction.deserializeIntermediateResult(customObject) : null;
       default:
         throw new IllegalStateException("Illegal column data type in 
intermediate result: " + columnDataType);
     }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
index b6fae6d340..9988b1926d 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function;
 
 import java.util.List;
 import java.util.Map;
+import org.apache.pinot.common.CustomObject;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.core.common.BlockValSet;
@@ -200,6 +201,17 @@ public class AvgAggregationFunction extends 
NullableSingleInputAggregationFuncti
     return ColumnDataType.OBJECT;
   }
 
+  @Override
+  public SerializedIntermediateResult serializeIntermediateResult(AvgPair 
avgPair) {
+    // ObjectSerDeUtils.ObjectType.AvgPair.getValue() == 4
+    return new SerializedIntermediateResult(4, avgPair.toBytes());
+  }
+
+  @Override
+  public AvgPair deserializeIntermediateResult(CustomObject customObject) {
+    return AvgPair.fromByteBuffer(customObject.getBuffer());
+  }
+
   @Override
   public ColumnDataType getFinalResultColumnType() {
     return ColumnDataType.DOUBLE;
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
index 1c39b6971b..a690203d4e 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
@@ -87,6 +87,7 @@ public class AggregationDataTableReducer implements 
DataTableReducer {
     for (DataTable dataTable : dataTables) {
       Tracing.ThreadAccountantOps.sampleAndCheckInterruption();
       for (int i = 0; i < numAggregationFunctions; i++) {
+        AggregationFunction aggregationFunction = _aggregationFunctions[i];
         Object intermediateResultToMerge;
         ColumnDataType columnDataType = dataSchema.getColumnDataType(i);
         if (_queryContext.isNullHandlingEnabled()) {
@@ -94,16 +95,18 @@ public class AggregationDataTableReducer implements 
DataTableReducer {
           if (nullBitmap != null && nullBitmap.contains(0)) {
             intermediateResultToMerge = null;
           } else {
-            intermediateResultToMerge = 
AggregationFunctionUtils.getIntermediateResult(dataTable, columnDataType, 0, i);
+            intermediateResultToMerge =
+                
AggregationFunctionUtils.getIntermediateResult(aggregationFunction, dataTable, 
columnDataType, 0, i);
           }
         } else {
-          intermediateResultToMerge = 
AggregationFunctionUtils.getIntermediateResult(dataTable, columnDataType, 0, i);
+          intermediateResultToMerge =
+              
AggregationFunctionUtils.getIntermediateResult(aggregationFunction, dataTable, 
columnDataType, 0, i);
         }
         Object mergedIntermediateResult = intermediateResults[i];
         if (mergedIntermediateResult == null) {
           intermediateResults[i] = intermediateResultToMerge;
         } else {
-          intermediateResults[i] = 
_aggregationFunctions[i].merge(mergedIntermediateResult, 
intermediateResultToMerge);
+          intermediateResults[i] = 
aggregationFunction.merge(mergedIntermediateResult, intermediateResultToMerge);
         }
       }
     }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
index e1db966f1b..3023d9b020 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
@@ -50,7 +50,6 @@ import 
org.apache.pinot.common.response.broker.QueryProcessingException;
 import org.apache.pinot.common.response.broker.ResultTable;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
-import org.apache.pinot.core.common.ObjectSerDeUtils;
 import org.apache.pinot.core.data.table.IndexedTable;
 import org.apache.pinot.core.data.table.Record;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
@@ -322,10 +321,12 @@ public class GroupByDataTableReducer implements 
DataTableReducer {
                       values[colId] = 
ObjectArrayList.wrap(dataTable.getStringArray(rowId, colId));
                       break;
                     case OBJECT:
-                      // TODO: Move ser/de into AggregationFunction interface
                       CustomObject customObject = 
dataTable.getCustomObject(rowId, colId);
                       if (customObject != null) {
-                        values[colId] = 
ObjectSerDeUtils.deserialize(customObject);
+                        assert _aggregationFunctions != null;
+                        values[colId] =
+                            _aggregationFunctions[colId - 
_numGroupByExpressions].deserializeIntermediateResult(
+                                customObject);
                       }
                       break;
                     // Add other aggregation intermediate result / group-by 
column type supports here
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/selection/SelectionOperatorUtils.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/selection/SelectionOperatorUtils.java
index 20b52876ea..f6b9ec8f4a 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/selection/SelectionOperatorUtils.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/selection/SelectionOperatorUtils.java
@@ -419,7 +419,7 @@ public class SelectionOperatorUtils {
             dataTableBuilder.setColumn(i, (Map) columnValue);
             break;
           case UNKNOWN:
-            dataTableBuilder.setColumn(i, (Object) null);
+            dataTableBuilder.setNull(i);
             break;
 
           // Multi-value column
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/util/DataBlockExtractUtils.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/util/DataBlockExtractUtils.java
index 76b0a62322..e8eed72c96 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/util/DataBlockExtractUtils.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/util/DataBlockExtractUtils.java
@@ -24,10 +24,11 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import javax.annotation.Nullable;
+import org.apache.pinot.common.CustomObject;
 import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
-import org.apache.pinot.core.common.ObjectSerDeUtils;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.spi.data.FieldSpec.DataType;
 import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder;
 import org.apache.pinot.spi.utils.MapUtils;
@@ -97,9 +98,9 @@ public final class DataBlockExtractUtils {
       case STRING_ARRAY:
         return dataBlock.getStringArray(rowId, colId);
 
-      // Special intermediate result for aggregation function
-      case OBJECT:
-        return ObjectSerDeUtils.deserialize(dataBlock.getCustomObject(rowId, 
colId));
+      // Null
+      case UNKNOWN:
+        return null;
 
       default:
         throw new IllegalStateException("Unsupported stored type: " + 
storedType + " for column: "
@@ -107,6 +108,36 @@ public final class DataBlockExtractUtils {
     }
   }
 
+  public static Object[] extractAggResult(DataBlock dataBlock, int colId, 
AggregationFunction aggFunction) {
+    DataSchema dataSchema = dataBlock.getDataSchema();
+    ColumnDataType storedType = 
dataSchema.getColumnDataType(colId).getStoredType();
+    int numRows = dataBlock.getNumberOfRows();
+    Object[] values = new Object[numRows];
+    if (storedType == ColumnDataType.OBJECT) {
+      // Ignore null bitmap for custom object because null is supported in 
custom object
+      for (int rowId = 0; rowId < numRows; rowId++) {
+        CustomObject customObject = dataBlock.getCustomObject(rowId, colId);
+        if (customObject != null) {
+          values[rowId] = 
aggFunction.deserializeIntermediateResult(customObject);
+        }
+      }
+    } else {
+      RoaringBitmap nullBitmap = dataBlock.getNullRowIds(colId);
+      if (nullBitmap == null) {
+        for (int rowId = 0; rowId < numRows; rowId++) {
+          values[rowId] = extractValue(dataBlock, storedType, rowId, colId);
+        }
+      } else {
+        for (int rowId = 0; rowId < numRows; rowId++) {
+          if (!nullBitmap.contains(rowId)) {
+            values[rowId] = extractValue(dataBlock, storedType, rowId, colId);
+          }
+        }
+      }
+    }
+    return values;
+  }
+
   public static Object[][] extractKeys(DataBlock dataBlock, int[] keyIds) {
     DataSchema dataSchema = dataBlock.getDataSchema();
     int numKeys = keyIds.length;
@@ -157,7 +188,7 @@ public final class DataBlockExtractUtils {
     return keys;
   }
 
-  public static Object[] extractColumn(DataBlock dataBlock, int colId) {
+  public static Object[] extractKey(DataBlock dataBlock, int colId) {
     DataSchema dataSchema = dataBlock.getDataSchema();
     ColumnDataType storedType = 
dataSchema.getColumnDataType(colId).getStoredType();
     RoaringBitmap nullBitmap = dataBlock.getNullRowIds(colId);
@@ -177,7 +208,7 @@ public final class DataBlockExtractUtils {
     return values;
   }
 
-  public static Object[] extractColumn(DataBlock dataBlock, int colId, int 
numMatchedRows,
+  public static Object[] extractKey(DataBlock dataBlock, int colId, int 
numMatchedRows,
       RoaringBitmap matchedBitmap) {
     DataSchema dataSchema = dataBlock.getDataSchema();
     ColumnDataType storedType = 
dataSchema.getColumnDataType(colId).getStoredType();
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/common/datablock/DataBlockBuilderTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/common/datablock/DataBlockBuilderTest.java
index 8306d0f8b6..0a0326bdcc 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/common/datablock/DataBlockBuilderTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/common/datablock/DataBlockBuilderTest.java
@@ -21,56 +21,56 @@ package org.apache.pinot.core.common.datablock;
 import java.io.IOException;
 import java.math.BigDecimal;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.BitSet;
 import java.util.Collections;
+import java.util.EnumSet;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
 import java.util.function.IntFunction;
-import org.apache.pinot.common.CustomObject;
 import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.common.utils.DataSchema;
-import org.apache.pinot.core.common.ObjectSerDeUtils;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.spi.utils.ByteArray;
 import org.roaringbitmap.RoaringBitmap;
 import org.testng.Assert;
-import org.testng.SkipException;
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
-import static org.testng.Assert.*;
+import static org.mockito.Mockito.mock;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNull;
 
 
+@SuppressWarnings("rawtypes")
 public class DataBlockBuilderTest {
 
   @DataProvider(name = "columnDataTypes")
-  DataSchema.ColumnDataType[] columnDataTypes() {
-    return Arrays.stream(DataSchema.ColumnDataType.values())
-        .map(DataSchema.ColumnDataType::getStoredType)
-        .distinct()
-        .toArray(DataSchema.ColumnDataType[]::new);
+  ColumnDataType[] columnDataTypes() {
+    return 
EnumSet.complementOf(EnumSet.of(ColumnDataType.BYTES_ARRAY)).toArray(new 
ColumnDataType[0]);
   }
 
   @Test(dataProvider = "columnDataTypes")
-  void testRowBlock(DataSchema.ColumnDataType type)
+  void testRowBlock(ColumnDataType type)
       throws IOException {
     int numRows = 100;
     List<Object[]> rows = generateRows(type, numRows);
-
-    DataSchema dataSchema = new DataSchema(new String[]{"column"}, new 
DataSchema.ColumnDataType[]{type});
-
-    DataBlock rowDataBlock = DataBlockBuilder.buildFromRows(rows, dataSchema);
-
+    DataSchema dataSchema = new DataSchema(new String[]{"column"}, new 
ColumnDataType[]{type});
+    AggregationFunction[] aggFunctions = null;
+    if (type == ColumnDataType.OBJECT) {
+      aggFunctions = new 
AggregationFunction[]{mock(AggregationFunction.class)};
+    }
+    DataBlock rowDataBlock = DataBlockBuilder.buildFromRows(rows, dataSchema, 
aggFunctions);
     assertEquals(rowDataBlock.getNumberOfRows(), numRows);
     checkEquals(type, rowDataBlock, i -> rows.get(i)[0]);
   }
 
-  private List<Object[]> generateRows(DataSchema.ColumnDataType type, int 
numRows) {
+  private List<Object[]> generateRows(ColumnDataType type, int numRows) {
     List<Object[]> result = new ArrayList<>();
     Random r = new Random(42);
-    switch (type) {
+    switch (type.getStoredType()) {
       case INT:
         for (int i = 0; i < numRows; i++) {
           result.add(new Object[]{r.nextInt()});
@@ -106,11 +106,6 @@ public class DataBlockBuilderTest {
           result.add(new Object[]{BigDecimal.valueOf(r.nextInt())});
         }
         break;
-      case OBJECT:
-        for (int i = 0; i < numRows; i++) {
-          result.add(new Object[]{r.nextLong()}); // longs are valid object 
types
-        }
-        break;
       case MAP:
         for (int i = 0; i < numRows; i++) {
           Map<String, String> map = new HashMap<>();
@@ -145,9 +140,12 @@ public class DataBlockBuilderTest {
           result.add(new Object[]{new String[]{String.valueOf(r.nextInt()), 
String.valueOf(r.nextInt())}});
         }
         break;
-      case BYTES_ARRAY:
+      case OBJECT:
       case UNKNOWN:
-        throw new SkipException(type + " not supported yet");
+        for (int i = 0; i < numRows; i++) {
+          result.add(new Object[1]);
+        }
+        break;
       default:
         throw new IllegalStateException("Unsupported data type: " + type);
     }
@@ -158,23 +156,25 @@ public class DataBlockBuilderTest {
   }
 
   @Test(dataProvider = "columnDataTypes")
-  void testColumnBlock(DataSchema.ColumnDataType type)
+  void testColumnBlock(ColumnDataType type)
       throws IOException {
     int numRows = 100;
     Object[] column = generateColumns(type, numRows);
-
-    DataSchema dataSchema = new DataSchema(new String[]{"column"}, new 
DataSchema.ColumnDataType[]{type});
-
-    DataBlock rowDataBlock = 
DataBlockBuilder.buildFromColumns(Collections.singletonList(column), 
dataSchema);
-
+    DataSchema dataSchema = new DataSchema(new String[]{"column"}, new 
ColumnDataType[]{type});
+    AggregationFunction[] aggFunctions = null;
+    if (type == ColumnDataType.OBJECT) {
+      aggFunctions = new 
AggregationFunction[]{mock(AggregationFunction.class)};
+    }
+    DataBlock rowDataBlock =
+        DataBlockBuilder.buildFromColumns(Collections.singletonList(column), 
dataSchema, aggFunctions);
     assertEquals(rowDataBlock.getNumberOfRows(), numRows);
     checkEquals(type, rowDataBlock, i -> column[i]);
   }
 
-  Object[] generateColumns(DataSchema.ColumnDataType type, int numRows) {
+  Object[] generateColumns(ColumnDataType type, int numRows) {
     Object[] result = new Object[numRows];
     Random r = new Random(42);
-    switch (type) {
+    switch (type.getStoredType()) {
       case INT:
         for (int i = 0; i < numRows; i++) {
           result[i] = r.nextInt();
@@ -218,11 +218,6 @@ public class DataBlockBuilderTest {
           result[i] = BigDecimal.valueOf(r.nextInt());
         }
         break;
-      case OBJECT:
-        for (int i = 0; i < numRows; i++) {
-          result[i] = r.nextLong(); // longs are valid object types
-        }
-        break;
       case INT_ARRAY:
         for (int i = 0; i < numRows; i++) {
           result[i] = new int[]{r.nextInt(), r.nextInt()};
@@ -248,9 +243,9 @@ public class DataBlockBuilderTest {
           result[i] = new String[]{String.valueOf(r.nextInt()), 
String.valueOf(r.nextInt())};
         }
         break;
-      case BYTES_ARRAY:
+      case OBJECT:
       case UNKNOWN:
-        throw new SkipException(type + " not supported yet");
+        break;
       default:
         throw new IllegalStateException("Unsupported data type: " + type);
     }
@@ -260,9 +255,9 @@ public class DataBlockBuilderTest {
     return result;
   }
 
-  private void checkEquals(DataSchema.ColumnDataType type, DataBlock block, 
IntFunction<Object> rowToData) {
+  private void checkEquals(ColumnDataType type, DataBlock block, 
IntFunction<Object> rowToData) {
     int numRows = block.getNumberOfRows();
-    switch (type) {
+    switch (type.getStoredType()) {
       case INT:
         for (int i = 0; i < numRows; i++) {
           Object expected = rowToData.apply(i);
@@ -319,16 +314,6 @@ public class DataBlockBuilderTest {
           }
         }
         break;
-      case OBJECT:
-        for (int i = 0; i < numRows; i++) {
-          Object expected = rowToData.apply(i);
-          if (expected != null) {
-            CustomObject customObject = block.getCustomObject(i, 0);
-            Long l = ObjectSerDeUtils.deserialize(customObject);
-            assertEquals(l, expected, "Failure on row " + i);
-          }
-        }
-        break;
       case MAP:
         for (int i = 0; i < numRows; i++) {
           Object expected = rowToData.apply(i);
@@ -377,13 +362,16 @@ public class DataBlockBuilderTest {
           }
         }
         break;
-      case BYTES_ARRAY:
+      case OBJECT:
       case UNKNOWN:
-        throw new SkipException(type + " not supported yet");
+        for (int i = 0; i < numRows; i++) {
+          assertNull(block.getCustomObject(i, 0));
+        }
+        break;
       default:
         throw new IllegalStateException("Unsupported data type: " + type);
     }
-    if (type != DataSchema.ColumnDataType.OBJECT) {
+    if (type != ColumnDataType.OBJECT && type != ColumnDataType.UNKNOWN) {
       RoaringBitmap nullRowIds = block.getNullRowIds(0);
 
       BitSet actualBitSet = new BitSet(numRows);
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/common/datatable/DataTableSerDeTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/common/datatable/DataTableSerDeTest.java
index b9cf5c5cb4..af3fa8f48b 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/common/datatable/DataTableSerDeTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/common/datatable/DataTableSerDeTest.java
@@ -26,14 +26,12 @@ import java.util.Map;
 import java.util.Random;
 import org.apache.commons.lang3.RandomStringUtils;
 import org.apache.commons.lang3.StringUtils;
-import org.apache.pinot.common.CustomObject;
 import org.apache.pinot.common.datatable.DataTable;
 import org.apache.pinot.common.datatable.DataTable.MetadataKey;
 import org.apache.pinot.common.datatable.DataTableFactory;
 import org.apache.pinot.common.exception.QueryException;
 import org.apache.pinot.common.response.ProcessingException;
 import org.apache.pinot.common.utils.DataSchema;
-import org.apache.pinot.core.common.ObjectSerDeUtils;
 import org.apache.pinot.spi.accounting.ThreadResourceUsageProvider;
 import org.apache.pinot.spi.utils.ByteArray;
 import org.roaringbitmap.RoaringBitmap;
@@ -157,7 +155,6 @@ public class DataTableSerDeTest {
 
   private void testEmptyValues(DataSchema dataSchema, int numRows, Object[] 
emptyValues)
       throws IOException {
-
     DataTableBuilder dataTableBuilder = 
DataTableBuilderFactory.getDataTableBuilder(dataSchema);
     for (int rowId = 0; rowId < numRows; rowId++) {
       dataTableBuilder.startRow();
@@ -178,7 +175,7 @@ public class DataTableSerDeTest {
         } else if (emptyValue instanceof ByteArray) {
           dataTableBuilder.setColumn(columnId, (ByteArray) emptyValue);
         } else {
-          dataTableBuilder.setColumn(columnId, emptyValue);
+          Assert.fail();
         }
       }
       dataTableBuilder.finishRow();
@@ -366,11 +363,6 @@ public class DataTableSerDeTest {
             BYTES[rowId] = isNull ? new byte[0] : 
RandomStringUtils.random(RANDOM.nextInt(20)).getBytes();
             dataTableBuilder.setColumn(colId, new ByteArray(BYTES[rowId]));
             break;
-          // Just test Double here, all object types will be covered in 
ObjectCustomSerDeTest.
-          case OBJECT:
-            OBJECTS[rowId] = isNull ? null : RANDOM.nextDouble();
-            dataTableBuilder.setColumn(colId, OBJECTS[rowId]);
-            break;
           case INT_ARRAY:
             int length = RANDOM.nextInt(20);
             int[] intArray = new int[length];
@@ -445,8 +437,9 @@ public class DataTableSerDeTest {
             MAPS[rowId] = map;
             dataTableBuilder.setColumn(colId, map);
             break;
+          case OBJECT:
           case UNKNOWN:
-            dataTableBuilder.setColumn(colId, (Object) null);
+            dataTableBuilder.setNull(colId);
             break;
           default:
             throw new UnsupportedOperationException("Unable to generate random 
data for: " + columnDataTypes[colId]);
@@ -507,15 +500,6 @@ public class DataTableSerDeTest {
             Assert.assertEquals(newDataTable.getBytes(rowId, 
colId).getBytes(), isNull ? new byte[0] : BYTES[rowId],
                 ERROR_MESSAGE);
             break;
-          case OBJECT:
-            CustomObject customObject = newDataTable.getCustomObject(rowId, 
colId);
-            if (isNull) {
-              Assert.assertNull(customObject, ERROR_MESSAGE);
-            } else {
-              Assert.assertNotNull(customObject);
-              Assert.assertEquals(ObjectSerDeUtils.deserialize(customObject), 
OBJECTS[rowId], ERROR_MESSAGE);
-            }
-            break;
           case INT_ARRAY:
             Assert.assertTrue(Arrays.equals(newDataTable.getIntArray(rowId, 
colId), INT_ARRAYS[rowId]), ERROR_MESSAGE);
             break;
@@ -549,6 +533,7 @@ public class DataTableSerDeTest {
           case MAP:
             Assert.assertEquals(newDataTable.getMap(rowId, colId), 
MAPS[rowId], ERROR_MESSAGE);
             break;
+          case OBJECT:
           case UNKNOWN:
             Object nulValue = newDataTable.getCustomObject(rowId, colId);
             Assert.assertNull(nulValue, ERROR_MESSAGE);
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
index 9cf70dc72c..7a60b5e370 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlock.java
@@ -34,16 +34,17 @@ import org.apache.pinot.common.datablock.RowDataBlock;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.common.Block;
 import org.apache.pinot.core.common.datablock.DataBlockBuilder;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.util.DataBlockExtractUtils;
 import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
 import org.apache.pinot.segment.spi.memory.DataBuffer;
 
 
-
 /**
  * A {@code TransferableBlock} is a wrapper around {@link DataBlock} for 
transferring data using
  * {@link org.apache.pinot.common.proto.Mailbox}.
  */
+@SuppressWarnings("rawtypes")
 public class TransferableBlock implements Block {
   private final DataBlock.Type _type;
   @Nullable
@@ -56,12 +57,21 @@ public class TransferableBlock implements Block {
   @Nullable
   private final MultiStageQueryStats _queryStats;
 
+  @Nullable
+  private AggregationFunction[] _aggFunctions;
+
   public TransferableBlock(List<Object[]> container, DataSchema dataSchema, 
DataBlock.Type type) {
+    this(container, dataSchema, type, null);
+  }
+
+  public TransferableBlock(List<Object[]> container, DataSchema dataSchema, 
DataBlock.Type type,
+      @Nullable AggregationFunction[] aggFunctions) {
     _container = container;
     _dataSchema = dataSchema;
     Preconditions.checkArgument(type == DataBlock.Type.ROW || type == 
DataBlock.Type.COLUMNAR,
         "Container cannot be used to construct block of type: %s", type);
     _type = type;
+    _aggFunctions = aggFunctions;
     _numRows = _container.size();
     // NOTE: Use assert to avoid breaking production code.
     assert _numRows > 0 : "Container should not be empty";
@@ -91,7 +101,7 @@ public class TransferableBlock implements Block {
     if (isSuccessfulEndOfStreamBlock()) {
       List<DataBuffer> statsByStage;
       if (_dataBlock instanceof MetadataBlock) {
-        statsByStage = ((MetadataBlock) _dataBlock).getStatsByStage();
+        statsByStage = _dataBlock.getStatsByStage();
         if (statsByStage == null) {
           return new ArrayList<>();
         }
@@ -172,10 +182,10 @@ public class TransferableBlock implements Block {
       try {
         switch (_type) {
           case ROW:
-            _dataBlock = DataBlockBuilder.buildFromRows(_container, 
_dataSchema);
+            _dataBlock = DataBlockBuilder.buildFromRows(_container, 
_dataSchema, _aggFunctions);
             break;
           case COLUMNAR:
-            _dataBlock = DataBlockBuilder.buildFromColumns(_container, 
_dataSchema);
+            _dataBlock = DataBlockBuilder.buildFromColumns(_container, 
_dataSchema, _aggFunctions);
             break;
           case METADATA:
             _dataBlock = new MetadataBlock(getSerializedStatsByStage());
@@ -207,6 +217,11 @@ public class TransferableBlock implements Block {
     return _type;
   }
 
+  @Nullable
+  public AggregationFunction[] getAggFunctions() {
+    return _aggFunctions;
+  }
+
   /**
    * Return whether a transferable block is at the end of a stream.
    *
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlockUtils.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlockUtils.java
index 04b0a3ded4..c90d61b2e1 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlockUtils.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/blocks/TransferableBlockUtils.java
@@ -94,7 +94,8 @@ public final class TransferableBlockUtils {
       List<TransferableBlock> blockChunks = new ArrayList<>(numChunks);
       for (int fromIndex = 0; fromIndex < numRows; fromIndex += 
numRowsPerChunk) {
         int toIndex = Math.min(fromIndex + numRowsPerChunk, numRows);
-        blockChunks.add(new TransferableBlock(rows.subList(fromIndex, 
toIndex), dataSchema, DataBlock.Type.ROW));
+        blockChunks.add(new TransferableBlock(rows.subList(fromIndex, 
toIndex), dataSchema, DataBlock.Type.ROW,
+            block.getAggFunctions()));
       }
       return blockChunks.iterator();
     } else {
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 071f9198a2..7d1c7a7e31 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -64,7 +64,6 @@ import org.slf4j.LoggerFactory;
  * When the list of aggregation calls is empty, this class is used to 
calculate distinct result based on group by keys.
  */
 public class AggregateOperator extends MultiStageOperator {
-
   private static final Logger LOGGER = 
LoggerFactory.getLogger(AggregateOperator.class);
   private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR";
   private static final CountAggregationFunction COUNT_STAR_AGG_FUNCTION =
@@ -72,6 +71,7 @@ public class AggregateOperator extends MultiStageOperator {
 
   private final MultiStageOperator _input;
   private final DataSchema _resultSchema;
+  private final AggregationFunction<?, ?>[] _aggFunctions;
   private final MultistageAggregationExecutor _aggregationExecutor;
   private final MultistageGroupByExecutor _groupByExecutor;
 
@@ -96,10 +96,8 @@ public class AggregateOperator extends MultiStageOperator {
     super(context);
     _input = input;
     _resultSchema = node.getDataSchema();
-
-    // Initialize the aggregation functions
-    AggregationFunction<?, ?>[] aggFunctions = 
getAggFunctions(node.getAggCalls());
-    int numFunctions = aggFunctions.length;
+    _aggFunctions = getAggFunctions(node.getAggCalls());
+    int numFunctions = _aggFunctions.length;
 
     // Process the filter argument indices
     List<Integer> filterArgs = node.getFilterArgs();
@@ -141,11 +139,11 @@ public class AggregateOperator extends MultiStageOperator 
{
     boolean leafReturnFinalResult = node.isLeafReturnFinalResult();
     if (groupKeys.isEmpty()) {
       _aggregationExecutor =
-          new MultistageAggregationExecutor(aggFunctions, filterArgIds, 
maxFilterArgId, aggType, _resultSchema);
+          new MultistageAggregationExecutor(_aggFunctions, filterArgIds, 
maxFilterArgId, aggType, _resultSchema);
       _groupByExecutor = null;
     } else {
       _groupByExecutor =
-          new MultistageGroupByExecutor(getGroupKeyIds(groupKeys), 
aggFunctions, filterArgIds, maxFilterArgId, aggType,
+          new MultistageGroupByExecutor(getGroupKeyIds(groupKeys), 
_aggFunctions, filterArgIds, maxFilterArgId, aggType,
               leafReturnFinalResult, _resultSchema, 
context.getOpChainMetadata(), node.getNodeHint());
       _aggregationExecutor = null;
     }
@@ -216,7 +214,7 @@ public class AggregateOperator extends MultiStageOperator {
   private TransferableBlock produceAggregatedBlock() {
     _hasConstructedAggregateBlock = true;
     if (_aggregationExecutor != null) {
-      return new TransferableBlock(_aggregationExecutor.getResult(), 
_resultSchema, DataBlock.Type.ROW);
+      return new TransferableBlock(_aggregationExecutor.getResult(), 
_resultSchema, DataBlock.Type.ROW, _aggFunctions);
     } else {
       List<Object[]> rows;
       if (_comparator != null) {
@@ -228,7 +226,7 @@ public class AggregateOperator extends MultiStageOperator {
       if (rows.isEmpty()) {
         return _eosBlock;
       } else {
-        TransferableBlock dataBlock = new TransferableBlock(rows, 
_resultSchema, DataBlock.Type.ROW);
+        TransferableBlock dataBlock = new TransferableBlock(rows, 
_resultSchema, DataBlock.Type.ROW, _aggFunctions);
         if (_groupByExecutor.isNumGroupsLimitReached()) {
           if (_errorOnNumGroupsLimit) {
             _input.earlyTerminate();
@@ -356,6 +354,7 @@ public class AggregateOperator extends MultiStageOperator {
       return Collections.emptyMap();
     }
     DataSchema dataSchema = block.getDataSchema();
+    assert dataSchema != null;
     Map<ExpressionContext, BlockValSet> blockValSetMap = new HashMap<>();
     if (block.isContainerConstructed()) {
       List<Object[]> rows = block.getContainer();
@@ -388,6 +387,7 @@ public class AggregateOperator extends MultiStageOperator {
       return Collections.emptyMap();
     }
     DataSchema dataSchema = block.getDataSchema();
+    assert dataSchema != null;
     Map<ExpressionContext, BlockValSet> blockValSetMap = new HashMap<>();
     if (block.isContainerConstructed()) {
       List<Object[]> rows = block.getContainer();
@@ -415,8 +415,8 @@ public class AggregateOperator extends MultiStageOperator {
     return blockValSetMap;
   }
 
-  static Object[] getIntermediateResults(AggregationFunction<?, ?> 
aggFunctions, TransferableBlock block) {
-    ExpressionContext firstArgument = 
aggFunctions.getInputExpressions().get(0);
+  static Object[] getIntermediateResults(AggregationFunction<?, ?> 
aggFunction, TransferableBlock block) {
+    ExpressionContext firstArgument = aggFunction.getInputExpressions().get(0);
     Preconditions.checkState(firstArgument.getType() == 
ExpressionContext.Type.IDENTIFIER,
         "Expected the first argument to be IDENTIFIER, got: %s", 
firstArgument.getType());
     int colId = fromIdentifierToColId(firstArgument.getIdentifier());
@@ -429,7 +429,7 @@ public class AggregateOperator extends MultiStageOperator {
       }
       return values;
     } else {
-      return DataBlockExtractUtils.extractColumn(block.getDataBlock(), colId);
+      return DataBlockExtractUtils.extractAggResult(block.getDataBlock(), 
colId, aggFunction);
     }
   }
 
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java
index e6fe338fb7..ef26ef6c57 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java
@@ -182,7 +182,7 @@ public class LeafStageTransferableBlockOperator extends 
MultiStageOperator {
       return constructMetadataBlock();
     } else {
       // Regular data block
-      return composeTransferableBlock(resultsBlock, _dataSchema);
+      return composeTransferableBlock(resultsBlock);
     }
   }
 
@@ -505,25 +505,23 @@ public class LeafStageTransferableBlockOperator extends 
MultiStageOperator {
    * Composes the {@link TransferableBlock} from the {@link BaseResultsBlock} 
returned from single-stage engine. It
    * converts the data types of the results to conform with the desired data 
schema asked by the multi-stage engine.
    */
-  private static TransferableBlock composeTransferableBlock(BaseResultsBlock 
resultsBlock,
-      DataSchema desiredDataSchema) {
+  private TransferableBlock composeTransferableBlock(BaseResultsBlock 
resultsBlock) {
     if (resultsBlock instanceof SelectionResultsBlock) {
-      return composeSelectTransferableBlock((SelectionResultsBlock) 
resultsBlock, desiredDataSchema);
+      return composeSelectTransferableBlock((SelectionResultsBlock) 
resultsBlock);
     } else {
-      return composeDirectTransferableBlock(resultsBlock, desiredDataSchema);
+      return composeDirectTransferableBlock(resultsBlock);
     }
   }
 
   /**
    * For selection, we need to check if the columns are in order. If not, we 
need to re-arrange the columns.
    */
-  private static TransferableBlock 
composeSelectTransferableBlock(SelectionResultsBlock resultsBlock,
-      DataSchema desiredDataSchema) {
+  private TransferableBlock 
composeSelectTransferableBlock(SelectionResultsBlock resultsBlock) {
     int[] columnIndices = getColumnIndices(resultsBlock);
     if (!inOrder(columnIndices)) {
-      return composeColumnIndexedTransferableBlock(resultsBlock, 
desiredDataSchema, columnIndices);
+      return composeColumnIndexedTransferableBlock(resultsBlock, 
columnIndices);
     } else {
-      return composeDirectTransferableBlock(resultsBlock, desiredDataSchema);
+      return composeDirectTransferableBlock(resultsBlock);
     }
   }
 
@@ -555,13 +553,12 @@ public class LeafStageTransferableBlockOperator extends 
MultiStageOperator {
     return true;
   }
 
-  private static TransferableBlock 
composeColumnIndexedTransferableBlock(BaseResultsBlock block,
-      DataSchema outputDataSchema, int[] columnIndices) {
+  private TransferableBlock 
composeColumnIndexedTransferableBlock(SelectionResultsBlock block, int[] 
columnIndices) {
     List<Object[]> resultRows = block.getRows();
     DataSchema inputDataSchema = block.getDataSchema();
     assert resultRows != null && inputDataSchema != null;
     ColumnDataType[] inputStoredTypes = 
inputDataSchema.getStoredColumnDataTypes();
-    ColumnDataType[] outputStoredTypes = 
outputDataSchema.getStoredColumnDataTypes();
+    ColumnDataType[] outputStoredTypes = 
_dataSchema.getStoredColumnDataTypes();
     List<Object[]> convertedRows = new ArrayList<>(resultRows.size());
     boolean needConvert = false;
     int numColumns = columnIndices.length;
@@ -580,7 +577,7 @@ public class LeafStageTransferableBlockOperator extends 
MultiStageOperator {
         convertedRows.add(reorderRow(row, columnIndices));
       }
     }
-    return new TransferableBlock(convertedRows, outputDataSchema, 
DataBlock.Type.ROW);
+    return new TransferableBlock(convertedRows, _dataSchema, 
DataBlock.Type.ROW);
   }
 
   private static Object[] reorderAndConvertRow(Object[] row, ColumnDataType[] 
inputStoredTypes,
@@ -610,18 +607,19 @@ public class LeafStageTransferableBlockOperator extends 
MultiStageOperator {
     return resultRow;
   }
 
-  private static TransferableBlock 
composeDirectTransferableBlock(BaseResultsBlock block, DataSchema 
outputDataSchema) {
+  private TransferableBlock composeDirectTransferableBlock(BaseResultsBlock 
block) {
     List<Object[]> resultRows = block.getRows();
     DataSchema inputDataSchema = block.getDataSchema();
     assert resultRows != null && inputDataSchema != null;
     ColumnDataType[] inputStoredTypes = 
inputDataSchema.getStoredColumnDataTypes();
-    ColumnDataType[] outputStoredTypes = 
outputDataSchema.getStoredColumnDataTypes();
+    ColumnDataType[] outputStoredTypes = 
_dataSchema.getStoredColumnDataTypes();
     if (!Arrays.equals(inputStoredTypes, outputStoredTypes)) {
       for (Object[] row : resultRows) {
         convertRow(row, inputStoredTypes, outputStoredTypes);
       }
     }
-    return new TransferableBlock(resultRows, outputDataSchema, 
DataBlock.Type.ROW);
+    return new TransferableBlock(resultRows, _dataSchema, DataBlock.Type.ROW,
+        _requests.get(0).getQueryContext().getAggregationFunctions());
   }
 
   public static void convertRow(Object[] row, ColumnDataType[] 
inputStoredTypes, ColumnDataType[] outputStoredTypes) {
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
index 3253904095..99d0d8310a 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
@@ -449,7 +449,7 @@ public class MultistageGroupByExecutor {
   private int[] generateGroupByKeys(DataBlock dataBlock) {
     Object[] keys;
     if (_groupKeyIds.length == 1) {
-      keys = DataBlockExtractUtils.extractColumn(dataBlock, _groupKeyIds[0]);
+      keys = DataBlockExtractUtils.extractKey(dataBlock, _groupKeyIds[0]);
     } else {
       keys = DataBlockExtractUtils.extractKeys(dataBlock, _groupKeyIds);
     }
@@ -496,7 +496,7 @@ public class MultistageGroupByExecutor {
   private int[] generateGroupByKeys(DataBlock dataBlock, int numMatchedRows, 
RoaringBitmap matchedBitmap) {
     Object[] keys;
     if (_groupKeyIds.length == 1) {
-      keys = DataBlockExtractUtils.extractColumn(dataBlock, _groupKeyIds[0], 
numMatchedRows, matchedBitmap);
+      keys = DataBlockExtractUtils.extractKey(dataBlock, _groupKeyIds[0], 
numMatchedRows, matchedBitmap);
     } else {
       keys = DataBlockExtractUtils.extractKeys(dataBlock, _groupKeyIds, 
numMatchedRows, matchedBitmap);
     }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/exchange/HashExchange.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/exchange/HashExchange.java
index 44b1b66815..31b60d2ef4 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/exchange/HashExchange.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/exchange/HashExchange.java
@@ -24,6 +24,7 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.TimeoutException;
 import java.util.function.Function;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.query.mailbox.SendingMailbox;
 import org.apache.pinot.query.planner.partitioning.EmptyKeySelector;
 import org.apache.pinot.query.planner.partitioning.KeySelector;
@@ -50,6 +51,7 @@ class HashExchange extends BlockExchange {
     this(sendingMailboxes, keySelector, splitter, RANDOM_INDEX_CHOOSER);
   }
 
+  @SuppressWarnings({"rawtypes", "unchecked"})
   @Override
   protected void route(List<SendingMailbox> destinations, TransferableBlock 
block)
       throws IOException, TimeoutException {
@@ -59,7 +61,6 @@ class HashExchange extends BlockExchange {
       return;
     }
 
-    //noinspection unchecked
     List<Object[]>[] mailboxIdToRowsMap = new List[numMailboxes];
     for (int i = 0; i < numMailboxes; i++) {
       mailboxIdToRowsMap[i] = new ArrayList<>();
@@ -69,10 +70,11 @@ class HashExchange extends BlockExchange {
       int mailboxId = _keySelector.computeHash(row) % numMailboxes;
       mailboxIdToRowsMap[mailboxId].add(row);
     }
+    AggregationFunction[] aggFunctions = block.getAggFunctions();
     for (int i = 0; i < numMailboxes; i++) {
       if (!mailboxIdToRowsMap[i].isEmpty()) {
         sendBlock(destinations.get(i),
-            new TransferableBlock(mailboxIdToRowsMap[i], 
block.getDataSchema(), block.getType()));
+            new TransferableBlock(mailboxIdToRowsMap[i], 
block.getDataSchema(), block.getType(), aggFunctions));
       }
     }
   }
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperatorTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperatorTest.java
index 248ab00ce6..830f3bef93 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperatorTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperatorTest.java
@@ -93,9 +93,11 @@ public class LeafStageTransferableBlockOperatorTest {
   }
 
   private List<ServerQueryRequest> mockQueryRequests(int numRequests) {
+    ServerQueryRequest queryRequest = mock(ServerQueryRequest.class);
+    when(queryRequest.getQueryContext()).thenReturn(mock(QueryContext.class));
     List<ServerQueryRequest> queryRequests = new ArrayList<>(numRequests);
     for (int i = 0; i < numRequests; i++) {
-      queryRequests.add(mock(ServerQueryRequest.class));
+      queryRequests.add(queryRequest);
     }
     return queryRequests;
   }
diff --git 
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
 
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index 4f0f0c48f9..10bdfb463b 100644
--- 
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++ 
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -74,8 +74,7 @@ public enum AggregationFunctionType {
   DISTINCTAVG("distinctAvg", ReturnTypes.DOUBLE, OperandTypes.NUMERIC, 
SqlTypeName.OTHER),
   DISTINCTCOUNTBITMAP("distinctCountBitmap", ReturnTypes.BIGINT, 
OperandTypes.ANY, SqlTypeName.OTHER,
       SqlTypeName.INTEGER),
-  SEGMENTPARTITIONEDDISTINCTCOUNT("segmentPartitionedDistinctCount", 
ReturnTypes.BIGINT, OperandTypes.ANY,
-      SqlTypeName.OTHER),
+  SEGMENTPARTITIONEDDISTINCTCOUNT("segmentPartitionedDistinctCount", 
ReturnTypes.BIGINT, OperandTypes.ANY),
   DISTINCTCOUNTHLL("distinctCountHLL", ReturnTypes.BIGINT,
       OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER), i 
-> i == 1), SqlTypeName.OTHER),
   DISTINCTCOUNTRAWHLL("distinctCountRawHLL", ReturnTypes.VARCHAR,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to