walterddr commented on code in PR #10845: URL: https://github.com/apache/pinot/pull/10845#discussion_r1248016907
########## pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java: ########## @@ -268,4 +269,381 @@ private static Object[] extractRowFromDataBlock(DataBlock dataBlock, int rowId, } return row; } + + /** + * Given a datablock and the column index, extracts the integer values for the column. Prefer using this function over + * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the + * desired type. + * + * @return int array of values in the column + */ + public static int[] extractIntRowsForColumn(DataBlock dataBlock, int columnIndex) { Review Comment: 1. name is confusing suggest ```suggestion public static int[] extractIntValuesForColumn(DataBlock dataBlock, int columnIndex) { ``` 2. (stretch/follow-up) dataBlock can be either ROW-based or COLUMNAR-based so check dataBlock type first is needed later ########## pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java: ########## @@ -268,4 +269,381 @@ private static Object[] extractRowFromDataBlock(DataBlock dataBlock, int rowId, } return row; } + + /** + * Given a datablock and the column index, extracts the integer values for the column. Prefer using this function over + * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the + * desired type. + * + * @return int array of values in the column + */ + public static int[] extractIntRowsForColumn(DataBlock dataBlock, int columnIndex) { + DataSchema dataSchema = dataBlock.getDataSchema(); + DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes(); + + // Get null bitmap for the column. + RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex]; + int numRows = dataBlock.getNumberOfRows(); + + int[] rows = new int[numRows]; + switch (columnDataTypes[columnIndex]) { + case INT: + case BOOLEAN: + for (int rowId = 0; rowId < numRows; rowId++) { + if (nullBitmap != null && nullBitmap.contains(rowId)) { + continue; + } Review Comment: nit: consider switching from a per-row nullbitmap check to a global switch-case ########## pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java: ########## @@ -0,0 +1,270 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator; + +import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.pinot.common.datablock.DataBlock; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.common.IntermediateStageBlockValSet; +import org.apache.pinot.core.data.table.Key; +import org.apache.pinot.core.plan.maker.InstancePlanMakerImplV2; +import org.apache.pinot.core.query.aggregation.function.AggregationFunction; +import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; +import org.apache.pinot.query.runtime.blocks.TransferableBlock; + + +/** + * Class that executes the group by aggregations for the multistage AggregateOperator. + */ +public class MultistageGroupByExecutor { + private final NewAggregateOperator.Mode _mode; + // The identifier operands for the aggregation function only store the column name. This map contains mapping + // between column name to their index which is used in v2 engine. + private final Map<String, Integer> _colNameToIndexMap; + + private final List<ExpressionContext> _groupSet; + private final AggregationFunction[] _aggFunctions; + + // Group By Result holders for each mode + private final GroupByResultHolder[] _aggregateResultHolders; + private final Map<Integer, Object[]> _mergeResultHolder; + private final List<Object[]> _finalResultHolder; + + // Mapping from the row-key to a zero based integer index. This is used when we invoke the v1 aggregation functions + // because they use the zero based integer indexes to store results. + private int _groupId = 0; + private Map<Key, Integer> _groupKeyToIdMap; + + // Mapping from the group by row-key to the values in the row which form the key. Used to fetch the actual row + // values when populating the result. + private final Map<Key, Object[]> _groupByKeyHolder; + + public MultistageGroupByExecutor(List<ExpressionContext> groupByExpr, AggregationFunction[] aggFunctions, + NewAggregateOperator.Mode mode, Map<String, Integer> colNameToIndexMap) { + _mode = mode; + _colNameToIndexMap = colNameToIndexMap; + _groupSet = groupByExpr; + _aggFunctions = aggFunctions; + + _aggregateResultHolders = new GroupByResultHolder[_aggFunctions.length]; + _mergeResultHolder = new HashMap<>(); + _finalResultHolder = new ArrayList<>(); + + _groupKeyToIdMap = new HashMap<>(); + _groupByKeyHolder = new HashMap<>(); + + for (int i = 0; i < _aggFunctions.length; i++) { + _aggregateResultHolders[i] = + _aggFunctions[i].createGroupByResultHolder(InstancePlanMakerImplV2.DEFAULT_MAX_INITIAL_RESULT_HOLDER_CAPACITY, + InstancePlanMakerImplV2.DEFAULT_NUM_GROUPS_LIMIT); + } + } + + /** + * Performs group-by aggregation for the data in the block. + */ + public void processBlock(TransferableBlock block, DataSchema inputDataSchema) { + if (_mode.equals(NewAggregateOperator.Mode.AGGREGATE)) { + processAggregate(block, inputDataSchema); + } else if (_mode.equals(NewAggregateOperator.Mode.MERGE)) { + processMerge(block); + } else if (_mode.equals(NewAggregateOperator.Mode.EXTRACT_RESULT)) { + collectResult(block); + } + } + + /** + * Fetches the result. + */ + public List<Object[]> getResult() { + List<Object[]> rows = new ArrayList<>(); + + if (_mode.equals(NewAggregateOperator.Mode.EXTRACT_RESULT)) { + return extractFinalGroupByResult(); + } + + // If the mode is MERGE or AGGREGATE, the groupby keys are already collected in _groupByKeyHolder by virtue of + // generating the row keys. + for (Map.Entry<Key, Object[]> e : _groupByKeyHolder.entrySet()) { + int numCols = _groupSet.size() + _aggFunctions.length; + Object[] row = new Object[numCols]; + Object[] keyElements = e.getValue(); + System.arraycopy(keyElements, 0, row, 0, keyElements.length); + + for (int i = 0; i < _aggFunctions.length; i++) { + int index = i + _groupSet.size(); + int groupId = _groupKeyToIdMap.get(e.getKey()); + if (_mode.equals(NewAggregateOperator.Mode.MERGE)) { + row[index] = _mergeResultHolder.get(groupId)[i]; + } else if (_mode.equals(NewAggregateOperator.Mode.AGGREGATE)) { + row[index] = _aggFunctions[i].extractGroupByResult(_aggregateResultHolders[i], groupId); + } + } + + rows.add(row); + } + + return rows; + } + + private List<Object[]> extractFinalGroupByResult() { + List<Object[]> rows = new ArrayList<>(); + for (Object[] resultRow : _finalResultHolder) { + int numCols = _groupSet.size() + _aggFunctions.length; + Object[] row = new Object[numCols]; + System.arraycopy(resultRow, 0, row, 0, _groupSet.size()); + + for (int i = 0; i < _aggFunctions.length; i++) { + int aggIdx = i + _groupSet.size(); + Comparable result = _aggFunctions[i].extractFinalResult(resultRow[aggIdx]); + row[aggIdx] = result == null ? null : _aggFunctions[i].getFinalResultColumnType().convert(result); + } + + rows.add(row); + } + return rows; + } + + private void processAggregate(TransferableBlock block, DataSchema inputDataSchema) { + int[] intKeys = generateGroupByKeys(block.getContainer()); + + for (int i = 0; i < _aggFunctions.length; i++) { + AggregationFunction aggregationFunction = _aggFunctions[i]; + Map<ExpressionContext, BlockValSet> blockValSetMap = + getBlockValSetMap(aggregationFunction, block, inputDataSchema); + GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i]; + groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size()); + aggregationFunction.aggregateGroupBySV(block.getNumRows(), intKeys, groupByResultHolder, blockValSetMap); + } + } + + private void processMerge(TransferableBlock block) { + List<Object[]> container = block.getContainer(); + int[] intKeys = generateGroupByKeys(container); + + for (int i = 0; i < _aggFunctions.length; i++) { + List<ExpressionContext> expressions = _aggFunctions[i].getInputExpressions(); + + for (int j = 0; j < container.size(); j++) { + Object[] row = container.get(j); + int rowKey = intKeys[j]; + if (!_mergeResultHolder.containsKey(rowKey)) { + _mergeResultHolder.put(rowKey, new Object[_aggFunctions.length]); + } + Object intermediateResultToMerge = extractValueFromRow(row, expressions); + Object mergedIntermediateResult = _mergeResultHolder.get(rowKey)[i]; + + // Not all V1 aggregation functions have null-handling. So handle null values and call merge only if necessary. + if (intermediateResultToMerge == null) { + continue; + } + if (mergedIntermediateResult == null) { + _mergeResultHolder.get(rowKey)[i] = intermediateResultToMerge; + continue; + } + + _mergeResultHolder.get(rowKey)[i] = _aggFunctions[i].merge(intermediateResultToMerge, mergedIntermediateResult); + } + } + } + + private void collectResult(TransferableBlock block) { + List<Object[]> container = block.getContainer(); + for (Object[] row : container) { + assert row.length == _groupSet.size() + _aggFunctions.length; + _finalResultHolder.add(row); + } + } + + /** + * Creates the group by key for each row. Converts the key into a 0-index based int value that can be used by + * GroupByAggregationResultHolders used in v1 aggregations. + * <p> + * Returns the int key for each row. + */ + private int[] generateGroupByKeys(List<Object[]> rows) { + int[] rowKeys = new int[rows.size()]; + int numGroups = _groupSet.size(); + + for (int i = 0; i < rows.size(); i++) { + Object[] row = rows.get(i); + + Object[] keyElements = new Object[numGroups]; + for (int j = 0; j < numGroups; j++) { + String colName = _groupSet.get(j).getIdentifier(); + int colIndex = _colNameToIndexMap.get(colName); + keyElements[j] = row[colIndex]; + } + + Key rowKey = new Key(keyElements); + _groupByKeyHolder.put(rowKey, rowKey.getValues()); + if (!_groupKeyToIdMap.containsKey(rowKey)) { + _groupKeyToIdMap.put(rowKey, _groupId); + ++_groupId; + } + rowKeys[i] = _groupKeyToIdMap.get(rowKey); + } + + return rowKeys; + } + + private Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction, + TransferableBlock block, DataSchema inputDataSchema) { + List<ExpressionContext> expressions = aggFunction.getInputExpressions(); + int numExpressions = expressions.size(); + if (numExpressions == 0) { + return Collections.emptyMap(); + } + + Preconditions.checkState(numExpressions == 1, "Cannot handle more than one identifier in aggregation function."); + ExpressionContext expression = expressions.get(0); + Preconditions.checkState(expression.getType().equals(ExpressionContext.Type.IDENTIFIER)); + int index = _colNameToIndexMap.get(expression.getIdentifier()); + + DataSchema.ColumnDataType dataType = inputDataSchema.getColumnDataType(index); + Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW"); + return Collections.singletonMap(expression, + new IntermediateStageBlockValSet(dataType, block.getDataBlock(), index)); Review Comment: same here. consider which is the format we are going to use. the serialized or the unserialized. ########## pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java: ########## @@ -0,0 +1,192 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator; + +import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.pinot.common.datablock.DataBlock; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.common.IntermediateStageBlockValSet; +import org.apache.pinot.core.query.aggregation.AggregationResultHolder; +import org.apache.pinot.core.query.aggregation.function.AggregationFunction; +import org.apache.pinot.query.runtime.blocks.TransferableBlock; + +/** + * Class that executes all aggregation functions (without group-bys) for the multistage AggregateOperator. + */ +public class MultistageAggregationExecutor { + private final NewAggregateOperator.Mode _mode; + // The identifier operands for the aggregation function only store the column name. This map contains mapping + // from column name to their index. + private final Map<String, Integer> _colNameToIndexMap; + + private final AggregationFunction[] _aggFunctions; + + // Result holders for each mode. + private final AggregationResultHolder[] _aggregateResultHolder; + private final Object[] _mergeResultHolder; + private final Object[] _finalResultHolder; + + public MultistageAggregationExecutor(AggregationFunction[] aggFunctions, NewAggregateOperator.Mode mode, + Map<String, Integer> colNameToIndexMap) { + _aggFunctions = aggFunctions; + _mode = mode; + _colNameToIndexMap = colNameToIndexMap; + + _aggregateResultHolder = new AggregationResultHolder[aggFunctions.length]; + _mergeResultHolder = new Object[aggFunctions.length]; + _finalResultHolder = new Object[aggFunctions.length]; + + for (int i = 0; i < _aggFunctions.length; i++) { + _aggregateResultHolder[i] = _aggFunctions[i].createAggregationResultHolder(); + } + } + + /** + * Performs aggregation for the data in the block. + */ + public void processBlock(TransferableBlock block, DataSchema inputDataSchema) { + if (_mode.equals(NewAggregateOperator.Mode.AGGREGATE)) { + processAggregate(block, inputDataSchema); + } else if (_mode.equals(NewAggregateOperator.Mode.MERGE)) { + processMerge(block); + } else if (_mode.equals(NewAggregateOperator.Mode.EXTRACT_RESULT)) { + collectResult(block); + } + } + + /** + * Fetches the result. + */ + public List<Object[]> getResult() { + List<Object[]> rows = new ArrayList<>(); + Object[] row = new Object[_aggFunctions.length]; + + for (int i = 0; i < _aggFunctions.length; i++) { + AggregationFunction aggFunction = _aggFunctions[i]; + if (_mode.equals(NewAggregateOperator.Mode.MERGE)) { + row[i] = _mergeResultHolder[i]; + } else if (_mode.equals(NewAggregateOperator.Mode.AGGREGATE)) { + row[i] = aggFunction.extractAggregationResult(_aggregateResultHolder[i]); + } else { + assert _mode.equals(NewAggregateOperator.Mode.EXTRACT_RESULT); + Comparable result = aggFunction.extractFinalResult(_finalResultHolder[i]); + row[i] = result == null ? null : aggFunction.getFinalResultColumnType().convert(result); + } + } + rows.add(row); + return rows; + } + + /** + * @return an empty agg result block for non-group-by aggregation. + */ + public Object[] constructEmptyAggResultRow() { + Object[] row = new Object[_aggFunctions.length]; + for (int i = 0; i < _aggFunctions.length; i++) { + AggregationFunction aggFunction = _aggFunctions[i]; + row[i] = aggFunction.extractAggregationResult(aggFunction.createAggregationResultHolder()); + } + return row; + } + + private void processAggregate(TransferableBlock block, DataSchema inputDataSchema) { + for (int i = 0; i < _aggFunctions.length; i++) { + AggregationFunction aggregationFunction = _aggFunctions[i]; + Map<ExpressionContext, BlockValSet> blockValSetMap = + getBlockValSetMap(aggregationFunction, block, inputDataSchema); + aggregationFunction.aggregate(block.getNumRows(), _aggregateResultHolder[i], blockValSetMap); + } + } + + private void processMerge(TransferableBlock block) { + List<Object[]> container = block.getContainer(); + + for (int i = 0; i < _aggFunctions.length; i++) { + List<ExpressionContext> expressions = _aggFunctions[i].getInputExpressions(); + for (Object[] row : container) { + Object intermediateResultToMerge = extractValueFromRow(row, expressions); + Object mergedIntermediateResult = _mergeResultHolder[i]; + + // Not all V1 aggregation functions have null-handling logic. Handle null values before calling merge. + if (intermediateResultToMerge == null) { + continue; + } + if (mergedIntermediateResult == null) { + _mergeResultHolder[i] = intermediateResultToMerge; + continue; + } + + _mergeResultHolder[i] = _aggFunctions[i].merge(intermediateResultToMerge, mergedIntermediateResult); + } + } + } + + private void collectResult(TransferableBlock block) { + List<Object[]> container = block.getContainer(); + assert container.size() == 1; + Object[] row = container.get(0); + for (int i = 0; i < _aggFunctions.length; i++) { + List<ExpressionContext> expressions = _aggFunctions[i].getInputExpressions(); + _finalResultHolder[i] = extractValueFromRow(row, expressions); + } + } + + private Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction, + TransferableBlock block, DataSchema inputDataSchema) { + List<ExpressionContext> expressions = aggFunction.getInputExpressions(); + int numExpressions = expressions.size(); + if (numExpressions == 0) { + return Collections.emptyMap(); + } + + Preconditions.checkState(numExpressions == 1, "Cannot handle more than one identifier in aggregation function."); + ExpressionContext expression = expressions.get(0); + Preconditions.checkState(expression.getType().equals(ExpressionContext.Type.IDENTIFIER)); + int index = _colNameToIndexMap.get(expression.getIdentifier()); + + DataSchema.ColumnDataType dataType = inputDataSchema.getColumnDataType(index); + Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW"); + return Collections.singletonMap(expression, + new IntermediateStageBlockValSet(dataType, block.getDataBlock(), index)); Review Comment: `block.getDataBlock()` is lazy. e.g. if the previous block is not mailbox received, then the getDataBlock() will actually convert the List<Object[]> unserialized format into a serialized format of BaseDataBlock, then convert it back to the column value primitive array. this is very inefficient. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org