This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new f6b29aa0a2 Improve PinotEvaluateLiteralRule (#11554)
f6b29aa0a2 is described below

commit f6b29aa0a250fa96a066d568dfbb22900db8b58d
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Sun Sep 10 11:20:08 2023 -0700

    Improve PinotEvaluateLiteralRule (#11554)
---
 .../rel/rules/PinotEvaluateLiteralRule.java        | 310 +++++++++++----------
 .../calcite/rel/rules/PinotQueryRuleSets.java      |   4 +-
 .../resources/queries/LiteralEvaluationPlans.json  |   2 +-
 3 files changed, 160 insertions(+), 156 deletions(-)

diff --git 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
index 7459dc18ec..e6636f75cf 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
@@ -18,23 +18,17 @@
  */
 package org.apache.calcite.rel.rules;
 
-import com.google.common.base.Preconditions;
 import java.math.BigDecimal;
 import java.sql.Timestamp;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.GregorianCalendar;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.stream.Collectors;
+import javax.annotation.Nullable;
 import org.apache.calcite.avatica.util.ByteString;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.RelVisitor;
-import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.logical.LogicalFilter;
 import org.apache.calcite.rel.logical.LogicalProject;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexBuilder;
@@ -45,7 +39,6 @@ import org.apache.calcite.rex.RexShuttle;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.tools.RelBuilderFactory;
 import org.apache.calcite.util.NlsString;
-import org.apache.calcite.util.TimestampString;
 import org.apache.pinot.common.function.FunctionInfo;
 import org.apache.pinot.common.function.FunctionInvoker;
 import org.apache.pinot.common.function.FunctionRegistry;
@@ -54,172 +47,181 @@ import 
org.apache.pinot.sql.parsers.SqlCompilationException;
 
 
 /**
- * SingleValueAggregateRemoveRule that matches an Aggregate function 
SINGLE_VALUE and remove it
- *
+ * PinotEvaluateLiteralRule that matches the literal only function calls and 
evaluates them.
  */
-public class PinotEvaluateLiteralRule extends RelOptRule {
-  public static final PinotEvaluateLiteralRule INSTANCE =
-      new PinotEvaluateLiteralRule(PinotRuleUtils.PINOT_REL_FACTORY);
+public class PinotEvaluateLiteralRule {
 
-  public PinotEvaluateLiteralRule(RelBuilderFactory factory) {
-    super(operand(RelNode.class, any()), factory, null);
-  }
+  public static class Project extends RelOptRule {
+    public static final Project INSTANCE = new 
Project(PinotRuleUtils.PINOT_REL_FACTORY);
 
-  @Override
-  public boolean matches(RelOptRuleCall call) {
-    // Traverse the relational expression using a RexShuttle visitor
-    AtomicBoolean hasLiteralOnlyCall = new AtomicBoolean(false);
-    try {
-      (new RelVisitor() {
-        @Override
-        public void visit(RelNode node, int ordinal, RelNode parent) {
-          // Check if all operands are RexLiteral
-          if (node.getInputs().stream().allMatch(operand -> (operand 
instanceof RexLiteral))) {
-            // If all operands are literals, this call can be evaluated
-            hasLiteralOnlyCall.set(true);
-            // early terminate if we found one evaluate call
-            throw new RuntimeException("Found one literal only call");
-          }
-          for (RelNode input : node.getInputs()) {
-            visit(input, ordinal, node);
-          }
-        }
-      }).go(call.rel(0));
-    } catch (RuntimeException e) {
-      // Found one literal only call
-    }
-    return hasLiteralOnlyCall.get();
-  }
+    private Project(RelBuilderFactory factory) {
+      super(operand(LogicalProject.class, any()), factory, null);
+    }
 
-  @Override
-  public void onMatch(RelOptRuleCall call) {
-    RelNode rootRelNode = call.rel(0);
-    RexBuilder rexBuilder = rootRelNode.getCluster().getRexBuilder();
-    Map<RexNode, RexNode> evaluatedResults = new HashMap<>();
-
-    // Recursively evaluate all the calls with Literal only operands from 
bottom up
-    // Traverse the relational expression using a RexShuttle visitor
-    RelNode newRelNode = rootRelNode.accept(new RexShuttle() {
-      @Override
-      public RexNode visitCall(RexCall call) {
-        // Check if all operands are RexLiteral
-        if (call.operands.stream().allMatch(operand -> operand instanceof 
RexLiteral)) {
-          // If all operands are literals or already evaluated, this call can 
be evaluated
-          return evaluateLiteralOnlyFunction(rexBuilder, call, 
evaluatedResults);
-        }
-        List<RexNode> newOperands = call.operands.stream().map(operand -> {
-          if (operand instanceof RexCall) {
-            return visitCall((RexCall) operand);
-          }
-          return operand;
-        }).collect(Collectors.toList());
-        return call.clone(call.getType(), newOperands);
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+      LogicalProject oldProject = call.rel(0);
+      RexBuilder rexBuilder = oldProject.getCluster().getRexBuilder();
+      LogicalProject newProject = (LogicalProject) oldProject.accept(new 
EvaluateLiteralShuttle(rexBuilder));
+      if (newProject != oldProject) {
+        call.transformTo(constructNewProject(oldProject, newProject, 
rexBuilder));
       }
-    });
-    if (newRelNode instanceof Project) {
-      newRelNode = constructNewProject(rexBuilder, (Project) rootRelNode, 
(Project) newRelNode);
     }
-    call.transformTo(newRelNode);
   }
 
-  private RelNode constructNewProject(RexBuilder rexBuilder, Project 
oldProjectNode, Project newProjectNode) {
-    List<RexNode> oldProjects = oldProjectNode.getProjects();
-    List<RexNode> newProjects = new ArrayList<>();
-    for (int i = 0; i < oldProjects.size(); i++) {
-      RexNode oldProject = oldProjects.get(i);
-      RexNode newProject = newProjectNode.getProjects().get(i);
+  /**
+   * Constructs a new LogicalProject that matches the type of the old 
LogicalProject.
+   */
+  private static LogicalProject constructNewProject(LogicalProject oldProject, 
LogicalProject newProject,
+      RexBuilder rexBuilder) {
+    List<RexNode> oldProjects = oldProject.getProjects();
+    List<RexNode> newProjects = newProject.getProjects();
+    int numProjects = oldProjects.size();
+    assert newProjects.size() == numProjects;
+    List<RexNode> castedNewProjects = new ArrayList<>(numProjects);
+    boolean needCast = false;
+    for (int i = 0; i < numProjects; i++) {
+      RexNode oldNode = oldProjects.get(i);
+      RexNode newNode = newProjects.get(i);
       // Need to cast the result to the original type if the literal type is 
changed, e.g. VARCHAR literal is typed as
       // CHAR(STRING_LENGTH) in Calcite, but we need to cast it back to 
VARCHAR.
-      newProject = (newProject.getType() == oldProject.getType()) ? newProject
-          : rexBuilder.makeCast(oldProject.getType(), newProject, true);
-      newProjects.add(newProject);
+      if (oldNode.getType() != newNode.getType()) {
+        needCast = true;
+        newNode = rexBuilder.makeCast(oldNode.getType(), newNode, true);
+      }
+      castedNewProjects.add(newNode);
+    }
+    return needCast ? LogicalProject.create(oldProject.getInput(), 
oldProject.getHints(), castedNewProjects,
+        oldProject.getRowType()) : newProject;
+  }
+
+  public static class Filter extends RelOptRule {
+    public static final Filter INSTANCE = new 
Filter(PinotRuleUtils.PINOT_REL_FACTORY);
+
+    private Filter(RelBuilderFactory factory) {
+      super(operand(LogicalFilter.class, any()), factory, null);
+    }
+
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+      LogicalFilter oldFilter = call.rel(0);
+      RexBuilder rexBuilder = oldFilter.getCluster().getRexBuilder();
+      LogicalFilter newFilter = (LogicalFilter) oldFilter.accept(new 
EvaluateLiteralShuttle(rexBuilder));
+      if (newFilter != oldFilter) {
+        call.transformTo(newFilter);
+      }
+    }
+  }
+
+  /**
+   * A RexShuttle that recursively evaluates all the calls with literal only 
operands.
+   */
+  private static class EvaluateLiteralShuttle extends RexShuttle {
+    final RexBuilder _rexBuilder;
+
+    EvaluateLiteralShuttle(RexBuilder rexBuilder) {
+      _rexBuilder = rexBuilder;
+    }
+
+    @Override
+    public RexNode visitCall(RexCall call) {
+      RexCall visitedCall = (RexCall) super.visitCall(call);
+      // Check if all operands are RexLiteral
+      if (visitedCall.operands.stream().allMatch(operand -> operand instanceof 
RexLiteral)) {
+        return evaluateLiteralOnlyFunction(visitedCall, _rexBuilder);
+      } else {
+        return visitedCall;
+      }
     }
-    return LogicalProject.create(oldProjectNode.getInput(), 
oldProjectNode.getHints(), newProjects,
-        oldProjectNode.getRowType());
   }
 
   /**
-   * Evaluates the literal only function and returns the result as a 
RexLiteral, null if the function cannot be
-   * evaluated.
+   * Evaluates the literal only function and returns the result as a 
RexLiteral if it can be evaluated, or the function
+   * itself (RexCall) if it cannot be evaluated.
    */
-  protected static RexNode evaluateLiteralOnlyFunction(RexBuilder rexBuilder, 
RexNode rexNode,
-      Map<RexNode, RexNode> evaluatedResults) {
-    if (rexNode instanceof RexLiteral) {
-      return rexNode;
-    }
-    Preconditions.checkArgument(rexNode instanceof RexCall, "Expected RexCall, 
got: " + rexNode);
-    RexNode resultRexNode = evaluatedResults.get(rexNode);
-    if (resultRexNode != null) {
-      return resultRexNode;
-    }
-    RexCall function = (RexCall) rexNode;
-    List<RexNode> operands = new ArrayList<>(function.getOperands());
+  private static RexNode evaluateLiteralOnlyFunction(RexCall rexCall, 
RexBuilder rexBuilder) {
+    String functionName = PinotRuleUtils.extractFunctionName(rexCall);
+    List<RexNode> operands = rexCall.getOperands();
+    assert operands.stream().allMatch(operand -> operand instanceof 
RexLiteral);
     int numOperands = operands.size();
+    FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionName, 
numOperands);
+    if (functionInfo == null) {
+      // Function cannot be evaluated
+      return rexCall;
+    }
+    Object[] arguments = new Object[numOperands];
     for (int i = 0; i < numOperands; i++) {
-      // The RexCall is guaranteed to have all operands as RexLiteral or 
evaluated as RexLiteral.
-      // So recursively call evaluateLiteralOnlyFunction on all operands.
-      RexNode operand = evaluateLiteralOnlyFunction(rexBuilder, 
operands.get(i), evaluatedResults);
-      operands.set(i, operand);
+      arguments[i] = getLiteralValue((RexLiteral) operands.get(i));
+    }
+    RelDataType rexNodeType = rexCall.getType();
+    Object resultValue;
+    try {
+      FunctionInvoker invoker = new FunctionInvoker(functionInfo);
+      invoker.convertTypes(arguments);
+      resultValue = invoker.invoke(arguments);
+    } catch (Exception e) {
+      throw new SqlCompilationException(
+          "Caught exception while invoking method: " + 
functionInfo.getMethod() + " with arguments: " + Arrays.toString(
+              arguments), e);
+    }
+    try {
+      resultValue = convertResultValue(resultValue, rexNodeType);
+    } catch (Exception e) {
+      throw new SqlCompilationException(
+          "Caught exception while converting result value: " + resultValue + " 
to type: " + rexNodeType, e);
+    }
+    try {
+      return rexBuilder.makeLiteral(resultValue, rexNodeType, false);
+    } catch (Exception e) {
+      throw new SqlCompilationException(
+          "Caught exception while making literal with value: " + resultValue + 
" and type: " + rexNodeType, e);
     }
+  }
 
-    String functionName = PinotRuleUtils.extractFunctionName(function);
-    FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionName, 
numOperands);
-    resultRexNode = rexNode;
-    if (functionInfo != null) {
-      Object[] arguments = new Object[numOperands];
-      for (int i = 0; i < numOperands; i++) {
-        RexNode operand = function.getOperands().get(i);
-        Preconditions.checkArgument(operand instanceof RexLiteral, "Expected 
all the operands to be RexLiteral");
-        Object value = ((RexLiteral) operand).getValue();
-        if (value instanceof NlsString) {
-          arguments[i] = ((NlsString) value).getValue();
-        } else if (value instanceof GregorianCalendar) {
-          arguments[i] = ((GregorianCalendar) value).getTimeInMillis();
-        } else if (value instanceof ByteString) {
-          arguments[i] = ((ByteString) value).getBytes();
-        } else {
-          arguments[i] = value;
-        }
-      }
-      RelDataType rexNodeType = rexNode.getType();
-      Object functionResult;
-      try {
-        FunctionInvoker invoker = new FunctionInvoker(functionInfo);
-        invoker.convertTypes(arguments);
-        functionResult = invoker.invoke(arguments);
-      } catch (Exception e) {
-        throw new SqlCompilationException(
-            "Caught exception while invoking method: " + 
functionInfo.getMethod() + " with arguments: "
-                + Arrays.toString(arguments), e);
-      }
-      if (functionResult == null) {
-        resultRexNode = rexBuilder.makeNullLiteral(rexNodeType);
-      } else if (rexNodeType.getSqlTypeName() == SqlTypeName.TIMESTAMP) {
-        long millis;
-        if (functionResult instanceof Timestamp) {
-          millis = ((Timestamp) functionResult).getTime();
-        } else if (functionResult instanceof Number) {
-          millis = ((Number) functionResult).longValue();
-        } else {
-          millis = 
TimestampUtils.toMillisSinceEpoch(functionResult.toString());
-        }
-        resultRexNode =
-            
rexBuilder.makeTimestampLiteral(TimestampString.fromMillisSinceEpoch(millis), 
rexNodeType.getPrecision());
-      } else if (functionResult instanceof Byte || functionResult instanceof 
Short || functionResult instanceof Integer
-          || functionResult instanceof Long) {
-        resultRexNode =
-            rexBuilder.makeExactLiteral(BigDecimal.valueOf(((Number) 
functionResult).longValue()), rexNodeType);
-      } else if (functionResult instanceof Float || functionResult instanceof 
Double) {
-        resultRexNode = rexBuilder.makeExactLiteral(new 
BigDecimal(functionResult.toString()), rexNodeType);
-      } else if (functionResult instanceof BigDecimal) {
-        resultRexNode = rexBuilder.makeExactLiteral((BigDecimal) 
functionResult, rexNodeType);
-      } else if (functionResult instanceof byte[]) {
-        resultRexNode = rexBuilder.makeLiteral(new ByteString((byte[]) 
functionResult), rexNodeType, false);
+  @Nullable
+  private static Object getLiteralValue(RexLiteral rexLiteral) {
+    Object value = rexLiteral.getValue();
+    if (value instanceof NlsString) {
+      // STRING
+      return ((NlsString) value).getValue();
+    } else if (value instanceof GregorianCalendar) {
+      // TIMESTAMP
+      return ((GregorianCalendar) value).getTimeInMillis();
+    } else if (value instanceof ByteString) {
+      // BYTES
+      return ((ByteString) value).getBytes();
+    } else {
+      return value;
+    }
+  }
+
+  @Nullable
+  private static Object convertResultValue(@Nullable Object resultValue, 
RelDataType relDataType) {
+    if (resultValue == null) {
+      return null;
+    }
+    if (relDataType.getSqlTypeName() == SqlTypeName.TIMESTAMP) {
+      // Return millis since epoch for TIMESTAMP
+      if (resultValue instanceof Timestamp) {
+        return ((Timestamp) resultValue).getTime();
+      } else if (resultValue instanceof Number) {
+        return ((Number) resultValue).longValue();
       } else {
-        resultRexNode = rexBuilder.makeLiteral(functionResult, rexNodeType, 
false);
+        return TimestampUtils.toMillisSinceEpoch(resultValue.toString());
       }
     }
-    evaluatedResults.put(rexNode, resultRexNode);
-    return resultRexNode;
+    // Return BigDecimal for numbers
+    if (resultValue instanceof Integer || resultValue instanceof Long) {
+      return new BigDecimal(((Number) resultValue).longValue());
+    }
+    if (resultValue instanceof Float || resultValue instanceof Double) {
+      return new BigDecimal(resultValue.toString());
+    }
+    // Return ByteString for byte[]
+    if (resultValue instanceof byte[]) {
+      return new ByteString((byte[]) resultValue);
+    }
+    // TODO: Add more type handling
+    return resultValue;
   }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
index 7a533e95d7..56448a5f3d 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
@@ -66,7 +66,9 @@ public class PinotQueryRuleSets {
           // push project through WINDOW
           CoreRules.PROJECT_WINDOW_TRANSPOSE,
 
-          PinotEvaluateLiteralRule.INSTANCE,
+          // TODO: Revisit and see if they can be replaced with 
CoreRules.PROJECT_REDUCE_EXPRESSIONS and
+          //       CoreRules.FILTER_REDUCE_EXPRESSIONS
+          PinotEvaluateLiteralRule.Project.INSTANCE, 
PinotEvaluateLiteralRule.Filter.INSTANCE,
 
           // TODO: evaluate the SORT_JOIN_TRANSPOSE and SORT_JOIN_COPY rules
 
diff --git 
a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json 
b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
index 4a29d734cb..4d82a770c9 100644
--- a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
@@ -136,7 +136,7 @@
         "sql": "EXPLAIN PLAN FOR Select 
ST_Distance(X'8040340000000000004024000000000000', ST_Point(-122, 37.5, 1)) 
FROM a",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(EXPR$0=[13416951.966757335:DOUBLE])",
+          "\nLogicalProject(EXPR$0=[1.3416951966757335E7:DOUBLE])",
           "\n  LogicalTableScan(table=[[a]])",
           "\n"
         ]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to