walterddr commented on code in PR #11617: URL: https://github.com/apache/pinot/pull/11617#discussion_r1330404415
########## pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java: ########## @@ -205,90 +196,78 @@ private TransferableBlock consumeAggregation() { return block; } - private List<FunctionContext> getFunctionContexts(List<RexExpression> aggCalls) { + private AggregationFunction<?, ?>[] getAggFunctions(List<RexExpression> aggCalls, + Map<Integer, Map<Integer, Literal>> literalArgumentsMap) { int numFunctions = aggCalls.size(); - List<FunctionContext> functionContexts = new ArrayList<>(numFunctions); + AggregationFunction<?, ?>[] aggFunctions = new AggregationFunction[numFunctions]; for (int i = 0; i < numFunctions; i++) { RexExpression.FunctionCall functionCall = (RexExpression.FunctionCall) aggCalls.get(i); - FunctionContext funcContext = convertRexExpressionsToFunctionContext(i, functionCall); - functionContexts.add(funcContext); + Map<Integer, Literal> literalArguments = literalArgumentsMap.getOrDefault(i, Collections.emptyMap()); + aggFunctions[i] = getAggFunction(functionCall, literalArguments); } - return functionContexts; + return aggFunctions; } - private FunctionContext convertRexExpressionsToFunctionContext(int aggIdx, - RexExpression.FunctionCall aggFunctionCall) { - // Extract details from RexExpression aggFunctionCall. - String functionName = aggFunctionCall.getFunctionName(); - List<RexExpression> functionOperands = aggFunctionCall.getFunctionOperands(); - - List<ExpressionContext> aggArguments = new ArrayList<>(); - for (int argIdx = 0; argIdx < functionOperands.size(); argIdx++) { - RexExpression operand = functionOperands.get(argIdx); - ExpressionContext exprContext = convertRexExpressionToExpressionContext(aggIdx, argIdx, operand); - aggArguments.add(exprContext); + private AggregationFunction<?, ?> getAggFunction(RexExpression.FunctionCall functionCall, + Map<Integer, Literal> literalArguments) { + String functionName = functionCall.getFunctionName(); + List<RexExpression> operands = functionCall.getFunctionOperands(); + int numArguments = operands.size(); + if (numArguments == 0) { + Preconditions.checkState(functionName.equals("COUNT"), + "Aggregate function without argument must be COUNT, got: %s", functionName); + return COUNT_STAR_AGG_FUNCTION; } - // add additional arguments for aggFunctionCall + // For intermediate aggregation, we might need to append the arguments to match the signature of the aggregation + int numExpectedArguments = numArguments; if (_aggType.isInputIntermediateFormat()) { - rewriteAggArgumentForIntermediateInput(aggArguments, aggIdx); - } - // This can only be true for COUNT aggregation functions on intermediate stage. - // The literal value here does not matter. We create a dummy literal here just so that the count aggregation - // has some column to process. - if (aggArguments.isEmpty()) { - aggArguments.add(ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "__PLACEHOLDER__")); - } - - return new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, aggArguments); - } - - private void rewriteAggArgumentForIntermediateInput(List<ExpressionContext> aggArguments, int aggIdx) { - Map<Integer, Literal> aggCallSignature = _aggCallSignatureMap.get(aggIdx); - if (aggCallSignature != null && !aggCallSignature.isEmpty()) { - int argListSize = aggCallSignature.get(-1).getIntValue(); - for (int argIdx = 1; argIdx < argListSize; argIdx++) { - Literal aggIdxLiteral = aggCallSignature.get(argIdx); - if (aggIdxLiteral != null) { - aggArguments.add(ExpressionContext.forLiteralContext(aggIdxLiteral)); - } else { - aggArguments.add(ExpressionContext.forIdentifier("__PLACEHOLDER__")); - } + Literal literal = literalArguments.get(-1); + if (literal != null) { + numExpectedArguments = literal.getIntValue(); } } - } - private ExpressionContext convertRexExpressionToExpressionContext(int aggIdx, int argIdx, RexExpression rexExpr) { - ExpressionContext exprContext; - if (_aggCallSignatureMap.get(aggIdx) != null && _aggCallSignatureMap.get(aggIdx).get(argIdx) != null) { - return ExpressionContext.forLiteralContext(_aggCallSignatureMap.get(aggIdx).get(argIdx)); - } - - // This is used only for aggregation arguments and groupby columns. The rexExpression can never be a function type. - switch (rexExpr.getKind()) { - case INPUT_REF: { - RexExpression.InputRef inputRef = (RexExpression.InputRef) rexExpr; - int identifierIndex = inputRef.getIndex(); - String columnName = _inputSchema.getColumnName(identifierIndex); - // Calcite generates unique column names for aggregation functions. For example, select avg(col1), sum(col1) - // will generate names $f0 and $f1 for avg and sum respectively. We use a map to store the name -> index - // mapping to extract the required column value from row-based container and fetch the input datatype for the - // column. - _colNameToIndexMap.put(columnName, identifierIndex); - exprContext = ExpressionContext.forIdentifier(columnName); - break; + List<ExpressionContext> arguments = new ArrayList<>(numExpectedArguments); + for (int i = 0; i < numExpectedArguments; i++) { + Literal literal = literalArguments.get(i); + if (literal != null) { + arguments.add(ExpressionContext.forLiteralContext(literal)); + continue; + } + if (i >= numArguments) { Review Comment: intermediate agg always have 1 real-argument (it cannot even be 0 b/c count would've been converted to SUM), and the rest are placeholders, why do we need to check the numArguments again here? simply do the loop with `for (int i = 1 ...` should be suffice? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org