walterddr commented on code in PR #10845:
URL: https://github.com/apache/pinot/pull/10845#discussion_r1248016907


##########
pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java:
##########
@@ -268,4 +269,381 @@ private static Object[] extractRowFromDataBlock(DataBlock 
dataBlock, int rowId,
     }
     return row;
   }
+
+  /**
+   * Given a datablock and the column index, extracts the integer values for 
the column. Prefer using this function over
+   * extractRowFromDatablock if the desired datatype is known to prevent 
autoboxing to Object and later unboxing to the
+   * desired type.
+   *
+   * @return int array of values in the column
+   */
+  public static int[] extractIntRowsForColumn(DataBlock dataBlock, int 
columnIndex) {

Review Comment:
   1. name is confusing suggest
   ```suggestion
     public static int[] extractIntValuesForColumn(DataBlock dataBlock, int 
columnIndex) {
   ```
   2. (stretch/follow-up) dataBlock can be either ROW-based or COLUMNAR-based 
so check dataBlock type first is needed later 



##########
pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java:
##########
@@ -268,4 +269,381 @@ private static Object[] extractRowFromDataBlock(DataBlock 
dataBlock, int rowId,
     }
     return row;
   }
+
+  /**
+   * Given a datablock and the column index, extracts the integer values for 
the column. Prefer using this function over
+   * extractRowFromDatablock if the desired datatype is known to prevent 
autoboxing to Object and later unboxing to the
+   * desired type.
+   *
+   * @return int array of values in the column
+   */
+  public static int[] extractIntRowsForColumn(DataBlock dataBlock, int 
columnIndex) {
+    DataSchema dataSchema = dataBlock.getDataSchema();
+    DataSchema.ColumnDataType[] columnDataTypes = 
dataSchema.getColumnDataTypes();
+
+    // Get null bitmap for the column.
+    RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+    int numRows = dataBlock.getNumberOfRows();
+
+    int[] rows = new int[numRows];
+    switch (columnDataTypes[columnIndex]) {
+      case INT:
+      case BOOLEAN:
+        for (int rowId = 0; rowId < numRows; rowId++) {
+          if (nullBitmap != null && nullBitmap.contains(rowId)) {
+            continue;
+          }

Review Comment:
   nit: consider switching from a per-row nullbitmap check to a global 
switch-case



##########
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java:
##########
@@ -0,0 +1,270 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.common.datablock.DataBlock;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.common.IntermediateStageBlockValSet;
+import org.apache.pinot.core.data.table.Key;
+import org.apache.pinot.core.plan.maker.InstancePlanMakerImplV2;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+
+
+/**
+ * Class that executes the group by aggregations for the multistage 
AggregateOperator.
+ */
+public class MultistageGroupByExecutor {
+  private final NewAggregateOperator.Mode _mode;
+  // The identifier operands for the aggregation function only store the 
column name. This map contains mapping
+  // between column name to their index which is used in v2 engine.
+  private final Map<String, Integer> _colNameToIndexMap;
+
+  private final List<ExpressionContext> _groupSet;
+  private final AggregationFunction[] _aggFunctions;
+
+  // Group By Result holders for each mode
+  private final GroupByResultHolder[] _aggregateResultHolders;
+  private final Map<Integer, Object[]> _mergeResultHolder;
+  private final List<Object[]> _finalResultHolder;
+
+  // Mapping from the row-key to a zero based integer index. This is used when 
we invoke the v1 aggregation functions
+  // because they use the zero based integer indexes to store results.
+  private int _groupId = 0;
+  private Map<Key, Integer> _groupKeyToIdMap;
+
+  // Mapping from the group by row-key to the values in the row which form the 
key. Used to fetch the actual row
+  // values when populating the result.
+  private final Map<Key, Object[]> _groupByKeyHolder;
+
+  public MultistageGroupByExecutor(List<ExpressionContext> groupByExpr, 
AggregationFunction[] aggFunctions,
+      NewAggregateOperator.Mode mode, Map<String, Integer> colNameToIndexMap) {
+    _mode = mode;
+    _colNameToIndexMap = colNameToIndexMap;
+    _groupSet = groupByExpr;
+    _aggFunctions = aggFunctions;
+
+    _aggregateResultHolders = new GroupByResultHolder[_aggFunctions.length];
+    _mergeResultHolder = new HashMap<>();
+    _finalResultHolder = new ArrayList<>();
+
+    _groupKeyToIdMap = new HashMap<>();
+    _groupByKeyHolder = new HashMap<>();
+
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      _aggregateResultHolders[i] =
+          
_aggFunctions[i].createGroupByResultHolder(InstancePlanMakerImplV2.DEFAULT_MAX_INITIAL_RESULT_HOLDER_CAPACITY,
+              InstancePlanMakerImplV2.DEFAULT_NUM_GROUPS_LIMIT);
+    }
+  }
+
+  /**
+   * Performs group-by aggregation for the data in the block.
+   */
+  public void processBlock(TransferableBlock block, DataSchema 
inputDataSchema) {
+    if (_mode.equals(NewAggregateOperator.Mode.AGGREGATE)) {
+      processAggregate(block, inputDataSchema);
+    } else if (_mode.equals(NewAggregateOperator.Mode.MERGE)) {
+      processMerge(block);
+    } else if (_mode.equals(NewAggregateOperator.Mode.EXTRACT_RESULT)) {
+      collectResult(block);
+    }
+  }
+
+  /**
+   * Fetches the result.
+   */
+  public List<Object[]> getResult() {
+    List<Object[]> rows = new ArrayList<>();
+
+    if (_mode.equals(NewAggregateOperator.Mode.EXTRACT_RESULT)) {
+      return extractFinalGroupByResult();
+    }
+
+    // If the mode is MERGE or AGGREGATE, the groupby keys are already 
collected in _groupByKeyHolder by virtue of
+    // generating the row keys.
+    for (Map.Entry<Key, Object[]> e : _groupByKeyHolder.entrySet()) {
+      int numCols = _groupSet.size() + _aggFunctions.length;
+      Object[] row = new Object[numCols];
+      Object[] keyElements = e.getValue();
+      System.arraycopy(keyElements, 0, row, 0, keyElements.length);
+
+      for (int i = 0; i < _aggFunctions.length; i++) {
+        int index = i + _groupSet.size();
+        int groupId = _groupKeyToIdMap.get(e.getKey());
+        if (_mode.equals(NewAggregateOperator.Mode.MERGE)) {
+          row[index] = _mergeResultHolder.get(groupId)[i];
+        } else if (_mode.equals(NewAggregateOperator.Mode.AGGREGATE)) {
+          row[index] = 
_aggFunctions[i].extractGroupByResult(_aggregateResultHolders[i], groupId);
+        }
+      }
+
+      rows.add(row);
+    }
+
+    return rows;
+  }
+
+  private List<Object[]> extractFinalGroupByResult() {
+    List<Object[]> rows = new ArrayList<>();
+    for (Object[] resultRow : _finalResultHolder) {
+      int numCols = _groupSet.size() + _aggFunctions.length;
+      Object[] row = new Object[numCols];
+      System.arraycopy(resultRow, 0, row, 0, _groupSet.size());
+
+      for (int i = 0; i < _aggFunctions.length; i++) {
+        int aggIdx = i + _groupSet.size();
+        Comparable result = 
_aggFunctions[i].extractFinalResult(resultRow[aggIdx]);
+        row[aggIdx] = result == null ? null : 
_aggFunctions[i].getFinalResultColumnType().convert(result);
+      }
+
+      rows.add(row);
+    }
+    return rows;
+  }
+
+  private void processAggregate(TransferableBlock block, DataSchema 
inputDataSchema) {
+    int[] intKeys = generateGroupByKeys(block.getContainer());
+
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      AggregationFunction aggregationFunction = _aggFunctions[i];
+      Map<ExpressionContext, BlockValSet> blockValSetMap =
+          getBlockValSetMap(aggregationFunction, block, inputDataSchema);
+      GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
+      groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size());
+      aggregationFunction.aggregateGroupBySV(block.getNumRows(), intKeys, 
groupByResultHolder, blockValSetMap);
+    }
+  }
+
+  private void processMerge(TransferableBlock block) {
+    List<Object[]> container = block.getContainer();
+    int[] intKeys = generateGroupByKeys(container);
+
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      List<ExpressionContext> expressions = 
_aggFunctions[i].getInputExpressions();
+
+      for (int j = 0; j < container.size(); j++) {
+        Object[] row = container.get(j);
+        int rowKey = intKeys[j];
+        if (!_mergeResultHolder.containsKey(rowKey)) {
+          _mergeResultHolder.put(rowKey, new Object[_aggFunctions.length]);
+        }
+        Object intermediateResultToMerge = extractValueFromRow(row, 
expressions);
+        Object mergedIntermediateResult = _mergeResultHolder.get(rowKey)[i];
+
+        // Not all V1 aggregation functions have null-handling. So handle null 
values and call merge only if necessary.
+        if (intermediateResultToMerge == null) {
+          continue;
+        }
+        if (mergedIntermediateResult == null) {
+          _mergeResultHolder.get(rowKey)[i] = intermediateResultToMerge;
+          continue;
+        }
+
+        _mergeResultHolder.get(rowKey)[i] = 
_aggFunctions[i].merge(intermediateResultToMerge, mergedIntermediateResult);
+      }
+    }
+  }
+
+  private void collectResult(TransferableBlock block) {
+    List<Object[]> container = block.getContainer();
+    for (Object[] row : container) {
+      assert row.length == _groupSet.size() + _aggFunctions.length;
+      _finalResultHolder.add(row);
+    }
+  }
+
+  /**
+   * Creates the group by key for each row. Converts the key into a 0-index 
based int value that can be used by
+   * GroupByAggregationResultHolders used in v1 aggregations.
+   * <p>
+   * Returns the int key for each row.
+   */
+  private int[] generateGroupByKeys(List<Object[]> rows) {
+    int[] rowKeys = new int[rows.size()];
+    int numGroups = _groupSet.size();
+
+    for (int i = 0; i < rows.size(); i++) {
+      Object[] row = rows.get(i);
+
+      Object[] keyElements = new Object[numGroups];
+      for (int j = 0; j < numGroups; j++) {
+        String colName = _groupSet.get(j).getIdentifier();
+        int colIndex = _colNameToIndexMap.get(colName);
+        keyElements[j] = row[colIndex];
+      }
+
+      Key rowKey = new Key(keyElements);
+      _groupByKeyHolder.put(rowKey, rowKey.getValues());
+      if (!_groupKeyToIdMap.containsKey(rowKey)) {
+        _groupKeyToIdMap.put(rowKey, _groupId);
+        ++_groupId;
+      }
+      rowKeys[i] = _groupKeyToIdMap.get(rowKey);
+    }
+
+    return rowKeys;
+  }
+
+  private Map<ExpressionContext, BlockValSet> 
getBlockValSetMap(AggregationFunction aggFunction,
+      TransferableBlock block, DataSchema inputDataSchema) {
+    List<ExpressionContext> expressions = aggFunction.getInputExpressions();
+    int numExpressions = expressions.size();
+    if (numExpressions == 0) {
+      return Collections.emptyMap();
+    }
+
+    Preconditions.checkState(numExpressions == 1, "Cannot handle more than one 
identifier in aggregation function.");
+    ExpressionContext expression = expressions.get(0);
+    
Preconditions.checkState(expression.getType().equals(ExpressionContext.Type.IDENTIFIER));
+    int index = _colNameToIndexMap.get(expression.getIdentifier());
+
+    DataSchema.ColumnDataType dataType = 
inputDataSchema.getColumnDataType(index);
+    Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), 
"Datablock type is not ROW");
+    return Collections.singletonMap(expression,
+        new IntermediateStageBlockValSet(dataType, block.getDataBlock(), 
index));

Review Comment:
   same here. consider which is the format we are going to use. the serialized 
or the unserialized. 



##########
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java:
##########
@@ -0,0 +1,192 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.common.datablock.DataBlock;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.common.IntermediateStageBlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+
+/**
+ * Class that executes all aggregation functions (without group-bys) for the 
multistage AggregateOperator.
+ */
+public class MultistageAggregationExecutor {
+  private final NewAggregateOperator.Mode _mode;
+  // The identifier operands for the aggregation function only store the 
column name. This map contains mapping
+  // from column name to their index.
+  private final Map<String, Integer> _colNameToIndexMap;
+
+  private final AggregationFunction[] _aggFunctions;
+
+  // Result holders for each mode.
+  private final AggregationResultHolder[] _aggregateResultHolder;
+  private final Object[] _mergeResultHolder;
+  private final Object[] _finalResultHolder;
+
+  public MultistageAggregationExecutor(AggregationFunction[] aggFunctions, 
NewAggregateOperator.Mode mode,
+      Map<String, Integer> colNameToIndexMap) {
+    _aggFunctions = aggFunctions;
+    _mode = mode;
+    _colNameToIndexMap = colNameToIndexMap;
+
+    _aggregateResultHolder = new AggregationResultHolder[aggFunctions.length];
+    _mergeResultHolder = new Object[aggFunctions.length];
+    _finalResultHolder = new Object[aggFunctions.length];
+
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      _aggregateResultHolder[i] = 
_aggFunctions[i].createAggregationResultHolder();
+    }
+  }
+
+  /**
+   * Performs aggregation for the data in the block.
+   */
+  public void processBlock(TransferableBlock block, DataSchema 
inputDataSchema) {
+    if (_mode.equals(NewAggregateOperator.Mode.AGGREGATE)) {
+      processAggregate(block, inputDataSchema);
+    } else if (_mode.equals(NewAggregateOperator.Mode.MERGE)) {
+      processMerge(block);
+    } else if (_mode.equals(NewAggregateOperator.Mode.EXTRACT_RESULT)) {
+      collectResult(block);
+    }
+  }
+
+  /**
+   * Fetches the result.
+   */
+  public List<Object[]> getResult() {
+    List<Object[]> rows = new ArrayList<>();
+    Object[] row = new Object[_aggFunctions.length];
+
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      AggregationFunction aggFunction = _aggFunctions[i];
+      if (_mode.equals(NewAggregateOperator.Mode.MERGE)) {
+        row[i] = _mergeResultHolder[i];
+      } else if (_mode.equals(NewAggregateOperator.Mode.AGGREGATE)) {
+        row[i] = 
aggFunction.extractAggregationResult(_aggregateResultHolder[i]);
+      } else {
+        assert _mode.equals(NewAggregateOperator.Mode.EXTRACT_RESULT);
+        Comparable result = 
aggFunction.extractFinalResult(_finalResultHolder[i]);
+        row[i] = result == null ? null : 
aggFunction.getFinalResultColumnType().convert(result);
+      }
+    }
+    rows.add(row);
+    return rows;
+  }
+
+  /**
+   * @return an empty agg result block for non-group-by aggregation.
+   */
+  public Object[] constructEmptyAggResultRow() {
+    Object[] row = new Object[_aggFunctions.length];
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      AggregationFunction aggFunction = _aggFunctions[i];
+      row[i] = 
aggFunction.extractAggregationResult(aggFunction.createAggregationResultHolder());
+    }
+    return row;
+  }
+
+  private void processAggregate(TransferableBlock block, DataSchema 
inputDataSchema) {
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      AggregationFunction aggregationFunction = _aggFunctions[i];
+      Map<ExpressionContext, BlockValSet> blockValSetMap =
+          getBlockValSetMap(aggregationFunction, block, inputDataSchema);
+      aggregationFunction.aggregate(block.getNumRows(), 
_aggregateResultHolder[i], blockValSetMap);
+    }
+  }
+
+  private void processMerge(TransferableBlock block) {
+    List<Object[]> container = block.getContainer();
+
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      List<ExpressionContext> expressions = 
_aggFunctions[i].getInputExpressions();
+      for (Object[] row : container) {
+        Object intermediateResultToMerge = extractValueFromRow(row, 
expressions);
+        Object mergedIntermediateResult = _mergeResultHolder[i];
+
+        // Not all V1 aggregation functions have null-handling logic. Handle 
null values before calling merge.
+        if (intermediateResultToMerge == null) {
+          continue;
+        }
+        if (mergedIntermediateResult == null) {
+          _mergeResultHolder[i] = intermediateResultToMerge;
+          continue;
+        }
+
+        _mergeResultHolder[i] = 
_aggFunctions[i].merge(intermediateResultToMerge, mergedIntermediateResult);
+      }
+    }
+  }
+
+  private void collectResult(TransferableBlock block) {
+    List<Object[]> container = block.getContainer();
+    assert container.size() == 1;
+    Object[] row = container.get(0);
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      List<ExpressionContext> expressions = 
_aggFunctions[i].getInputExpressions();
+      _finalResultHolder[i] = extractValueFromRow(row, expressions);
+    }
+  }
+
+  private Map<ExpressionContext, BlockValSet> 
getBlockValSetMap(AggregationFunction aggFunction,
+      TransferableBlock block, DataSchema inputDataSchema) {
+    List<ExpressionContext> expressions = aggFunction.getInputExpressions();
+    int numExpressions = expressions.size();
+    if (numExpressions == 0) {
+      return Collections.emptyMap();
+    }
+
+    Preconditions.checkState(numExpressions == 1, "Cannot handle more than one 
identifier in aggregation function.");
+    ExpressionContext expression = expressions.get(0);
+    
Preconditions.checkState(expression.getType().equals(ExpressionContext.Type.IDENTIFIER));
+    int index = _colNameToIndexMap.get(expression.getIdentifier());
+
+    DataSchema.ColumnDataType dataType = 
inputDataSchema.getColumnDataType(index);
+    Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), 
"Datablock type is not ROW");
+    return Collections.singletonMap(expression,
+        new IntermediateStageBlockValSet(dataType, block.getDataBlock(), 
index));

Review Comment:
   `block.getDataBlock()` is lazy. e.g. if the previous block is not mailbox 
received, then the getDataBlock() will actually convert the List<Object[]> 
unserialized format into a serialized format of BaseDataBlock, then convert it 
back to the column value primitive array. this is very inefficient. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to