This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new f6b29aa0a2 Improve PinotEvaluateLiteralRule (#11554)
f6b29aa0a2 is described below
commit f6b29aa0a250fa96a066d568dfbb22900db8b58d
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Sun Sep 10 11:20:08 2023 -0700
Improve PinotEvaluateLiteralRule (#11554)
---
.../rel/rules/PinotEvaluateLiteralRule.java | 310 +++++++++++----------
.../calcite/rel/rules/PinotQueryRuleSets.java | 4 +-
.../resources/queries/LiteralEvaluationPlans.json | 2 +-
3 files changed, 160 insertions(+), 156 deletions(-)
diff --git
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
index 7459dc18ec..e6636f75cf 100644
---
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
+++
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java
@@ -18,23 +18,17 @@
*/
package org.apache.calcite.rel.rules;
-import com.google.common.base.Preconditions;
import java.math.BigDecimal;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.GregorianCalendar;
-import java.util.HashMap;
import java.util.List;
-import java.util.Map;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.stream.Collectors;
+import javax.annotation.Nullable;
import org.apache.calcite.avatica.util.ByteString;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.RelVisitor;
-import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
@@ -45,7 +39,6 @@ import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.NlsString;
-import org.apache.calcite.util.TimestampString;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.function.FunctionInvoker;
import org.apache.pinot.common.function.FunctionRegistry;
@@ -54,172 +47,181 @@ import
org.apache.pinot.sql.parsers.SqlCompilationException;
/**
- * SingleValueAggregateRemoveRule that matches an Aggregate function
SINGLE_VALUE and remove it
- *
+ * PinotEvaluateLiteralRule that matches the literal only function calls and
evaluates them.
*/
-public class PinotEvaluateLiteralRule extends RelOptRule {
- public static final PinotEvaluateLiteralRule INSTANCE =
- new PinotEvaluateLiteralRule(PinotRuleUtils.PINOT_REL_FACTORY);
+public class PinotEvaluateLiteralRule {
- public PinotEvaluateLiteralRule(RelBuilderFactory factory) {
- super(operand(RelNode.class, any()), factory, null);
- }
+ public static class Project extends RelOptRule {
+ public static final Project INSTANCE = new
Project(PinotRuleUtils.PINOT_REL_FACTORY);
- @Override
- public boolean matches(RelOptRuleCall call) {
- // Traverse the relational expression using a RexShuttle visitor
- AtomicBoolean hasLiteralOnlyCall = new AtomicBoolean(false);
- try {
- (new RelVisitor() {
- @Override
- public void visit(RelNode node, int ordinal, RelNode parent) {
- // Check if all operands are RexLiteral
- if (node.getInputs().stream().allMatch(operand -> (operand
instanceof RexLiteral))) {
- // If all operands are literals, this call can be evaluated
- hasLiteralOnlyCall.set(true);
- // early terminate if we found one evaluate call
- throw new RuntimeException("Found one literal only call");
- }
- for (RelNode input : node.getInputs()) {
- visit(input, ordinal, node);
- }
- }
- }).go(call.rel(0));
- } catch (RuntimeException e) {
- // Found one literal only call
- }
- return hasLiteralOnlyCall.get();
- }
+ private Project(RelBuilderFactory factory) {
+ super(operand(LogicalProject.class, any()), factory, null);
+ }
- @Override
- public void onMatch(RelOptRuleCall call) {
- RelNode rootRelNode = call.rel(0);
- RexBuilder rexBuilder = rootRelNode.getCluster().getRexBuilder();
- Map<RexNode, RexNode> evaluatedResults = new HashMap<>();
-
- // Recursively evaluate all the calls with Literal only operands from
bottom up
- // Traverse the relational expression using a RexShuttle visitor
- RelNode newRelNode = rootRelNode.accept(new RexShuttle() {
- @Override
- public RexNode visitCall(RexCall call) {
- // Check if all operands are RexLiteral
- if (call.operands.stream().allMatch(operand -> operand instanceof
RexLiteral)) {
- // If all operands are literals or already evaluated, this call can
be evaluated
- return evaluateLiteralOnlyFunction(rexBuilder, call,
evaluatedResults);
- }
- List<RexNode> newOperands = call.operands.stream().map(operand -> {
- if (operand instanceof RexCall) {
- return visitCall((RexCall) operand);
- }
- return operand;
- }).collect(Collectors.toList());
- return call.clone(call.getType(), newOperands);
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ LogicalProject oldProject = call.rel(0);
+ RexBuilder rexBuilder = oldProject.getCluster().getRexBuilder();
+ LogicalProject newProject = (LogicalProject) oldProject.accept(new
EvaluateLiteralShuttle(rexBuilder));
+ if (newProject != oldProject) {
+ call.transformTo(constructNewProject(oldProject, newProject,
rexBuilder));
}
- });
- if (newRelNode instanceof Project) {
- newRelNode = constructNewProject(rexBuilder, (Project) rootRelNode,
(Project) newRelNode);
}
- call.transformTo(newRelNode);
}
- private RelNode constructNewProject(RexBuilder rexBuilder, Project
oldProjectNode, Project newProjectNode) {
- List<RexNode> oldProjects = oldProjectNode.getProjects();
- List<RexNode> newProjects = new ArrayList<>();
- for (int i = 0; i < oldProjects.size(); i++) {
- RexNode oldProject = oldProjects.get(i);
- RexNode newProject = newProjectNode.getProjects().get(i);
+ /**
+ * Constructs a new LogicalProject that matches the type of the old
LogicalProject.
+ */
+ private static LogicalProject constructNewProject(LogicalProject oldProject,
LogicalProject newProject,
+ RexBuilder rexBuilder) {
+ List<RexNode> oldProjects = oldProject.getProjects();
+ List<RexNode> newProjects = newProject.getProjects();
+ int numProjects = oldProjects.size();
+ assert newProjects.size() == numProjects;
+ List<RexNode> castedNewProjects = new ArrayList<>(numProjects);
+ boolean needCast = false;
+ for (int i = 0; i < numProjects; i++) {
+ RexNode oldNode = oldProjects.get(i);
+ RexNode newNode = newProjects.get(i);
// Need to cast the result to the original type if the literal type is
changed, e.g. VARCHAR literal is typed as
// CHAR(STRING_LENGTH) in Calcite, but we need to cast it back to
VARCHAR.
- newProject = (newProject.getType() == oldProject.getType()) ? newProject
- : rexBuilder.makeCast(oldProject.getType(), newProject, true);
- newProjects.add(newProject);
+ if (oldNode.getType() != newNode.getType()) {
+ needCast = true;
+ newNode = rexBuilder.makeCast(oldNode.getType(), newNode, true);
+ }
+ castedNewProjects.add(newNode);
+ }
+ return needCast ? LogicalProject.create(oldProject.getInput(),
oldProject.getHints(), castedNewProjects,
+ oldProject.getRowType()) : newProject;
+ }
+
+ public static class Filter extends RelOptRule {
+ public static final Filter INSTANCE = new
Filter(PinotRuleUtils.PINOT_REL_FACTORY);
+
+ private Filter(RelBuilderFactory factory) {
+ super(operand(LogicalFilter.class, any()), factory, null);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ LogicalFilter oldFilter = call.rel(0);
+ RexBuilder rexBuilder = oldFilter.getCluster().getRexBuilder();
+ LogicalFilter newFilter = (LogicalFilter) oldFilter.accept(new
EvaluateLiteralShuttle(rexBuilder));
+ if (newFilter != oldFilter) {
+ call.transformTo(newFilter);
+ }
+ }
+ }
+
+ /**
+ * A RexShuttle that recursively evaluates all the calls with literal only
operands.
+ */
+ private static class EvaluateLiteralShuttle extends RexShuttle {
+ final RexBuilder _rexBuilder;
+
+ EvaluateLiteralShuttle(RexBuilder rexBuilder) {
+ _rexBuilder = rexBuilder;
+ }
+
+ @Override
+ public RexNode visitCall(RexCall call) {
+ RexCall visitedCall = (RexCall) super.visitCall(call);
+ // Check if all operands are RexLiteral
+ if (visitedCall.operands.stream().allMatch(operand -> operand instanceof
RexLiteral)) {
+ return evaluateLiteralOnlyFunction(visitedCall, _rexBuilder);
+ } else {
+ return visitedCall;
+ }
}
- return LogicalProject.create(oldProjectNode.getInput(),
oldProjectNode.getHints(), newProjects,
- oldProjectNode.getRowType());
}
/**
- * Evaluates the literal only function and returns the result as a
RexLiteral, null if the function cannot be
- * evaluated.
+ * Evaluates the literal only function and returns the result as a
RexLiteral if it can be evaluated, or the function
+ * itself (RexCall) if it cannot be evaluated.
*/
- protected static RexNode evaluateLiteralOnlyFunction(RexBuilder rexBuilder,
RexNode rexNode,
- Map<RexNode, RexNode> evaluatedResults) {
- if (rexNode instanceof RexLiteral) {
- return rexNode;
- }
- Preconditions.checkArgument(rexNode instanceof RexCall, "Expected RexCall,
got: " + rexNode);
- RexNode resultRexNode = evaluatedResults.get(rexNode);
- if (resultRexNode != null) {
- return resultRexNode;
- }
- RexCall function = (RexCall) rexNode;
- List<RexNode> operands = new ArrayList<>(function.getOperands());
+ private static RexNode evaluateLiteralOnlyFunction(RexCall rexCall,
RexBuilder rexBuilder) {
+ String functionName = PinotRuleUtils.extractFunctionName(rexCall);
+ List<RexNode> operands = rexCall.getOperands();
+ assert operands.stream().allMatch(operand -> operand instanceof
RexLiteral);
int numOperands = operands.size();
+ FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionName,
numOperands);
+ if (functionInfo == null) {
+ // Function cannot be evaluated
+ return rexCall;
+ }
+ Object[] arguments = new Object[numOperands];
for (int i = 0; i < numOperands; i++) {
- // The RexCall is guaranteed to have all operands as RexLiteral or
evaluated as RexLiteral.
- // So recursively call evaluateLiteralOnlyFunction on all operands.
- RexNode operand = evaluateLiteralOnlyFunction(rexBuilder,
operands.get(i), evaluatedResults);
- operands.set(i, operand);
+ arguments[i] = getLiteralValue((RexLiteral) operands.get(i));
+ }
+ RelDataType rexNodeType = rexCall.getType();
+ Object resultValue;
+ try {
+ FunctionInvoker invoker = new FunctionInvoker(functionInfo);
+ invoker.convertTypes(arguments);
+ resultValue = invoker.invoke(arguments);
+ } catch (Exception e) {
+ throw new SqlCompilationException(
+ "Caught exception while invoking method: " +
functionInfo.getMethod() + " with arguments: " + Arrays.toString(
+ arguments), e);
+ }
+ try {
+ resultValue = convertResultValue(resultValue, rexNodeType);
+ } catch (Exception e) {
+ throw new SqlCompilationException(
+ "Caught exception while converting result value: " + resultValue + "
to type: " + rexNodeType, e);
+ }
+ try {
+ return rexBuilder.makeLiteral(resultValue, rexNodeType, false);
+ } catch (Exception e) {
+ throw new SqlCompilationException(
+ "Caught exception while making literal with value: " + resultValue +
" and type: " + rexNodeType, e);
}
+ }
- String functionName = PinotRuleUtils.extractFunctionName(function);
- FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionName,
numOperands);
- resultRexNode = rexNode;
- if (functionInfo != null) {
- Object[] arguments = new Object[numOperands];
- for (int i = 0; i < numOperands; i++) {
- RexNode operand = function.getOperands().get(i);
- Preconditions.checkArgument(operand instanceof RexLiteral, "Expected
all the operands to be RexLiteral");
- Object value = ((RexLiteral) operand).getValue();
- if (value instanceof NlsString) {
- arguments[i] = ((NlsString) value).getValue();
- } else if (value instanceof GregorianCalendar) {
- arguments[i] = ((GregorianCalendar) value).getTimeInMillis();
- } else if (value instanceof ByteString) {
- arguments[i] = ((ByteString) value).getBytes();
- } else {
- arguments[i] = value;
- }
- }
- RelDataType rexNodeType = rexNode.getType();
- Object functionResult;
- try {
- FunctionInvoker invoker = new FunctionInvoker(functionInfo);
- invoker.convertTypes(arguments);
- functionResult = invoker.invoke(arguments);
- } catch (Exception e) {
- throw new SqlCompilationException(
- "Caught exception while invoking method: " +
functionInfo.getMethod() + " with arguments: "
- + Arrays.toString(arguments), e);
- }
- if (functionResult == null) {
- resultRexNode = rexBuilder.makeNullLiteral(rexNodeType);
- } else if (rexNodeType.getSqlTypeName() == SqlTypeName.TIMESTAMP) {
- long millis;
- if (functionResult instanceof Timestamp) {
- millis = ((Timestamp) functionResult).getTime();
- } else if (functionResult instanceof Number) {
- millis = ((Number) functionResult).longValue();
- } else {
- millis =
TimestampUtils.toMillisSinceEpoch(functionResult.toString());
- }
- resultRexNode =
-
rexBuilder.makeTimestampLiteral(TimestampString.fromMillisSinceEpoch(millis),
rexNodeType.getPrecision());
- } else if (functionResult instanceof Byte || functionResult instanceof
Short || functionResult instanceof Integer
- || functionResult instanceof Long) {
- resultRexNode =
- rexBuilder.makeExactLiteral(BigDecimal.valueOf(((Number)
functionResult).longValue()), rexNodeType);
- } else if (functionResult instanceof Float || functionResult instanceof
Double) {
- resultRexNode = rexBuilder.makeExactLiteral(new
BigDecimal(functionResult.toString()), rexNodeType);
- } else if (functionResult instanceof BigDecimal) {
- resultRexNode = rexBuilder.makeExactLiteral((BigDecimal)
functionResult, rexNodeType);
- } else if (functionResult instanceof byte[]) {
- resultRexNode = rexBuilder.makeLiteral(new ByteString((byte[])
functionResult), rexNodeType, false);
+ @Nullable
+ private static Object getLiteralValue(RexLiteral rexLiteral) {
+ Object value = rexLiteral.getValue();
+ if (value instanceof NlsString) {
+ // STRING
+ return ((NlsString) value).getValue();
+ } else if (value instanceof GregorianCalendar) {
+ // TIMESTAMP
+ return ((GregorianCalendar) value).getTimeInMillis();
+ } else if (value instanceof ByteString) {
+ // BYTES
+ return ((ByteString) value).getBytes();
+ } else {
+ return value;
+ }
+ }
+
+ @Nullable
+ private static Object convertResultValue(@Nullable Object resultValue,
RelDataType relDataType) {
+ if (resultValue == null) {
+ return null;
+ }
+ if (relDataType.getSqlTypeName() == SqlTypeName.TIMESTAMP) {
+ // Return millis since epoch for TIMESTAMP
+ if (resultValue instanceof Timestamp) {
+ return ((Timestamp) resultValue).getTime();
+ } else if (resultValue instanceof Number) {
+ return ((Number) resultValue).longValue();
} else {
- resultRexNode = rexBuilder.makeLiteral(functionResult, rexNodeType,
false);
+ return TimestampUtils.toMillisSinceEpoch(resultValue.toString());
}
}
- evaluatedResults.put(rexNode, resultRexNode);
- return resultRexNode;
+ // Return BigDecimal for numbers
+ if (resultValue instanceof Integer || resultValue instanceof Long) {
+ return new BigDecimal(((Number) resultValue).longValue());
+ }
+ if (resultValue instanceof Float || resultValue instanceof Double) {
+ return new BigDecimal(resultValue.toString());
+ }
+ // Return ByteString for byte[]
+ if (resultValue instanceof byte[]) {
+ return new ByteString((byte[]) resultValue);
+ }
+ // TODO: Add more type handling
+ return resultValue;
}
}
diff --git
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
index 7a533e95d7..56448a5f3d 100644
---
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
+++
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
@@ -66,7 +66,9 @@ public class PinotQueryRuleSets {
// push project through WINDOW
CoreRules.PROJECT_WINDOW_TRANSPOSE,
- PinotEvaluateLiteralRule.INSTANCE,
+ // TODO: Revisit and see if they can be replaced with
CoreRules.PROJECT_REDUCE_EXPRESSIONS and
+ // CoreRules.FILTER_REDUCE_EXPRESSIONS
+ PinotEvaluateLiteralRule.Project.INSTANCE,
PinotEvaluateLiteralRule.Filter.INSTANCE,
// TODO: evaluate the SORT_JOIN_TRANSPOSE and SORT_JOIN_COPY rules
diff --git
a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
index 4a29d734cb..4d82a770c9 100644
--- a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
@@ -136,7 +136,7 @@
"sql": "EXPLAIN PLAN FOR Select
ST_Distance(X'8040340000000000004024000000000000', ST_Point(-122, 37.5, 1))
FROM a",
"output": [
"Execution Plan",
- "\nLogicalProject(EXPR$0=[13416951.966757335:DOUBLE])",
+ "\nLogicalProject(EXPR$0=[1.3416951966757335E7:DOUBLE])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]