mayankshriv commented on a change in pull request #5832:
URL: https://github.com/apache/incubator-pinot/pull/5832#discussion_r467334796



##########
File path: 
pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
##########
@@ -78,9 +102,9 @@ public 
DistinctCountThetaSketchAggregationFunction(List<ExpressionContext> argum
       throws SqlParseException {
     int numArguments = arguments.size();
 
-    // NOTE: This function expects at least 3 arguments: theta-sketch column, 
parameters, post-aggregation expression.
-    Preconditions.checkArgument(numArguments >= 3,
-        "DistinctCountThetaSketch expects at least three arguments 
(theta-sketch column, parameters, post-aggregation expression), got: ",
+    // NOTE: This function expects at least 4 arguments: theta-sketch column, 
parameters, post-aggregation expression.

Review comment:
       The comment says 4 arguments, but it only lists three of them?

##########
File path: 
pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
##########
@@ -97,6 +121,10 @@ public 
DistinctCountThetaSketchAggregationFunction(List<ExpressionContext> argum
     // Initialize the theta-sketch set operation builder
     _setOperationBuilder = getSetOperationBuilder();
 
+    // index of the original input predicates
+    // index[0] = $1

Review comment:
       Remove unused code?

##########
File path: 
pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
##########
@@ -526,28 +543,51 @@ private Predicate getPredicate(String predicateString) {
    * passed to this method to be used when evaluating the expression.
    *
    * @param postAggregationExpression Post-aggregation expression to evaluate 
(modeled as a filter)
+   * @param expressions list of aggregation function parameters
    * @param sketchMap Precomputed sketches for predicates that are part of the 
expression.
    * @return Overall evaluated sketch for the expression.
    */
-  private Sketch evalPostAggregationExpression(FilterContext 
postAggregationExpression,
-      Map<Predicate, Sketch> sketchMap) {
-    switch (postAggregationExpression.getType()) {
-      case AND:
-        Intersection intersection = _setOperationBuilder.buildIntersection();
-        for (FilterContext child : postAggregationExpression.getChildren()) {
-          intersection.update(evalPostAggregationExpression(child, sketchMap));
-        }
-        return intersection.getResult();
-      case OR:
-        Union union = _setOperationBuilder.buildUnion();
-        for (FilterContext child : postAggregationExpression.getChildren()) {
-          union.update(evalPostAggregationExpression(child, sketchMap));
+  private Sketch evalPostAggregationExpression(
+      final ExpressionContext postAggregationExpression,
+      final List<Predicate> expressions,
+      final Map<Predicate, Sketch> sketchMap) {
+    if (postAggregationExpression.getType() == ExpressionContext.Type.LITERAL) 
{
+      throw new IllegalArgumentException("Literal not supported in 
post-aggregation function");
+    }
+
+    if (postAggregationExpression.getType() == 
ExpressionContext.Type.IDENTIFIER) {
+      final Predicate exp =
+          
expressions.get(extractSubstitutionPosition(postAggregationExpression.getLiteral())
 - 1);
+      return sketchMap.get(exp);
+    }
+
+    // shouldn't throw exception because of the validation in the constructor
+    final MergeFunction func =
+        
MergeFunction.valueOf(postAggregationExpression.getFunction().getFunctionName().toUpperCase());
+
+    // handle functions recursively
+    switch(func) {
+      case SET_UNION:
+        final Union union = _setOperationBuilder.buildUnion();
+        for (final ExpressionContext exp : 
postAggregationExpression.getFunction().getArguments()) {
+          union.update(evalPostAggregationExpression(exp, expressions, 
sketchMap));
         }
         return union.getResult();
-      case PREDICATE:
-        return sketchMap.get(postAggregationExpression.getPredicate());
+      case SET_INTERSECT:
+        final Intersection intersection = 
_setOperationBuilder.buildIntersection();
+        for (final ExpressionContext exp : 
postAggregationExpression.getFunction().getArguments()) {
+          intersection.update(evalPostAggregationExpression(exp, expressions, 
sketchMap));
+        }
+        return intersection.getResult();
+      case SET_DIFF:
+        final List<ExpressionContext> args = 
postAggregationExpression.getFunction().getArguments();
+        final AnotB diff = _setOperationBuilder.buildANotB();
+        final Sketch a = evalPostAggregationExpression(args.get(0), 
expressions, sketchMap);
+        final Sketch b = evalPostAggregationExpression(args.get(1), 
expressions, sketchMap);
+        diff.update(a, b);
+        return diff.getResult();
       default:
-        throw new IllegalStateException();
+        throw new IllegalStateException("Invalid post-aggregation function.");

Review comment:
       Include the string representation of function?

##########
File path: 
pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
##########
@@ -526,28 +543,51 @@ private Predicate getPredicate(String predicateString) {
    * passed to this method to be used when evaluating the expression.
    *
    * @param postAggregationExpression Post-aggregation expression to evaluate 
(modeled as a filter)
+   * @param expressions list of aggregation function parameters
    * @param sketchMap Precomputed sketches for predicates that are part of the 
expression.
    * @return Overall evaluated sketch for the expression.
    */
-  private Sketch evalPostAggregationExpression(FilterContext 
postAggregationExpression,
-      Map<Predicate, Sketch> sketchMap) {
-    switch (postAggregationExpression.getType()) {
-      case AND:
-        Intersection intersection = _setOperationBuilder.buildIntersection();
-        for (FilterContext child : postAggregationExpression.getChildren()) {
-          intersection.update(evalPostAggregationExpression(child, sketchMap));
-        }
-        return intersection.getResult();
-      case OR:
-        Union union = _setOperationBuilder.buildUnion();
-        for (FilterContext child : postAggregationExpression.getChildren()) {
-          union.update(evalPostAggregationExpression(child, sketchMap));
+  private Sketch evalPostAggregationExpression(
+      final ExpressionContext postAggregationExpression,
+      final List<Predicate> expressions,
+      final Map<Predicate, Sketch> sketchMap) {
+    if (postAggregationExpression.getType() == ExpressionContext.Type.LITERAL) 
{
+      throw new IllegalArgumentException("Literal not supported in 
post-aggregation function");
+    }
+
+    if (postAggregationExpression.getType() == 
ExpressionContext.Type.IDENTIFIER) {
+      final Predicate exp =
+          
expressions.get(extractSubstitutionPosition(postAggregationExpression.getLiteral())
 - 1);

Review comment:
       Perhaps pre-substitition in the constructor would be better? For 
example, if the same $k arg is repeated multiple times, we might avoid the use 
of matcher using a temporary alias map?

##########
File path: 
pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
##########
@@ -108,35 +136,24 @@ public 
DistinctCountThetaSketchAggregationFunction(List<ExpressionContext> argum
     Preconditions.checkArgument(paramsExpression.getType() == 
ExpressionContext.Type.LITERAL,
         "Last argument of DistinctCountThetaSketch must be literal 
(post-aggregation expression)");
     _postAggregationExpression = QueryContextConverterUtils
-        
.getFilter(CalciteSqlParser.compileToExpression(postAggregationExpression.getLiteral()));
+        
.getExpression(CalciteSqlParser.compileToExpression(postAggregationExpression.getLiteral()));
 
     // Initialize the predicate map
     _predicateInfoMap = new HashMap<>();
-    if (numArguments > 3) {
-      // Predicates are explicitly specified
-      for (int i = 2; i < numArguments - 1; i++) {
-        ExpressionContext predicateExpression = arguments.get(i);
-        Preconditions.checkArgument(predicateExpression.getType() == 
ExpressionContext.Type.LITERAL,
-            "Third to second last argument of DistinctCountThetaSketch must be 
literal (predicate expression)");
-        Predicate predicate = getPredicate(predicateExpression.getLiteral());
-        _inputExpressions.add(predicate.getLhs());
-        _predicateInfoMap.put(predicate, new PredicateInfo(predicate));
-      }
-    } else {
-      // Auto-derive predicates from the post-aggregation expression
-      Stack<FilterContext> stack = new Stack<>();
-      stack.push(_postAggregationExpression);
-      while (!stack.isEmpty()) {
-        FilterContext filter = stack.pop();
-        if (filter.getType() == FilterContext.Type.PREDICATE) {
-          Predicate predicate = filter.getPredicate();
-          _inputExpressions.add(predicate.getLhs());
-          _predicateInfoMap.put(predicate, new PredicateInfo(predicate));
-        } else {
-          stack.addAll(filter.getChildren());
-        }
-      }
+
+    // Predicates are explicitly specified

Review comment:
       We should still keep the auto-deriving of predicates. Granted, we won't 
be able to use $ notation in that case though. What do you think?

##########
File path: 
pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
##########
@@ -576,6 +616,69 @@ private SetOperationBuilder getSetOperationBuilder() {
         : 
SetOperation.builder().setNominalEntries(_thetaSketchParams.getNominalEntries());
   }
 
+  /**
+   * Validates that the function context's substitution parameters ($1, $2, 
etc) does not exceed the number
+   * of predicates passed into the post-aggregation function.
+   *
+   * For example, if the post aggregation function is:
+   * INTERSECT($1, $2, $3)
+   *
+   * But there are only 2 arguments passed into the aggregation function, 
throw an error
+   * @param context The parsed function context that's a tree structure
+   * @param numPredicates Max number of predicates available to be substituted
+   */
+  private static void validatePostAggregationExpression(final 
ExpressionContext context, final int numPredicates) {
+    if (context.getType() == ExpressionContext.Type.LITERAL) {
+      throw new IllegalArgumentException("Invalid post-aggregation function 
expression syntax.");
+    }
+
+    if (context.getType() == ExpressionContext.Type.IDENTIFIER) {
+      int id = extractSubstitutionPosition(context.getIdentifier());
+      if (id <= 0)
+        throw new IllegalArgumentException("Argument substitution starts at 
$1");
+      if (id > numPredicates)
+        throw new IllegalArgumentException("Argument substitution exceeded 
number of predicates");
+      // if none of the invalid conditions are met above, exit out early
+      return;
+    }
+
+    if (!MergeFunction.isValid(context.getFunction().getFunctionName())) {
+      final String allowed =
+          Arrays.stream(MergeFunction.values())
+              .map(MergeFunction::name)
+              .collect(Collectors.joining(","));
+      throw new IllegalArgumentException(
+          String.format("Invalid Theta Sketch aggregation function. Allowed: 
[%s]", allowed));
+    }
+
+    switch 
(MergeFunction.valueOf(context.getFunction().getFunctionName().toUpperCase())) {
+      case SET_DIFF:
+        // set diff can only have 2 arguments
+        if (context.getFunction().getArguments().size() != 2) {

Review comment:
       +1

##########
File path: 
pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountThetaSketchTest.java
##########
@@ -117,60 +121,77 @@ public void testGroupBySql() {
     testThetaSketches(true, true);
   }
 
+  @Test(expectedExceptions = BadQueryRequestException.class, dataProvider = 
"badQueries")
+  public void testInvalidNoPredicates(final String query) {
+    getBrokerResponseForSqlQuery(query);
+  }
+
+  @DataProvider(name = "badQueries")
+  public Object[][] badQueries() {
+    return new Object[][] {
+        // need at least 4 arguments in agg func
+        {"select distinctCountThetaSketch(colTS, 'nominalEntries=123', '$0') 
from testTable"},
+        // substitution arguments should start at $1
+        {"select distinctCountThetaSketch(colTS, 'nominalEntries=123', 'colA = 
1', '$0') from testTable"},
+        // substituting variable has numeric value higher than the number of 
predicates provided
+        {"select distinctCountThetaSketch(colTS, 'nominalEntries=123', 'colA = 
1', '$5') from testTable"},
+        // SET_DIFF requires exactly 2 arguments
+        {"select distinctCountThetaSketch(colTS, 'nominalEntries=123', 'colA = 
1', 'SET_DIFF($1)') from testTable"},
+        // invalid merging function
+        {"select distinctCountThetaSketch(colTS, 'nominalEntries=123', 'colA = 
1', 'asdf') from testTable"}
+    };
+  }
+
   private void testThetaSketches(boolean groupBy, boolean sql) {
     String tsQuery, distinctQuery;
     String thetaSketchParams = "nominalEntries=1001";
 
     List<String> predicateStrings = Collections.singletonList("colA = 1");
+    String substitution = "$1";
     String whereClause = Strings.join(predicateStrings, " or ");
-    tsQuery = buildQuery(whereClause, thetaSketchParams, predicateStrings, 
whereClause, groupBy, false);
+    tsQuery = buildQuery(whereClause, thetaSketchParams, predicateStrings, 
substitution, groupBy, false);

Review comment:
       Isn't the fourth argument postAggregationExpression? If so, it should 
look more like a set operation, as opposed to "$1"?

##########
File path: 
pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
##########
@@ -576,6 +616,69 @@ private SetOperationBuilder getSetOperationBuilder() {
         : 
SetOperation.builder().setNominalEntries(_thetaSketchParams.getNominalEntries());
   }
 
+  /**
+   * Validates that the function context's substitution parameters ($1, $2, 
etc) does not exceed the number
+   * of predicates passed into the post-aggregation function.
+   *
+   * For example, if the post aggregation function is:
+   * INTERSECT($1, $2, $3)
+   *
+   * But there are only 2 arguments passed into the aggregation function, 
throw an error
+   * @param context The parsed function context that's a tree structure
+   * @param numPredicates Max number of predicates available to be substituted
+   */
+  private static void validatePostAggregationExpression(final 
ExpressionContext context, final int numPredicates) {
+    if (context.getType() == ExpressionContext.Type.LITERAL) {
+      throw new IllegalArgumentException("Invalid post-aggregation function 
expression syntax.");
+    }
+
+    if (context.getType() == ExpressionContext.Type.IDENTIFIER) {
+      int id = extractSubstitutionPosition(context.getIdentifier());
+      if (id <= 0)
+        throw new IllegalArgumentException("Argument substitution starts at 
$1");
+      if (id > numPredicates)
+        throw new IllegalArgumentException("Argument substitution exceeded 
number of predicates");
+      // if none of the invalid conditions are met above, exit out early
+      return;
+    }
+
+    if (!MergeFunction.isValid(context.getFunction().getFunctionName())) {
+      final String allowed =
+          Arrays.stream(MergeFunction.values())

Review comment:
       We tend to avoid stream apis in query execution as they tend to have 
performance overhead.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to