mayankshriv commented on a change in pull request #5832: URL: https://github.com/apache/incubator-pinot/pull/5832#discussion_r467334796
########## File path: pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java ########## @@ -78,9 +102,9 @@ public DistinctCountThetaSketchAggregationFunction(List<ExpressionContext> argum throws SqlParseException { int numArguments = arguments.size(); - // NOTE: This function expects at least 3 arguments: theta-sketch column, parameters, post-aggregation expression. - Preconditions.checkArgument(numArguments >= 3, - "DistinctCountThetaSketch expects at least three arguments (theta-sketch column, parameters, post-aggregation expression), got: ", + // NOTE: This function expects at least 4 arguments: theta-sketch column, parameters, post-aggregation expression. Review comment: The comment says 4 arguments, but it only lists three of them? ########## File path: pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java ########## @@ -97,6 +121,10 @@ public DistinctCountThetaSketchAggregationFunction(List<ExpressionContext> argum // Initialize the theta-sketch set operation builder _setOperationBuilder = getSetOperationBuilder(); + // index of the original input predicates + // index[0] = $1 Review comment: Remove unused code? ########## File path: pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java ########## @@ -526,28 +543,51 @@ private Predicate getPredicate(String predicateString) { * passed to this method to be used when evaluating the expression. * * @param postAggregationExpression Post-aggregation expression to evaluate (modeled as a filter) + * @param expressions list of aggregation function parameters * @param sketchMap Precomputed sketches for predicates that are part of the expression. * @return Overall evaluated sketch for the expression. */ - private Sketch evalPostAggregationExpression(FilterContext postAggregationExpression, - Map<Predicate, Sketch> sketchMap) { - switch (postAggregationExpression.getType()) { - case AND: - Intersection intersection = _setOperationBuilder.buildIntersection(); - for (FilterContext child : postAggregationExpression.getChildren()) { - intersection.update(evalPostAggregationExpression(child, sketchMap)); - } - return intersection.getResult(); - case OR: - Union union = _setOperationBuilder.buildUnion(); - for (FilterContext child : postAggregationExpression.getChildren()) { - union.update(evalPostAggregationExpression(child, sketchMap)); + private Sketch evalPostAggregationExpression( + final ExpressionContext postAggregationExpression, + final List<Predicate> expressions, + final Map<Predicate, Sketch> sketchMap) { + if (postAggregationExpression.getType() == ExpressionContext.Type.LITERAL) { + throw new IllegalArgumentException("Literal not supported in post-aggregation function"); + } + + if (postAggregationExpression.getType() == ExpressionContext.Type.IDENTIFIER) { + final Predicate exp = + expressions.get(extractSubstitutionPosition(postAggregationExpression.getLiteral()) - 1); + return sketchMap.get(exp); + } + + // shouldn't throw exception because of the validation in the constructor + final MergeFunction func = + MergeFunction.valueOf(postAggregationExpression.getFunction().getFunctionName().toUpperCase()); + + // handle functions recursively + switch(func) { + case SET_UNION: + final Union union = _setOperationBuilder.buildUnion(); + for (final ExpressionContext exp : postAggregationExpression.getFunction().getArguments()) { + union.update(evalPostAggregationExpression(exp, expressions, sketchMap)); } return union.getResult(); - case PREDICATE: - return sketchMap.get(postAggregationExpression.getPredicate()); + case SET_INTERSECT: + final Intersection intersection = _setOperationBuilder.buildIntersection(); + for (final ExpressionContext exp : postAggregationExpression.getFunction().getArguments()) { + intersection.update(evalPostAggregationExpression(exp, expressions, sketchMap)); + } + return intersection.getResult(); + case SET_DIFF: + final List<ExpressionContext> args = postAggregationExpression.getFunction().getArguments(); + final AnotB diff = _setOperationBuilder.buildANotB(); + final Sketch a = evalPostAggregationExpression(args.get(0), expressions, sketchMap); + final Sketch b = evalPostAggregationExpression(args.get(1), expressions, sketchMap); + diff.update(a, b); + return diff.getResult(); default: - throw new IllegalStateException(); + throw new IllegalStateException("Invalid post-aggregation function."); Review comment: Include the string representation of function? ########## File path: pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java ########## @@ -526,28 +543,51 @@ private Predicate getPredicate(String predicateString) { * passed to this method to be used when evaluating the expression. * * @param postAggregationExpression Post-aggregation expression to evaluate (modeled as a filter) + * @param expressions list of aggregation function parameters * @param sketchMap Precomputed sketches for predicates that are part of the expression. * @return Overall evaluated sketch for the expression. */ - private Sketch evalPostAggregationExpression(FilterContext postAggregationExpression, - Map<Predicate, Sketch> sketchMap) { - switch (postAggregationExpression.getType()) { - case AND: - Intersection intersection = _setOperationBuilder.buildIntersection(); - for (FilterContext child : postAggregationExpression.getChildren()) { - intersection.update(evalPostAggregationExpression(child, sketchMap)); - } - return intersection.getResult(); - case OR: - Union union = _setOperationBuilder.buildUnion(); - for (FilterContext child : postAggregationExpression.getChildren()) { - union.update(evalPostAggregationExpression(child, sketchMap)); + private Sketch evalPostAggregationExpression( + final ExpressionContext postAggregationExpression, + final List<Predicate> expressions, + final Map<Predicate, Sketch> sketchMap) { + if (postAggregationExpression.getType() == ExpressionContext.Type.LITERAL) { + throw new IllegalArgumentException("Literal not supported in post-aggregation function"); + } + + if (postAggregationExpression.getType() == ExpressionContext.Type.IDENTIFIER) { + final Predicate exp = + expressions.get(extractSubstitutionPosition(postAggregationExpression.getLiteral()) - 1); Review comment: Perhaps pre-substitition in the constructor would be better? For example, if the same $k arg is repeated multiple times, we might avoid the use of matcher using a temporary alias map? ########## File path: pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java ########## @@ -108,35 +136,24 @@ public DistinctCountThetaSketchAggregationFunction(List<ExpressionContext> argum Preconditions.checkArgument(paramsExpression.getType() == ExpressionContext.Type.LITERAL, "Last argument of DistinctCountThetaSketch must be literal (post-aggregation expression)"); _postAggregationExpression = QueryContextConverterUtils - .getFilter(CalciteSqlParser.compileToExpression(postAggregationExpression.getLiteral())); + .getExpression(CalciteSqlParser.compileToExpression(postAggregationExpression.getLiteral())); // Initialize the predicate map _predicateInfoMap = new HashMap<>(); - if (numArguments > 3) { - // Predicates are explicitly specified - for (int i = 2; i < numArguments - 1; i++) { - ExpressionContext predicateExpression = arguments.get(i); - Preconditions.checkArgument(predicateExpression.getType() == ExpressionContext.Type.LITERAL, - "Third to second last argument of DistinctCountThetaSketch must be literal (predicate expression)"); - Predicate predicate = getPredicate(predicateExpression.getLiteral()); - _inputExpressions.add(predicate.getLhs()); - _predicateInfoMap.put(predicate, new PredicateInfo(predicate)); - } - } else { - // Auto-derive predicates from the post-aggregation expression - Stack<FilterContext> stack = new Stack<>(); - stack.push(_postAggregationExpression); - while (!stack.isEmpty()) { - FilterContext filter = stack.pop(); - if (filter.getType() == FilterContext.Type.PREDICATE) { - Predicate predicate = filter.getPredicate(); - _inputExpressions.add(predicate.getLhs()); - _predicateInfoMap.put(predicate, new PredicateInfo(predicate)); - } else { - stack.addAll(filter.getChildren()); - } - } + + // Predicates are explicitly specified Review comment: We should still keep the auto-deriving of predicates. Granted, we won't be able to use $ notation in that case though. What do you think? ########## File path: pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java ########## @@ -576,6 +616,69 @@ private SetOperationBuilder getSetOperationBuilder() { : SetOperation.builder().setNominalEntries(_thetaSketchParams.getNominalEntries()); } + /** + * Validates that the function context's substitution parameters ($1, $2, etc) does not exceed the number + * of predicates passed into the post-aggregation function. + * + * For example, if the post aggregation function is: + * INTERSECT($1, $2, $3) + * + * But there are only 2 arguments passed into the aggregation function, throw an error + * @param context The parsed function context that's a tree structure + * @param numPredicates Max number of predicates available to be substituted + */ + private static void validatePostAggregationExpression(final ExpressionContext context, final int numPredicates) { + if (context.getType() == ExpressionContext.Type.LITERAL) { + throw new IllegalArgumentException("Invalid post-aggregation function expression syntax."); + } + + if (context.getType() == ExpressionContext.Type.IDENTIFIER) { + int id = extractSubstitutionPosition(context.getIdentifier()); + if (id <= 0) + throw new IllegalArgumentException("Argument substitution starts at $1"); + if (id > numPredicates) + throw new IllegalArgumentException("Argument substitution exceeded number of predicates"); + // if none of the invalid conditions are met above, exit out early + return; + } + + if (!MergeFunction.isValid(context.getFunction().getFunctionName())) { + final String allowed = + Arrays.stream(MergeFunction.values()) + .map(MergeFunction::name) + .collect(Collectors.joining(",")); + throw new IllegalArgumentException( + String.format("Invalid Theta Sketch aggregation function. Allowed: [%s]", allowed)); + } + + switch (MergeFunction.valueOf(context.getFunction().getFunctionName().toUpperCase())) { + case SET_DIFF: + // set diff can only have 2 arguments + if (context.getFunction().getArguments().size() != 2) { Review comment: +1 ########## File path: pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountThetaSketchTest.java ########## @@ -117,60 +121,77 @@ public void testGroupBySql() { testThetaSketches(true, true); } + @Test(expectedExceptions = BadQueryRequestException.class, dataProvider = "badQueries") + public void testInvalidNoPredicates(final String query) { + getBrokerResponseForSqlQuery(query); + } + + @DataProvider(name = "badQueries") + public Object[][] badQueries() { + return new Object[][] { + // need at least 4 arguments in agg func + {"select distinctCountThetaSketch(colTS, 'nominalEntries=123', '$0') from testTable"}, + // substitution arguments should start at $1 + {"select distinctCountThetaSketch(colTS, 'nominalEntries=123', 'colA = 1', '$0') from testTable"}, + // substituting variable has numeric value higher than the number of predicates provided + {"select distinctCountThetaSketch(colTS, 'nominalEntries=123', 'colA = 1', '$5') from testTable"}, + // SET_DIFF requires exactly 2 arguments + {"select distinctCountThetaSketch(colTS, 'nominalEntries=123', 'colA = 1', 'SET_DIFF($1)') from testTable"}, + // invalid merging function + {"select distinctCountThetaSketch(colTS, 'nominalEntries=123', 'colA = 1', 'asdf') from testTable"} + }; + } + private void testThetaSketches(boolean groupBy, boolean sql) { String tsQuery, distinctQuery; String thetaSketchParams = "nominalEntries=1001"; List<String> predicateStrings = Collections.singletonList("colA = 1"); + String substitution = "$1"; String whereClause = Strings.join(predicateStrings, " or "); - tsQuery = buildQuery(whereClause, thetaSketchParams, predicateStrings, whereClause, groupBy, false); + tsQuery = buildQuery(whereClause, thetaSketchParams, predicateStrings, substitution, groupBy, false); Review comment: Isn't the fourth argument postAggregationExpression? If so, it should look more like a set operation, as opposed to "$1"? ########## File path: pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java ########## @@ -576,6 +616,69 @@ private SetOperationBuilder getSetOperationBuilder() { : SetOperation.builder().setNominalEntries(_thetaSketchParams.getNominalEntries()); } + /** + * Validates that the function context's substitution parameters ($1, $2, etc) does not exceed the number + * of predicates passed into the post-aggregation function. + * + * For example, if the post aggregation function is: + * INTERSECT($1, $2, $3) + * + * But there are only 2 arguments passed into the aggregation function, throw an error + * @param context The parsed function context that's a tree structure + * @param numPredicates Max number of predicates available to be substituted + */ + private static void validatePostAggregationExpression(final ExpressionContext context, final int numPredicates) { + if (context.getType() == ExpressionContext.Type.LITERAL) { + throw new IllegalArgumentException("Invalid post-aggregation function expression syntax."); + } + + if (context.getType() == ExpressionContext.Type.IDENTIFIER) { + int id = extractSubstitutionPosition(context.getIdentifier()); + if (id <= 0) + throw new IllegalArgumentException("Argument substitution starts at $1"); + if (id > numPredicates) + throw new IllegalArgumentException("Argument substitution exceeded number of predicates"); + // if none of the invalid conditions are met above, exit out early + return; + } + + if (!MergeFunction.isValid(context.getFunction().getFunctionName())) { + final String allowed = + Arrays.stream(MergeFunction.values()) Review comment: We tend to avoid stream apis in query execution as they tend to have performance overhead. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org