This is an automated email from the ASF dual-hosted git repository. rongr pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 05989559e0 [multistage][agg] support agg with filter (#11144) 05989559e0 is described below commit 05989559e06d448d23d55147672345ede624381d Author: Rong Rong <ro...@apache.org> AuthorDate: Fri Jul 28 13:35:59 2023 -0700 [multistage][agg] support agg with filter (#11144) * [init][agg] support agg filter where clause * limitation still applies for nullable vs non-nullable results and agg filter merging with select filter, will be address in follow ups. --------- Co-authored-by: Rong Rong <ro...@startree.ai> --- .../pinot/common/datablock/DataBlockUtils.java | 311 +++++++++++++++++++++ .../PinotAggregateExchangeNodeInsertRule.java | 2 +- .../query/parser/CalciteRexExpressionParser.java | 30 +- .../query/planner/plannode/AggregateNode.java | 7 + .../query/runtime/operator/AggregateOperator.java | 75 +++-- .../operator/MultistageAggregationExecutor.java | 27 +- .../operator/MultistageGroupByExecutor.java | 83 +++++- .../runtime/operator/block/DataBlockValSet.java | 15 +- .../operator/block/FilteredDataBlockValSet.java | 33 +-- .../query/runtime/plan/PhysicalPlanVisitor.java | 2 +- .../plan/server/ServerPlanRequestVisitor.java | 6 +- .../runtime/operator/AggregateOperatorTest.java | 16 +- .../test/resources/queries/FilterAggregates.json | 166 +++++++++++ 13 files changed, 679 insertions(+), 94 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java index 99c3e4df46..b5d8321280 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java @@ -23,6 +23,7 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.sql.Timestamp; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -560,4 +561,314 @@ public final class DataBlockUtils { return rows; } + + /** + * Given a datablock and the column index, extracts the integer values for the column. Prefer using this function over + * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the + * desired type. + * This only works on ROW format. + * TODO: Add support for COLUMNAR format. + * @return int array of values in the column + */ + public static int[] extractIntValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) { + DataSchema dataSchema = dataBlock.getDataSchema(); + DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes(); + + // Get null bitmap for the column. + RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex]; + int numRows = dataBlock.getNumberOfRows(); + + int[] rows = new int[numRows]; + int outRowId = 0; + for (int inRowId = 0; inRowId < numRows; inRowId++) { + if (dataBlock.getInt(inRowId, filterArgIdx) == 1) { + if (nullBitmap != null && nullBitmap.contains(inRowId)) { + outRowId++; + continue; + } + switch (columnDataTypes[columnIndex]) { + case INT: + case BOOLEAN: + rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex); + break; + case LONG: + rows[outRowId++] = (int) dataBlock.getLong(inRowId, columnIndex); + break; + case FLOAT: + rows[outRowId++] = (int) dataBlock.getFloat(inRowId, columnIndex); + break; + case DOUBLE: + rows[outRowId++] = (int) dataBlock.getDouble(inRowId, columnIndex); + break; + case BIG_DECIMAL: + rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).intValue(); + break; + default: + throw new IllegalStateException( + String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex)); + } + } + } + return Arrays.copyOfRange(rows, 0, outRowId); + } + + /** + * Given a datablock and the column index, extracts the long values for the column. Prefer using this function over + * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the + * desired type. + * This only works on ROW format. + * TODO: Add support for COLUMNAR format. + * @return long array of values in the column + */ + public static long[] extractLongValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) { + DataSchema dataSchema = dataBlock.getDataSchema(); + DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes(); + + // Get null bitmap for the column. + RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex]; + int numRows = dataBlock.getNumberOfRows(); + + long[] rows = new long[numRows]; + int outRowId = 0; + for (int inRowId = 0; inRowId < numRows; inRowId++) { + if (dataBlock.getInt(inRowId, filterArgIdx) == 1) { + if (nullBitmap != null && nullBitmap.contains(inRowId)) { + outRowId++; + continue; + } + switch (columnDataTypes[columnIndex]) { + case INT: + case BOOLEAN: + rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex); + break; + case LONG: + rows[outRowId++] = dataBlock.getLong(inRowId, columnIndex); + break; + case FLOAT: + rows[outRowId++] = (long) dataBlock.getFloat(inRowId, columnIndex); + break; + case DOUBLE: + rows[outRowId++] = (long) dataBlock.getDouble(inRowId, columnIndex); + break; + case BIG_DECIMAL: + rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).longValue(); + break; + default: + throw new IllegalStateException( + String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex)); + } + } + } + return Arrays.copyOfRange(rows, 0, outRowId); + } + + /** + * Given a datablock and the column index, extracts the float values for the column. Prefer using this function over + * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the + * desired type. + * This only works on ROW format. + * TODO: Add support for COLUMNAR format. + * @return float array of values in the column + */ + public static float[] extractFloatValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) { + DataSchema dataSchema = dataBlock.getDataSchema(); + DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes(); + + // Get null bitmap for the column. + RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex]; + int numRows = dataBlock.getNumberOfRows(); + + float[] rows = new float[numRows]; + int outRowId = 0; + for (int inRowId = 0; inRowId < numRows; inRowId++) { + if (dataBlock.getInt(inRowId, filterArgIdx) == 1) { + if (nullBitmap != null && nullBitmap.contains(inRowId)) { + outRowId++; + continue; + } + switch (columnDataTypes[columnIndex]) { + case INT: + case BOOLEAN: + rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex); + break; + case LONG: + rows[outRowId++] = dataBlock.getLong(inRowId, columnIndex); + break; + case FLOAT: + rows[outRowId++] = dataBlock.getFloat(inRowId, columnIndex); + break; + case DOUBLE: + rows[outRowId++] = (float) dataBlock.getDouble(inRowId, columnIndex); + break; + case BIG_DECIMAL: + rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).floatValue(); + break; + default: + throw new IllegalStateException( + String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex)); + } + } + } + return Arrays.copyOfRange(rows, 0, outRowId); + } + + /** + * Given a datablock and the column index, extracts the double values for the column. Prefer using this function over + * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the + * desired type. + * This only works on ROW format. + * TODO: Add support for COLUMNAR format. + * @return double array of values in the column + */ + public static double[] extractDoubleValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) { + DataSchema dataSchema = dataBlock.getDataSchema(); + DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes(); + + // Get null bitmap for the column. + RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex]; + int numRows = dataBlock.getNumberOfRows(); + + double[] rows = new double[numRows]; + int outRowId = 0; + for (int inRowId = 0; inRowId < numRows; inRowId++) { + if (dataBlock.getInt(inRowId, filterArgIdx) == 1) { + if (nullBitmap != null && nullBitmap.contains(inRowId)) { + outRowId++; + continue; + } + switch (columnDataTypes[columnIndex]) { + case INT: + case BOOLEAN: + rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex); + break; + case LONG: + rows[outRowId++] = dataBlock.getLong(inRowId, columnIndex); + break; + case FLOAT: + rows[outRowId++] = dataBlock.getFloat(inRowId, columnIndex); + break; + case DOUBLE: + rows[outRowId++] = dataBlock.getDouble(inRowId, columnIndex); + break; + case BIG_DECIMAL: + rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).doubleValue(); + break; + default: + throw new IllegalStateException( + String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex)); + } + } + } + return Arrays.copyOfRange(rows, 0, outRowId); + } + + /** + * Given a datablock and the column index, extracts the BigDecimal values for the column. Prefer using this function + * over extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to + * the desired type. + * This only works on ROW format. + * TODO: Add support for COLUMNAR format. + * @return BigDecimal array of values in the column + */ + public static BigDecimal[] extractBigDecimalValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) { + DataSchema dataSchema = dataBlock.getDataSchema(); + DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes(); + + // Get null bitmap for the column. + RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex]; + int numRows = dataBlock.getNumberOfRows(); + + BigDecimal[] rows = new BigDecimal[numRows]; + int outRowId = 0; + for (int inRowId = 0; inRowId < numRows; inRowId++) { + if (dataBlock.getInt(inRowId, filterArgIdx) == 1) { + if (nullBitmap != null && nullBitmap.contains(inRowId)) { + outRowId++; + continue; + } + switch (columnDataTypes[columnIndex]) { + case INT: + case BOOLEAN: + rows[outRowId++] = BigDecimal.valueOf(dataBlock.getInt(inRowId, columnIndex)); + break; + case LONG: + rows[outRowId++] = BigDecimal.valueOf(dataBlock.getLong(inRowId, columnIndex)); + break; + case FLOAT: + rows[outRowId++] = BigDecimal.valueOf(dataBlock.getFloat(inRowId, columnIndex)); + break; + case DOUBLE: + rows[outRowId++] = BigDecimal.valueOf(dataBlock.getDouble(inRowId, columnIndex)); + break; + case BIG_DECIMAL: + rows[outRowId++] = BigDecimal.valueOf(dataBlock.getBigDecimal(inRowId, columnIndex).doubleValue()); + break; + default: + throw new IllegalStateException( + String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex)); + } + } + } + return Arrays.copyOfRange(rows, 0, outRowId); + } + + /** + * Given a datablock and the column index, extracts the String values for the column. Prefer using this function over + * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the + * desired type. + * This only works on ROW format. + * TODO: Add support for COLUMNAR format. + * @return String array of values in the column + */ + public static String[] extractStringValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) { + DataSchema dataSchema = dataBlock.getDataSchema(); + DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes(); + + // Get null bitmap for the column. + RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex]; + int numRows = dataBlock.getNumberOfRows(); + + String[] rows = new String[numRows]; + int outRowId = 0; + for (int inRowId = 0; inRowId < numRows; inRowId++) { + if (dataBlock.getInt(inRowId, filterArgIdx) == 1) { + if (nullBitmap != null && nullBitmap.contains(inRowId)) { + outRowId++; + continue; + } + rows[outRowId++] = dataBlock.getString(inRowId, columnIndex); + } + } + return Arrays.copyOfRange(rows, 0, outRowId); + } + + /** + * Given a datablock and the column index, extracts the byte values for the column. Prefer using this function over + * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the + * desired type. + * This only works on ROW format. + * TODO: Add support for COLUMNAR format. + * @return byte array of values in the column + */ + public static byte[][] extractBytesValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) { + DataSchema dataSchema = dataBlock.getDataSchema(); + DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes(); + + // Get null bitmap for the column. + RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex]; + int numRows = dataBlock.getNumberOfRows(); + + byte[][] rows = new byte[numRows][]; + int outRowId = 0; + for (int inRowId = 0; inRowId < numRows; inRowId++) { + if (dataBlock.getInt(inRowId, filterArgIdx) == 1) { + if (nullBitmap != null && nullBitmap.contains(inRowId)) { + outRowId++; + continue; + } + rows[outRowId++] = dataBlock.getBytes(inRowId, columnIndex).getBytes(); + } + } + return Arrays.copyOfRange(rows, 0, outRowId); + } } diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java index 60efa69fde..df904123d2 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java +++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java @@ -324,7 +324,7 @@ public class PinotAggregateExchangeNodeInsertRule extends RelOptRule { orgAggCall.isApproximate(), orgAggCall.ignoreNulls(), argList, - orgAggCall.filterArg, + aggType.isInputIntermediateFormat() ? -1 : orgAggCall.filterArg, orgAggCall.distinctKeys, orgAggCall.collation, numberGroups, diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java index 632471f6b0..4c31fc86f4 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java @@ -19,6 +19,7 @@ package org.apache.pinot.query.parser; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -67,19 +68,30 @@ public class CalciteRexExpressionParser { // Relational conversion Utils // -------------------------------------------------------------------------- - public static List<Expression> overwriteSelectList(List<RexExpression> rexNodeList, PinotQuery pinotQuery) { - return addSelectList(new ArrayList<>(), rexNodeList, pinotQuery); - } - - public static List<Expression> addSelectList(List<Expression> existingList, List<RexExpression> rexNodeList, - PinotQuery pinotQuery) { - List<Expression> selectExpr = new ArrayList<>(existingList); - - final Iterator<RexExpression> iterator = rexNodeList.iterator(); + public static List<Expression> convertProjectList(List<RexExpression> projectList, PinotQuery pinotQuery) { + List<Expression> selectExpr = new ArrayList<>(); + final Iterator<RexExpression> iterator = projectList.iterator(); while (iterator.hasNext()) { final RexExpression next = iterator.next(); selectExpr.add(toExpression(next, pinotQuery)); } + return selectExpr; + } + + public static List<Expression> convertAggregateList(List<Expression> groupSetList, List<RexExpression> aggCallList, + List<Integer> filterArgIndices, PinotQuery pinotQuery) { + List<Expression> selectExpr = new ArrayList<>(groupSetList); + + for (int idx = 0; idx < aggCallList.size(); idx++) { + final RexExpression aggCall = aggCallList.get(idx); + int filterArgIdx = filterArgIndices.get(idx); + if (filterArgIdx == -1) { + selectExpr.add(toExpression(aggCall, pinotQuery)); + } else { + selectExpr.add(toExpression(new RexExpression.FunctionCall(SqlKind.FILTER, aggCall.getDataType(), "FILTER", + Arrays.asList(aggCall, new RexExpression.InputRef(filterArgIdx))), pinotQuery)); + } + } return selectExpr; } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java index c465fe93f6..5c8a0999c0 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java @@ -36,6 +36,8 @@ public class AggregateNode extends AbstractPlanNode { @ProtoProperties private List<RexExpression> _aggCalls; @ProtoProperties + private List<Integer> _filterArgIndices; + @ProtoProperties private List<RexExpression> _groupSet; @ProtoProperties private AggType _aggType; @@ -49,6 +51,7 @@ public class AggregateNode extends AbstractPlanNode { super(planFragmentId, dataSchema); Preconditions.checkState(areHintsValid(relHints), "invalid sql hint for agg node: {}", relHints); _aggCalls = aggCalls.stream().map(RexExpression::toRexExpression).collect(Collectors.toList()); + _filterArgIndices = aggCalls.stream().map(c -> c.filterArg).collect(Collectors.toList()); _groupSet = groupSet; _nodeHint = new NodeHint(relHints); _aggType = AggType.valueOf(PinotHintStrategyTable.getHintOption(relHints, PinotHintOptions.INTERNAL_AGG_OPTIONS, @@ -63,6 +66,10 @@ public class AggregateNode extends AbstractPlanNode { return _aggCalls; } + public List<Integer> getFilterArgIndices() { + return _filterArgIndices; + } + public List<RexExpression> getGroupSet() { return _groupSet; } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java index 494ad1d737..22814b4571 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java @@ -35,7 +35,6 @@ import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.common.request.context.FunctionContext; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.core.common.BlockValSet; -import org.apache.pinot.core.common.IntermediateStageBlockValSet; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory; import org.apache.pinot.query.planner.logical.LiteralHintUtils; @@ -44,7 +43,10 @@ import org.apache.pinot.query.planner.plannode.AbstractPlanNode; import org.apache.pinot.query.planner.plannode.AggregateNode.AggType; import org.apache.pinot.query.runtime.blocks.TransferableBlock; import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils; +import org.apache.pinot.query.runtime.operator.block.DataBlockValSet; +import org.apache.pinot.query.runtime.operator.block.FilteredDataBlockValSet; import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; +import org.apache.pinot.spi.data.FieldSpec; /** @@ -90,31 +92,28 @@ public class AggregateOperator extends MultiStageOperator { private MultistageAggregationExecutor _aggregationExecutor; private MultistageGroupByExecutor _groupByExecutor; - // TODO: refactor Pinot Reducer code to support the intermediate stage agg operator. - // aggCalls has to be a list of FunctionCall and cannot be null - // groupSet has to be a list of InputRef and cannot be null - // TODO: Add these two checks when we confirm we can handle error in upstream ctor call. - public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator, - DataSchema resultSchema, DataSchema inputSchema, List<RexExpression> aggCalls, List<RexExpression> groupSet, - AggType aggType) { - this(context, inputOperator, resultSchema, inputSchema, aggCalls, groupSet, aggType, - new AbstractPlanNode.NodeHint()); - } - @VisibleForTesting - public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator, - DataSchema resultSchema, DataSchema inputSchema, List<RexExpression> aggCalls, List<RexExpression> groupSet, - AggType aggType, AbstractPlanNode.NodeHint nodeHint) { + public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator, DataSchema resultSchema, + DataSchema inputSchema, List<RexExpression> aggCalls, List<RexExpression> groupSet, AggType aggType, + @Nullable List<Integer> filterArgIndices, @Nullable AbstractPlanNode.NodeHint nodeHint) { super(context); _inputOperator = inputOperator; _resultSchema = resultSchema; _inputSchema = inputSchema; _aggType = aggType; + // filter arg index array + int[] filterArgIndexArray; + if (filterArgIndices == null || filterArgIndices.size() == 0) { + filterArgIndexArray = null; + } else { + filterArgIndexArray = filterArgIndices.stream().mapToInt(Integer::intValue).toArray(); + } + // filter operand and literal hints if (nodeHint != null && nodeHint._hintOptions != null && nodeHint._hintOptions.get(PinotHintOptions.INTERNAL_AGG_OPTIONS) != null) { - _aggCallSignatureMap = LiteralHintUtils.hintStringToLiteralMap(nodeHint._hintOptions - .get(PinotHintOptions.INTERNAL_AGG_OPTIONS) - .get(PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE)); + _aggCallSignatureMap = LiteralHintUtils.hintStringToLiteralMap( + nodeHint._hintOptions.get(PinotHintOptions.INTERNAL_AGG_OPTIONS) + .get(PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE)); } else { _aggCallSignatureMap = Collections.emptyMap(); } @@ -126,6 +125,7 @@ public class AggregateOperator extends MultiStageOperator { // Convert groupSet to ExpressionContext that our aggregation functions understand. List<ExpressionContext> groupByExpr = getGroupSet(groupSet); + List<FunctionContext> functionContexts = getFunctionContexts(aggCalls); AggregationFunction[] aggFunctions = new AggregationFunction[functionContexts.size()]; @@ -136,12 +136,14 @@ public class AggregateOperator extends MultiStageOperator { // Initialize the appropriate executor. if (!groupSet.isEmpty()) { _isGroupByAggregation = true; - _groupByExecutor = new MultistageGroupByExecutor(groupByExpr, aggFunctions, aggType, _colNameToIndexMap, - _resultSchema); + _groupByExecutor = + new MultistageGroupByExecutor(groupByExpr, aggFunctions, filterArgIndexArray, aggType, _colNameToIndexMap, + _resultSchema); } else { _isGroupByAggregation = false; - _aggregationExecutor = new MultistageAggregationExecutor(aggFunctions, aggType, _colNameToIndexMap, - _resultSchema); + _aggregationExecutor = + new MultistageAggregationExecutor(aggFunctions, filterArgIndexArray, aggType, _colNameToIndexMap, + _resultSchema); } } @@ -253,7 +255,7 @@ public class AggregateOperator extends MultiStageOperator { // The literal value here does not matter. We create a dummy literal here just so that the count aggregation // has some column to process. if (aggArguments.isEmpty()) { - aggArguments.add(ExpressionContext.forIdentifier("__PLACEHOLDER__")); + aggArguments.add(ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "__PLACEHOLDER__")); } return new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, aggArguments); @@ -320,8 +322,8 @@ public class AggregateOperator extends MultiStageOperator { // TODO: If the previous block is not mailbox received, this method is not efficient. Then getDataBlock() will // convert the unserialized format to serialized format of BaseDataBlock. Then it will convert it back to column // value primitive type. - static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction, - TransferableBlock block, DataSchema inputDataSchema, Map<String, Integer> colNameToIndexMap) { + static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction, TransferableBlock block, + DataSchema inputDataSchema, Map<String, Integer> colNameToIndexMap, int filterArgIdx) { List<ExpressionContext> expressions = aggFunction.getInputExpressions(); int numExpressions = expressions.size(); if (numExpressions == 0) { @@ -330,17 +332,34 @@ public class AggregateOperator extends MultiStageOperator { Map<ExpressionContext, BlockValSet> blockValSetMap = new HashMap<>(); for (ExpressionContext expression : expressions) { - if (expression.getType().equals(ExpressionContext.Type.IDENTIFIER) - && !"__PLACEHOLDER__".equals(expression.getIdentifier())) { + if (expression.getType().equals(ExpressionContext.Type.IDENTIFIER) && !"__PLACEHOLDER__".equals( + expression.getIdentifier())) { int index = colNameToIndexMap.get(expression.getIdentifier()); DataSchema.ColumnDataType dataType = inputDataSchema.getColumnDataType(index); Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW"); - blockValSetMap.put(expression, new IntermediateStageBlockValSet(dataType, block.getDataBlock(), index)); + if (filterArgIdx == -1) { + blockValSetMap.put(expression, new DataBlockValSet(dataType, block.getDataBlock(), index)); + } else { + blockValSetMap.put(expression, + new FilteredDataBlockValSet(dataType, block.getDataBlock(), index, filterArgIdx)); + } } } return blockValSetMap; } + static int computeBlockNumRows(TransferableBlock block, int filterArgIdx) { + if (filterArgIdx == -1) { + return block.getNumRows(); + } else { + int rowCount = 0; + for (int rowId = 0; rowId < block.getNumRows(); rowId++) { + rowCount += block.getDataBlock().getInt(rowId, filterArgIdx) == 1 ? 1 : 0; + } + return rowCount; + } + } + static Object extractValueFromRow(AggregationFunction aggregationFunction, Object[] row, Map<String, Integer> colNameToIndexMap) { List<ExpressionContext> expressions = aggregationFunction.getInputExpressions(); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java index 19e4f66cc6..b352997ac1 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java @@ -21,6 +21,7 @@ package org.apache.pinot.query.runtime.operator; import java.util.Collections; import java.util.List; import java.util.Map; +import javax.annotation.Nullable; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.core.common.BlockValSet; @@ -42,13 +43,15 @@ public class MultistageAggregationExecutor { private final DataSchema _resultSchema; private final AggregationFunction[] _aggFunctions; + private final int[] _filterArgIndices; // Result holders for each mode. private final AggregationResultHolder[] _aggregateResultHolder; private final Object[] _mergeResultHolder; - public MultistageAggregationExecutor(AggregationFunction[] aggFunctions, + public MultistageAggregationExecutor(AggregationFunction[] aggFunctions, @Nullable int[] filterArgIndices, AggType aggType, Map<String, Integer> colNameToIndexMap, DataSchema resultSchema) { + _filterArgIndices = filterArgIndices; _aggFunctions = aggFunctions; _aggType = aggType; _colNameToIndexMap = colNameToIndexMap; @@ -116,11 +119,23 @@ public class MultistageAggregationExecutor { } private void processAggregate(TransferableBlock block, DataSchema inputDataSchema) { - for (int i = 0; i < _aggFunctions.length; i++) { - AggregationFunction aggregationFunction = _aggFunctions[i]; - Map<ExpressionContext, BlockValSet> blockValSetMap = - AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap); - aggregationFunction.aggregate(block.getNumRows(), _aggregateResultHolder[i], blockValSetMap); + if (_filterArgIndices == null) { + for (int i = 0; i < _aggFunctions.length; i++) { + AggregationFunction aggregationFunction = _aggFunctions[i]; + Map<ExpressionContext, BlockValSet> blockValSetMap = + AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap, -1); + aggregationFunction.aggregate(block.getNumRows(), _aggregateResultHolder[i], blockValSetMap); + } + } else { + for (int i = 0; i < _aggFunctions.length; i++) { + AggregationFunction aggregationFunction = _aggFunctions[i]; + int filterArgIdx = _filterArgIndices[i]; + Map<ExpressionContext, BlockValSet> blockValSetMap = + AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap, + filterArgIdx); + int numRows = AggregateOperator.computeBlockNumRows(block, filterArgIdx); + aggregationFunction.aggregate(numRows, _aggregateResultHolder[i], blockValSetMap); + } } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java index 5eacba025b..e33cc491cf 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java @@ -19,9 +19,11 @@ package org.apache.pinot.query.runtime.operator; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import javax.annotation.Nullable; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.core.common.BlockValSet; @@ -47,6 +49,7 @@ public class MultistageGroupByExecutor { private final List<ExpressionContext> _groupSet; private final AggregationFunction[] _aggFunctions; + private final int[] _filterArgIndices; // Group By Result holders for each mode private final GroupByResultHolder[] _aggregateResultHolders; @@ -57,11 +60,13 @@ public class MultistageGroupByExecutor { private final Map<Key, Integer> _groupKeyToIdMap; public MultistageGroupByExecutor(List<ExpressionContext> groupByExpr, AggregationFunction[] aggFunctions, - AggType aggType, Map<String, Integer> colNameToIndexMap, DataSchema resultSchema) { + @Nullable int[] filterArgIndices, AggType aggType, Map<String, Integer> colNameToIndexMap, + DataSchema resultSchema) { _aggType = aggType; _colNameToIndexMap = colNameToIndexMap; _groupSet = groupByExpr; _aggFunctions = aggFunctions; + _filterArgIndices = filterArgIndices; _resultSchema = resultSchema; _aggregateResultHolders = new GroupByResultHolder[_aggFunctions.length]; @@ -70,9 +75,9 @@ public class MultistageGroupByExecutor { _groupKeyToIdMap = new HashMap<>(); for (int i = 0; i < _aggFunctions.length; i++) { - _aggregateResultHolders[i] = - _aggFunctions[i].createGroupByResultHolder(InstancePlanMakerImplV2.DEFAULT_MAX_INITIAL_RESULT_HOLDER_CAPACITY, - InstancePlanMakerImplV2.DEFAULT_NUM_GROUPS_LIMIT); + _aggregateResultHolders[i] = _aggFunctions[i].createGroupByResultHolder( + InstancePlanMakerImplV2.DEFAULT_MAX_INITIAL_RESULT_HOLDER_CAPACITY, + InstancePlanMakerImplV2.DEFAULT_NUM_GROUPS_LIMIT); } } @@ -129,15 +134,29 @@ public class MultistageGroupByExecutor { } private void processAggregate(TransferableBlock block, DataSchema inputDataSchema) { - int[] intKeys = generateGroupByKeys(block.getContainer()); - - for (int i = 0; i < _aggFunctions.length; i++) { - AggregationFunction aggregationFunction = _aggFunctions[i]; - Map<ExpressionContext, BlockValSet> blockValSetMap = - AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap); - GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i]; - groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size()); - aggregationFunction.aggregateGroupBySV(block.getNumRows(), intKeys, groupByResultHolder, blockValSetMap); + if (_filterArgIndices == null) { + int[] intKeys = generateGroupByKeys(block.getContainer()); + for (int i = 0; i < _aggFunctions.length; i++) { + AggregationFunction aggregationFunction = _aggFunctions[i]; + Map<ExpressionContext, BlockValSet> blockValSetMap = + AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap, -1); + GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i]; + groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size()); + aggregationFunction.aggregateGroupBySV(block.getNumRows(), intKeys, groupByResultHolder, blockValSetMap); + } + } else { + for (int i = 0; i < _aggFunctions.length; i++) { + AggregationFunction aggregationFunction = _aggFunctions[i]; + int filterArgIdx = _filterArgIndices[i]; + int[] intKeys = generateGroupByKeys(block.getContainer(), filterArgIdx); + Map<ExpressionContext, BlockValSet> blockValSetMap = + AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap, + filterArgIdx); + int numRows = AggregateOperator.computeBlockNumRows(block, filterArgIdx); + GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i]; + groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size()); + aggregationFunction.aggregateGroupBySV(numRows, intKeys, groupByResultHolder, blockValSetMap); + } } } @@ -191,4 +210,42 @@ public class MultistageGroupByExecutor { } return rowIntKeys; } + + /** + * Creates the group by key for each row. Converts the key into a 0-index based int value that can be used by + * GroupByAggregationResultHolders used in v1 aggregations. + * <p> + * Returns the int key for each row. + */ + private int[] generateGroupByKeys(List<Object[]> rows, int filterArgIndex) { + int numRows = rows.size(); + int[] rowIntKeys = new int[numRows]; + int numKeys = _groupSet.size(); + if (filterArgIndex == -1) { + for (int rowId = 0; rowId < numRows; rowId++) { + Object[] row = rows.get(rowId); + Object[] keyValues = new Object[numKeys]; + for (int j = 0; j < numKeys; j++) { + keyValues[j] = row[_colNameToIndexMap.get(_groupSet.get(j).getIdentifier())]; + } + Key rowKey = new Key(keyValues); + rowIntKeys[rowId] = _groupKeyToIdMap.computeIfAbsent(rowKey, k -> _groupKeyToIdMap.size()); + } + return rowIntKeys; + } else { + int outRowId = 0; + for (int inRowId = 0; inRowId < numRows; inRowId++) { + Object[] row = rows.get(inRowId); + if ((Boolean) row[filterArgIndex]) { + Object[] keyValues = new Object[numKeys]; + for (int j = 0; j < numKeys; j++) { + keyValues[j] = row[_colNameToIndexMap.get(_groupSet.get(j).getIdentifier())]; + } + Key rowKey = new Key(keyValues); + rowIntKeys[outRowId++] = _groupKeyToIdMap.computeIfAbsent(rowKey, k -> _groupKeyToIdMap.size()); + } + } + return Arrays.copyOfRange(rowIntKeys, 0, outRowId); + } + } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/DataBlockValSet.java similarity index 90% copy from pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java copy to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/DataBlockValSet.java index 7ddaaf04c9..e1bbc077ef 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/DataBlockValSet.java @@ -16,13 +16,14 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.core.common; +package org.apache.pinot.query.runtime.operator.block; import java.math.BigDecimal; import javax.annotation.Nullable; import org.apache.pinot.common.datablock.DataBlock; import org.apache.pinot.common.datablock.DataBlockUtils; import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.segment.spi.index.reader.Dictionary; import org.apache.pinot.spi.data.FieldSpec; import org.roaringbitmap.RoaringBitmap; @@ -34,13 +35,13 @@ import org.roaringbitmap.RoaringBitmap; * aggregations using v1 aggregation functions. * TODO: Support MV */ -public class IntermediateStageBlockValSet implements BlockValSet { - private final FieldSpec.DataType _dataType; - private final DataBlock _dataBlock; - private final int _index; - private final RoaringBitmap _nullBitMap; +public class DataBlockValSet implements BlockValSet { + protected final FieldSpec.DataType _dataType; + protected final DataBlock _dataBlock; + protected final int _index; + protected final RoaringBitmap _nullBitMap; - public IntermediateStageBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex) { + public DataBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex) { _dataType = columnDataType.toDataType(); _dataBlock = dataBlock; _index = colIndex; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/FilteredDataBlockValSet.java similarity index 86% rename from pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java rename to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/FilteredDataBlockValSet.java index 7ddaaf04c9..e4231fbd71 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/FilteredDataBlockValSet.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.core.common; +package org.apache.pinot.query.runtime.operator.block; import java.math.BigDecimal; import javax.annotation.Nullable; @@ -27,6 +27,7 @@ import org.apache.pinot.segment.spi.index.reader.Dictionary; import org.apache.pinot.spi.data.FieldSpec; import org.roaringbitmap.RoaringBitmap; + /** * In the multistage engine, the leaf stage servers process the data in columnar fashion. By the time the * intermediate stage receives the projected column, they are converted to a row based format. This class provides @@ -34,17 +35,13 @@ import org.roaringbitmap.RoaringBitmap; * aggregations using v1 aggregation functions. * TODO: Support MV */ -public class IntermediateStageBlockValSet implements BlockValSet { - private final FieldSpec.DataType _dataType; - private final DataBlock _dataBlock; - private final int _index; - private final RoaringBitmap _nullBitMap; +public class FilteredDataBlockValSet extends DataBlockValSet { + private final int _filterIdx; - public IntermediateStageBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex) { - _dataType = columnDataType.toDataType(); - _dataBlock = dataBlock; - _index = colIndex; - _nullBitMap = dataBlock.getNullRowIds(colIndex); + public FilteredDataBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex, + int filterIdx) { + super(columnDataType, dataBlock, colIndex); + _filterIdx = filterIdx; } /** @@ -80,37 +77,37 @@ public class IntermediateStageBlockValSet implements BlockValSet { @Override public int[] getIntValuesSV() { - return DataBlockUtils.extractIntValuesForColumn(_dataBlock, _index); + return DataBlockUtils.extractIntValuesForColumn(_dataBlock, _index, _filterIdx); } @Override public long[] getLongValuesSV() { - return DataBlockUtils.extractLongValuesForColumn(_dataBlock, _index); + return DataBlockUtils.extractLongValuesForColumn(_dataBlock, _index, _filterIdx); } @Override public float[] getFloatValuesSV() { - return DataBlockUtils.extractFloatValuesForColumn(_dataBlock, _index); + return DataBlockUtils.extractFloatValuesForColumn(_dataBlock, _index, _filterIdx); } @Override public double[] getDoubleValuesSV() { - return DataBlockUtils.extractDoubleValuesForColumn(_dataBlock, _index); + return DataBlockUtils.extractDoubleValuesForColumn(_dataBlock, _index, _filterIdx); } @Override public BigDecimal[] getBigDecimalValuesSV() { - return DataBlockUtils.extractBigDecimalValuesForColumn(_dataBlock, _index); + return DataBlockUtils.extractBigDecimalValuesForColumn(_dataBlock, _index, _filterIdx); } @Override public String[] getStringValuesSV() { - return DataBlockUtils.extractStringValuesForColumn(_dataBlock, _index); + return DataBlockUtils.extractStringValuesForColumn(_dataBlock, _index, _filterIdx); } @Override public byte[][] getBytesValuesSV() { - return DataBlockUtils.extractBytesValuesForColumn(_dataBlock, _index); + return DataBlockUtils.extractBytesValuesForColumn(_dataBlock, _index, _filterIdx); } @Override diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java index 79f2275274..2340545437 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java @@ -101,7 +101,7 @@ public class PhysicalPlanVisitor implements PlanNodeVisitor<MultiStageOperator, DataSchema resultSchema = node.getDataSchema(); return new AggregateOperator(context.getOpChainExecutionContext(), nextOperator, resultSchema, inputSchema, - node.getAggCalls(), node.getGroupSet(), node.getAggType(), node.getNodeHint()); + node.getAggCalls(), node.getGroupSet(), node.getAggType(), node.getFilterArgIndices(), node.getNodeHint()); } @Override diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java index 8fa6df2756..5e3873fbcd 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java @@ -71,8 +71,8 @@ public class ServerPlanRequestVisitor implements PlanNodeVisitor<Void, ServerPla .setGroupByList(CalciteRexExpressionParser.convertGroupByList(node.getGroupSet(), context.getPinotQuery())); // set agg list context.getPinotQuery().setSelectList( - CalciteRexExpressionParser.addSelectList(context.getPinotQuery().getGroupByList(), node.getAggCalls(), - context.getPinotQuery())); + CalciteRexExpressionParser.convertAggregateList(context.getPinotQuery().getGroupByList(), node.getAggCalls(), + node.getFilterArgIndices(), context.getPinotQuery())); return null; } @@ -149,7 +149,7 @@ public class ServerPlanRequestVisitor implements PlanNodeVisitor<Void, ServerPla public Void visitProject(ProjectNode node, ServerPlanRequestContext context) { visitChildren(node, context); context.getPinotQuery() - .setSelectList(CalciteRexExpressionParser.overwriteSelectList(node.getProjects(), context.getPinotQuery())); + .setSelectList(CalciteRexExpressionParser.convertProjectList(node.getProjects(), context.getPinotQuery())); return null; } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java index 8048bad5f4..d32d38b2b0 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java @@ -79,7 +79,7 @@ public class AggregateOperatorTest { DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE}); AggregateOperator operator = new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group, - AggType.INTERMEDIATE); + AggType.INTERMEDIATE, null, null); // When: TransferableBlock block1 = operator.nextBlock(); // build @@ -101,7 +101,7 @@ public class AggregateOperatorTest { DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE}); AggregateOperator operator = new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group, - AggType.LEAF); + AggType.LEAF, null, null); // When: TransferableBlock block = operator.nextBlock(); @@ -125,7 +125,7 @@ public class AggregateOperatorTest { DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE}); AggregateOperator operator = new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group, - AggType.LEAF); + AggType.LEAF, null, null); // When: TransferableBlock block1 = operator.nextBlock(); // build when reading NoOp block @@ -150,7 +150,7 @@ public class AggregateOperatorTest { DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE}); AggregateOperator operator = new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group, - AggType.INTERMEDIATE); + AggType.INTERMEDIATE, null, null); // When: TransferableBlock block1 = operator.nextBlock(); @@ -177,7 +177,7 @@ public class AggregateOperatorTest { DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, LONG}); AggregateOperator operator = new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group, - AggType.FINAL); + AggType.FINAL, null, null); // When: TransferableBlock block1 = operator.nextBlock(); @@ -201,7 +201,7 @@ public class AggregateOperatorTest { DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{STRING, DOUBLE}); AggregateOperator sum0GroupBy1 = new AggregateOperator(OperatorTestUtil.getDefaultContext(), upstreamOperator, outSchema, inSchema, Collections.singletonList(agg), - Collections.singletonList(new RexExpression.InputRef(1)), AggType.LEAF); + Collections.singletonList(new RexExpression.InputRef(1)), AggType.LEAF, null, null); TransferableBlock result = sum0GroupBy1.getNextBlock(); while (result.isNoOpBlock()) { result = sum0GroupBy1.getNextBlock(); @@ -229,7 +229,7 @@ public class AggregateOperatorTest { // When: AggregateOperator operator = new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group, - AggType.INTERMEDIATE); + AggType.INTERMEDIATE, null, null); } @Test @@ -248,7 +248,7 @@ public class AggregateOperatorTest { DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE}); AggregateOperator operator = new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group, - AggType.INTERMEDIATE); + AggType.INTERMEDIATE, null, null); // When: TransferableBlock block = operator.nextBlock(); diff --git a/pinot-query-runtime/src/test/resources/queries/FilterAggregates.json b/pinot-query-runtime/src/test/resources/queries/FilterAggregates.json new file mode 100644 index 0000000000..af308031ab --- /dev/null +++ b/pinot-query-runtime/src/test/resources/queries/FilterAggregates.json @@ -0,0 +1,166 @@ +{ + "general_aggregate_with_filter_where": { + "tables": { + "tbl": { + "schema": [ + {"name": "int_col", "type": "INT"}, + {"name": "double_col", "type": "DOUBLE"}, + {"name": "string_col", "type": "STRING"}, + {"name": "bool_col", "type": "BOOLEAN"} + ], + "inputs": [ + [2, 300, "a", true], + [2, 400, "a", true], + [3, 100, "b", false], + [100, 1, "b", false], + [101, 1.01, "c", false], + [150, 1.5, "c", false], + [175, 1.75, "c", true] + ] + } + }, + "queries": [ + { + "ignored": true, + "comments": "FILTER WHERE clause with IN hard-wired to translate into subquery, which in this case should not happen.", + "sql": "SELECT min(double_col) FILTER (WHERE string_col IN ('a', 'b')), count(*) FROM {tbl}" + }, + { + "ignored": true, + "comments": "IS NULL and IS NOT NULL is not yet supported in filter conversion.", + "sql": "SELECT min(double_col) FILTER (WHERE string_col IS NOT NULL), count(*) FROM {tbl}" + }, + { + "ignored": true, + "comments": "agg with filter and group-by causes conversion issue on v1 if the group-by field is not in the select list", + "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} GROUP BY int_col" + }, + { + "ignored": true, + "comments": "agg with group by and filter will create NULL-able columns that are unsupported with current AGG FILTER WHERE semantics.", + "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(*) FROM {tbl} GROUP BY int_col, string_col" + }, + { + "ignored": true, + "comments": "mixed/conflict filter that requires merging in v1 is not supported", + "sql": "SELECT double_col, bool_col, count(int_col) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} WHERE string_col = 'b' GROUP BY double_col, bool_col" + }, + { + "ignored": true, + "comments": "FILTER WHERE clause might omit group key entirely if nothing is being selected out, this is non-standard SQL behavior but it is v1 behavior", + "sql": "SELECT int_col, count(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} GROUP BY int_col" + }, + { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl}" }, + { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(*) FROM {tbl}" }, + { "sql": "SELECT min(int_col) FILTER (WHERE bool_col IS TRUE), max(int_col) FILTER (WHERE bool_col AND int_col < 10), avg(int_col) FILTER (WHERE MOD(int_col, 3) = 0), sum(int_col), count(int_col), count(distinct(int_col)), count(*) FILTER (WHERE MOD(int_col, 3) = 0) FROM {tbl}" }, + { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} WHERE string_col='b'" }, + { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(*) FROM {tbl} WHERE string_col='b'" }, + { "sql": "SELECT int_col, COALESCE(count(double_col) FILTER (WHERE string_col = 'a' OR int_col > 0), 0), count(*) FROM {tbl} GROUP BY int_col" }, + { + "ignored": true, + "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched", + "sql": "SELECT int_col, string_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), 0), avg(double_col), sum(double_col), count(double_col), COALESCE(count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), 0) FROM {tbl} GROUP BY int_col, string_col" + }, + { + "ignored": true, + "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched", + "sql": "SELECT double_col, COALESCE(min(int_col) FILTER (WHERE bool_col IS TRUE), 0), COALESCE(max(int_col) FILTER (WHERE bool_col AND int_col < 10), 0), COALESCE(avg(int_col) FILTER (WHERE MOD(int_col, 3) = 0), 0), sum(int_col), count(int_col), count(distinct(int_col)), count(string_col) FILTER (WHERE MOD(int_col, 3) = 0) FROM {tbl} GROUP BY double_col" + }, + { "sql": "SELECT double_col, bool_col, count(int_col) FILTER (WHERE string_col = 'a' OR int_col > 10), count(int_col) FROM {tbl} WHERE string_col IN ('a', 'b') GROUP BY double_col, bool_col" }, + { + "ignored": true, + "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched", + "sql": "SELECT bool_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), 0), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(string_col) FROM {tbl} WHERE string_col='b' GROUP BY bool_col" + } + ] + }, + "general_aggregate_with_filter_where_after_join": { + "tables": { + "tbl1": { + "schema": [ + {"name": "int_col", "type": "INT"}, + {"name": "double_col", "type": "DOUBLE"}, + {"name": "string_col", "type": "STRING"}, + {"name": "bool_col", "type": "BOOLEAN"} + ], + "inputs": [ + [2, 300, "a", true], + [2, 400, "a", true], + [3, 100, "b", false], + [100, 1, "b", false], + [101, 1.01, "c", false], + [150, 1.5, "c", false], + [175, 1.75, "c", true] + ] + }, + "tbl2": { + "schema":[ + {"name": "int_col2", "type": "INT"}, + {"name": "string_col2", "type": "STRING"}, + {"name": "double_col2", "type": "DOUBLE"} + ], + "inputs": [ + [1, "apple", 1000.0], + [2, "a", 1.323], + [3, "b", 1212.12], + [3, "c", 341], + [4, "orange", 1212.121] + ] + } + }, + "queries": [ + { + "ignored": true, + "comments": "FILTER WHERE clause with IN hard-wired to translate into subquery, which in this case should not happen.", + "sql": "SELECT min(double_col) FILTER (WHERE string_col IN ('a', 'b')), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2" + }, + { + "ignored": true, + "comments": "IS NULL and IS NOT NULL is not yet supported in filter conversion.", + "sql": "SELECT min(double_col) FILTER (WHERE string_col IS NOT NULL), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2" + }, + { + "ignored": true, + "comments": "agg with filter and group-by causes conversion issue on v1 if the group-by field is not in the select list", + "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2" + }, + { + "ignored": true, + "comments": "agg with group by and filter will create NULL-able columns that are unsupported with current AGG FILTER WHERE semantics.", + "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2, string_col" + }, + { + "ignored": true, + "comments": "mixed/conflict filter that requires merging in v1 is not supported", + "sql": "SELECT double_col, bool_col, count(int_col2) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col = 'b' GROUP BY double_col, bool_col" + }, + { + "ignored": true, + "comments": "FILTER WHERE clause might omit group key entirely if nothing is being selected out, this is non-standard SQL behavior but it is v1 behavior", + "sql": "SELECT int_col2, count(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2" + }, + { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2" }, + { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2" }, + { "sql": "SELECT min(int_col2) FILTER (WHERE bool_col IS TRUE), max(int_col2) FILTER (WHERE bool_col AND int_col2 < 10), avg(int_col2) FILTER (WHERE MOD(int_col2, 3) = 0), sum(int_col2), count(int_col2), count(distinct(int_col2)), count(*) FILTER (WHERE MOD(int_col2, 3) = 0) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2" }, + { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col='b'" }, + { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col='b'" }, + { "sql": "SELECT int_col2, COALESCE(count(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 0), 0), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2" }, + { + "ignored": true, + "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched", + "sql": "SELECT int_col2, string_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), 0), avg(double_col), sum(double_col), count(double_col), COALESCE(count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), 0) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2, string_col" + }, + { + "ignored": true, + "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched", + "sql": "SELECT double_col, COALESCE(min(int_col2) FILTER (WHERE bool_col IS TRUE), 0), COALESCE(max(int_col2) FILTER (WHERE bool_col AND int_col2 < 10), 0), COALESCE(avg(int_col2) FILTER (WHERE MOD(int_col2, 3) = 0), 0), sum(int_col2), count(int_col2), count(distinct(int_col2)), count(string_col) FILTER (WHERE MOD(int_col2, 3) = 0) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY double_col" + }, + { "sql": "SELECT double_col, bool_col, count(int_col2) FILTER (WHERE string_col = 'a' OR int_col2 > 10), count(int_col2) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col IN ('a', 'b') GROUP BY double_col, bool_col" }, + { + "ignored": true, + "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched", + "sql": "SELECT bool_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), 0), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(string_col) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col='b' GROUP BY bool_col" + } + ] + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org