This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new f8de958174 Use argument type to lookup function for literal only query 
(#13673)
f8de958174 is described below

commit f8de958174e209ab7f11572149ce1723d63b5af3
Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com>
AuthorDate: Wed Aug 28 11:50:02 2024 -0700

    Use argument type to lookup function for literal only query (#13673)
---
 .../BaseSingleStageBrokerRequestHandler.java       | 120 ++++++--------------
 .../common/request/context/LiteralContext.java     |  36 +-----
 .../pinot/common/utils/request/RequestUtils.java   |  98 ++++++++++++++--
 .../rewriter/CompileTimeFunctionsInvoker.java      |  66 ++++++-----
 .../pinot/sql/parsers/CalciteSqlCompilerTest.java  | 125 ++++++++++-----------
 5 files changed, 225 insertions(+), 220 deletions(-)

diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java
index 28eebf205b..83f83188dc 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java
@@ -77,6 +77,7 @@ import org.apache.pinot.common.response.ProcessingException;
 import org.apache.pinot.common.response.broker.BrokerResponseNative;
 import org.apache.pinot.common.response.broker.ResultTable;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.common.utils.DatabaseUtils;
 import org.apache.pinot.common.utils.config.QueryOptionsUtils;
 import org.apache.pinot.common.utils.request.RequestUtils;
@@ -101,8 +102,6 @@ import 
org.apache.pinot.spi.exception.BadQueryRequestException;
 import org.apache.pinot.spi.exception.DatabaseConflictException;
 import org.apache.pinot.spi.trace.RequestContext;
 import org.apache.pinot.spi.trace.Tracing;
-import org.apache.pinot.spi.utils.BigDecimalUtils;
-import org.apache.pinot.spi.utils.BytesUtils;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.spi.utils.CommonConstants.Broker;
 import 
org.apache.pinot.spi.utils.CommonConstants.Broker.Request.QueryOptionKey;
@@ -1489,16 +1488,18 @@ public abstract class 
BaseSingleStageBrokerRequestHandler extends BaseBrokerRequ
   private BrokerResponseNative processLiteralOnlyQuery(long requestId, 
PinotQuery pinotQuery,
       RequestContext requestContext) {
     BrokerResponseNative brokerResponse = new BrokerResponseNative();
-    List<String> columnNames = new ArrayList<>();
-    List<DataSchema.ColumnDataType> columnTypes = new ArrayList<>();
-    List<Object> row = new ArrayList<>();
-    for (Expression expression : pinotQuery.getSelectList()) {
-      computeResultsForExpression(expression, columnNames, columnTypes, row);
-    }
-    DataSchema dataSchema =
-        new DataSchema(columnNames.toArray(new String[0]), 
columnTypes.toArray(new DataSchema.ColumnDataType[0]));
-    List<Object[]> rows = new ArrayList<>();
-    rows.add(row.toArray());
+    List<Expression> selectList = pinotQuery.getSelectList();
+    int numColumns = selectList.size();
+    String[] columnNames = new String[numColumns];
+    ColumnDataType[] columnTypes = new ColumnDataType[numColumns];
+    Object[] values = new Object[numColumns];
+    for (int i = 0; i < numColumns; i++) {
+      computeResultsForExpression(selectList.get(i), columnNames, columnTypes, 
values, i);
+      values[i] = columnTypes[i].format(values[i]);
+    }
+    DataSchema dataSchema = new DataSchema(columnNames, columnTypes);
+    List<Object[]> rows = new ArrayList<>(1);
+    rows.add(values);
     ResultTable resultTable = new ResultTable(dataSchema, rows);
     brokerResponse.setResultTable(resultTable);
     brokerResponse.setTimeUsedMs(System.currentTimeMillis() - 
requestContext.getRequestArrivalTimeMillis());
@@ -1510,87 +1511,30 @@ public abstract class 
BaseSingleStageBrokerRequestHandler extends BaseBrokerRequ
   }
 
   // TODO(xiangfu): Move Literal function computation here from Calcite Parser.
-  private void computeResultsForExpression(Expression e, List<String> 
columnNames,
-      List<DataSchema.ColumnDataType> columnTypes, List<Object> row) {
-    if (e.getType() == ExpressionType.LITERAL) {
-      computeResultsForLiteral(e.getLiteral(), columnNames, columnTypes, row);
-    }
-    if (e.getType() == ExpressionType.FUNCTION) {
-      if (e.getFunctionCall().getOperator().equals("as")) {
-        String columnName = 
e.getFunctionCall().getOperands().get(1).getIdentifier().getName();
-        computeResultsForExpression(e.getFunctionCall().getOperands().get(0), 
columnNames, columnTypes, row);
-        columnNames.set(columnNames.size() - 1, columnName);
+  private void computeResultsForExpression(Expression expression, String[] 
columnNames, ColumnDataType[] columnTypes,
+      Object[] values, int index) {
+    ExpressionType type = expression.getType();
+    if (type == ExpressionType.LITERAL) {
+      computeResultsForLiteral(expression.getLiteral(), columnNames, 
columnTypes, values, index);
+    } else if (type == ExpressionType.FUNCTION) {
+      Function function = expression.getFunctionCall();
+      String operator = function.getOperator();
+      if (operator.equals("as")) {
+        List<Expression> operands = function.getOperands();
+        computeResultsForExpression(operands.get(0), columnNames, columnTypes, 
values, index);
+        columnNames[index] = operands.get(1).getIdentifier().getName();
       } else {
-        throw new IllegalStateException(
-            "No able to compute results for function - " + 
e.getFunctionCall().getOperator());
+        throw new IllegalStateException("No able to compute results for 
function - " + operator);
       }
     }
   }
 
-  private void computeResultsForLiteral(Literal literal, List<String> 
columnNames,
-      List<DataSchema.ColumnDataType> columnTypes, List<Object> row) {
-    columnNames.add(RequestUtils.prettyPrint(literal));
-    switch (literal.getSetField()) {
-      case NULL_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.UNKNOWN);
-        row.add(null);
-        break;
-      case BOOL_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.BOOLEAN);
-        row.add(literal.getBoolValue());
-        break;
-      case INT_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.INT);
-        row.add(literal.getIntValue());
-        break;
-      case LONG_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.LONG);
-        row.add(literal.getLongValue());
-        break;
-      case FLOAT_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.FLOAT);
-        row.add(Float.intBitsToFloat(literal.getFloatValue()));
-        break;
-      case DOUBLE_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.DOUBLE);
-        row.add(literal.getDoubleValue());
-        break;
-      case BIG_DECIMAL_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.BIG_DECIMAL);
-        row.add(BigDecimalUtils.deserialize(literal.getBigDecimalValue()));
-        break;
-      case STRING_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.STRING);
-        row.add(literal.getStringValue());
-        break;
-      case BINARY_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.BYTES);
-        row.add(BytesUtils.toHexString(literal.getBinaryValue()));
-        break;
-      // TODO: Revisit the array handling. Currently we are setting List into 
the row.
-      case INT_ARRAY_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.INT_ARRAY);
-        row.add(literal.getIntArrayValue());
-        break;
-      case LONG_ARRAY_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.LONG_ARRAY);
-        row.add(literal.getLongArrayValue());
-        break;
-      case FLOAT_ARRAY_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.FLOAT_ARRAY);
-        
row.add(literal.getFloatArrayValue().stream().map(Float::intBitsToFloat).collect(Collectors.toList()));
-        break;
-      case DOUBLE_ARRAY_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.DOUBLE_ARRAY);
-        row.add(literal.getDoubleArrayValue());
-        break;
-      case STRING_ARRAY_VALUE:
-        columnTypes.add(DataSchema.ColumnDataType.STRING_ARRAY);
-        row.add(literal.getStringArrayValue());
-        break;
-      default:
-        throw new IllegalStateException("Unsupported literal: " + literal);
-    }
+  private void computeResultsForLiteral(Literal literal, String[] columnNames, 
ColumnDataType[] columnTypes,
+      Object[] values, int index) {
+    columnNames[index] = RequestUtils.prettyPrint(literal);
+    Pair<ColumnDataType, Object> typeAndValue = 
RequestUtils.getLiteralTypeAndValue(literal);
+    columnTypes[index] = typeAndValue.getLeft();
+    values[index] = typeAndValue.getRight();
   }
 
   /**
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
index 0a2b8ad6e1..c0a55f23f5 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
@@ -23,11 +23,11 @@ import com.google.common.base.Preconditions;
 import java.math.BigDecimal;
 import java.sql.Timestamp;
 import java.util.Arrays;
-import java.util.List;
 import java.util.Objects;
 import javax.annotation.Nullable;
 import org.apache.pinot.common.request.Literal;
 import org.apache.pinot.common.utils.PinotDataType;
+import org.apache.pinot.common.utils.request.RequestUtils;
 import org.apache.pinot.spi.data.FieldSpec.DataType;
 import org.apache.pinot.spi.utils.BigDecimalUtils;
 import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder;
@@ -105,55 +105,31 @@ public class LiteralContext {
         break;
       case INT_ARRAY_VALUE: {
         _type = DataType.INT;
-        List<Integer> valueList = literal.getIntArrayValue();
-        int numValues = valueList.size();
-        int[] values = new int[numValues];
-        for (int i = 0; i < numValues; i++) {
-          values[i] = valueList.get(i);
-        }
-        _value = values;
+        _value = RequestUtils.getIntArrayValue(literal);
         _pinotDataType = PinotDataType.PRIMITIVE_INT_ARRAY;
         break;
       }
       case LONG_ARRAY_VALUE: {
         _type = DataType.LONG;
-        List<Long> valueList = literal.getLongArrayValue();
-        int numValues = valueList.size();
-        long[] values = new long[numValues];
-        for (int i = 0; i < numValues; i++) {
-          values[i] = valueList.get(i);
-        }
-        _value = values;
+        _value = RequestUtils.getLongArrayValue(literal);
         _pinotDataType = PinotDataType.PRIMITIVE_LONG_ARRAY;
         break;
       }
       case FLOAT_ARRAY_VALUE: {
         _type = DataType.FLOAT;
-        List<Integer> valueList = literal.getFloatArrayValue();
-        int numValues = valueList.size();
-        float[] values = new float[numValues];
-        for (int i = 0; i < numValues; i++) {
-          values[i] = Float.intBitsToFloat(valueList.get(i));
-        }
-        _value = values;
+        _value = RequestUtils.getFloatArrayValue(literal);
         _pinotDataType = PinotDataType.PRIMITIVE_FLOAT_ARRAY;
         break;
       }
       case DOUBLE_ARRAY_VALUE: {
         _type = DataType.DOUBLE;
-        List<Double> valueList = literal.getDoubleArrayValue();
-        int numValues = valueList.size();
-        double[] values = new double[numValues];
-        for (int i = 0; i < numValues; i++) {
-          values[i] = valueList.get(i);
-        }
-        _value = values;
+        _value = RequestUtils.getDoubleArrayValue(literal);
         _pinotDataType = PinotDataType.PRIMITIVE_DOUBLE_ARRAY;
         break;
       }
       case STRING_ARRAY_VALUE:
         _type = DataType.STRING;
-        _value = literal.getStringArrayValue().toArray(new String[0]);
+        _value = RequestUtils.getStringArrayValue(literal);
         _pinotDataType = PinotDataType.STRING_ARRAY;
         break;
       default:
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
index e8feaeeb07..6147b7f7ea 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
@@ -41,6 +41,7 @@ import javax.annotation.Nullable;
 import org.apache.calcite.sql.SqlLiteral;
 import org.apache.calcite.sql.SqlNumericLiteral;
 import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.request.DataSource;
 import org.apache.pinot.common.request.Expression;
 import org.apache.pinot.common.request.ExpressionType;
@@ -48,6 +49,7 @@ import org.apache.pinot.common.request.Function;
 import org.apache.pinot.common.request.Identifier;
 import org.apache.pinot.common.request.Literal;
 import org.apache.pinot.common.request.PinotQuery;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.spi.utils.BigDecimalUtils;
 import org.apache.pinot.spi.utils.BytesUtils;
 import org.apache.pinot.spi.utils.CommonConstants.Broker.Request;
@@ -343,21 +345,95 @@ public class RequestUtils {
       case BINARY_VALUE:
         return literal.getBinaryValue();
       case INT_ARRAY_VALUE:
-        return 
literal.getIntArrayValue().stream().mapToInt(Integer::intValue).toArray();
+        return getIntArrayValue(literal);
       case LONG_ARRAY_VALUE:
-        return 
literal.getLongArrayValue().stream().mapToLong(Long::longValue).toArray();
+        return getLongArrayValue(literal);
       case FLOAT_ARRAY_VALUE:
-        List<Integer> floatList = literal.getFloatArrayValue();
-        int numFloats = floatList.size();
-        float[] floatArray = new float[numFloats];
-        for (int i = 0; i < numFloats; i++) {
-          floatArray[i] = Float.intBitsToFloat(floatList.get(i));
-        }
-        return floatArray;
+        return getFloatArrayValue(literal);
+      case DOUBLE_ARRAY_VALUE:
+        return getDoubleArrayValue(literal);
+      case STRING_ARRAY_VALUE:
+        return getStringArrayValue(literal);
+      default:
+        throw new IllegalStateException("Unsupported field type: " + type);
+    }
+  }
+
+  public static int[] getIntArrayValue(Literal literal) {
+    List<Integer> list = literal.getIntArrayValue();
+    int size = list.size();
+    int[] array = new int[size];
+    for (int i = 0; i < size; i++) {
+      array[i] = list.get(i);
+    }
+    return array;
+  }
+
+  public static long[] getLongArrayValue(Literal literal) {
+    List<Long> list = literal.getLongArrayValue();
+    int size = list.size();
+    long[] array = new long[size];
+    for (int i = 0; i < size; i++) {
+      array[i] = list.get(i);
+    }
+    return array;
+  }
+
+  public static float[] getFloatArrayValue(Literal literal) {
+    List<Integer> list = literal.getFloatArrayValue();
+    int size = list.size();
+    float[] array = new float[size];
+    for (int i = 0; i < size; i++) {
+      array[i] = Float.intBitsToFloat(list.get(i));
+    }
+    return array;
+  }
+
+  public static double[] getDoubleArrayValue(Literal literal) {
+    List<Double> list = literal.getDoubleArrayValue();
+    int size = list.size();
+    double[] array = new double[size];
+    for (int i = 0; i < size; i++) {
+      array[i] = list.get(i);
+    }
+    return array;
+  }
+
+  public static String[] getStringArrayValue(Literal literal) {
+    return literal.getStringArrayValue().toArray(new String[0]);
+  }
+
+  public static Pair<ColumnDataType, Object> getLiteralTypeAndValue(Literal 
literal) {
+    Literal._Fields type = literal.getSetField();
+    switch (type) {
+      case NULL_VALUE:
+        return Pair.of(ColumnDataType.UNKNOWN, null);
+      case BOOL_VALUE:
+        return Pair.of(ColumnDataType.BOOLEAN, literal.getBoolValue());
+      case INT_VALUE:
+        return Pair.of(ColumnDataType.INT, literal.getIntValue());
+      case LONG_VALUE:
+        return Pair.of(ColumnDataType.LONG, literal.getLongValue());
+      case FLOAT_VALUE:
+        return Pair.of(ColumnDataType.FLOAT, 
Float.intBitsToFloat(literal.getFloatValue()));
+      case DOUBLE_VALUE:
+        return Pair.of(ColumnDataType.DOUBLE, literal.getDoubleValue());
+      case BIG_DECIMAL_VALUE:
+        return Pair.of(ColumnDataType.BIG_DECIMAL, 
BigDecimalUtils.deserialize(literal.getBigDecimalValue()));
+      case STRING_VALUE:
+        return Pair.of(ColumnDataType.STRING, literal.getStringValue());
+      case BINARY_VALUE:
+        return Pair.of(ColumnDataType.BYTES, literal.getBinaryValue());
+      case INT_ARRAY_VALUE:
+        return Pair.of(ColumnDataType.INT_ARRAY, getIntArrayValue(literal));
+      case LONG_ARRAY_VALUE:
+        return Pair.of(ColumnDataType.LONG_ARRAY, getLongArrayValue(literal));
+      case FLOAT_ARRAY_VALUE:
+        return Pair.of(ColumnDataType.FLOAT_ARRAY, 
getFloatArrayValue(literal));
       case DOUBLE_ARRAY_VALUE:
-        return 
literal.getDoubleArrayValue().stream().mapToDouble(Double::doubleValue).toArray();
+        return Pair.of(ColumnDataType.DOUBLE_ARRAY, 
getDoubleArrayValue(literal));
       case STRING_ARRAY_VALUE:
-        return literal.getStringArrayValue().toArray(new String[0]);
+        return Pair.of(ColumnDataType.STRING_ARRAY, 
getStringArrayValue(literal));
       default:
         throw new IllegalStateException("Unsupported field type: " + type);
     }
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
 
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
index 6a47fa827e..1e10fbed52 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java
@@ -18,20 +18,25 @@
  */
 package org.apache.pinot.sql.parsers.rewriter;
 
+import com.google.common.annotations.VisibleForTesting;
 import java.util.Arrays;
 import java.util.List;
 import javax.annotation.Nullable;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.function.FunctionInfo;
 import org.apache.pinot.common.function.FunctionInvoker;
 import org.apache.pinot.common.function.FunctionRegistry;
 import org.apache.pinot.common.request.Expression;
 import org.apache.pinot.common.request.Function;
+import org.apache.pinot.common.request.Literal;
 import org.apache.pinot.common.request.PinotQuery;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.common.utils.request.RequestUtils;
 import org.apache.pinot.sql.parsers.SqlCompilationException;
 
 
 public class CompileTimeFunctionsInvoker implements QueryRewriter {
+
   @Override
   public PinotQuery rewrite(PinotQuery pinotQuery) {
     for (int i = 0; i < pinotQuery.getSelectListSize(); i++) {
@@ -53,7 +58,8 @@ public class CompileTimeFunctionsInvoker implements 
QueryRewriter {
     return pinotQuery;
   }
 
-  protected static Expression invokeCompileTimeFunctionExpression(@Nullable 
Expression expression) {
+  @VisibleForTesting
+  public static Expression invokeCompileTimeFunctionExpression(@Nullable 
Expression expression) {
     if (expression == null || expression.getFunctionCall() == null) {
       return expression;
     }
@@ -61,38 +67,44 @@ public class CompileTimeFunctionsInvoker implements 
QueryRewriter {
     List<Expression> operands = function.getOperands();
     int numOperands = operands.size();
     boolean compilable = true;
+    ColumnDataType[] argumentTypes = new ColumnDataType[numOperands];
+    Object[] arguments = new Object[numOperands];
     for (int i = 0; i < numOperands; i++) {
       Expression operand = 
invokeCompileTimeFunctionExpression(operands.get(i));
-      if (operand.getLiteral() == null) {
+      operands.set(i, operand);
+      Literal literal = operand.getLiteral();
+      if (compilable && literal != null) {
+        Pair<ColumnDataType, Object> typeAndValue = 
RequestUtils.getLiteralTypeAndValue(literal);
+        argumentTypes[i] = typeAndValue.getLeft();
+        arguments[i] = typeAndValue.getRight();
+      } else {
+        // NOTE: Do not directly 'return expression;' here because we want to 
compile all operands even if the current
+        //       expression is not compilable.
         compilable = false;
       }
-      operands.set(i, operand);
     }
-    if (compilable) {
-      String canonicalName = 
FunctionRegistry.canonicalize(function.getOperator());
-      FunctionInfo functionInfo = 
FunctionRegistry.lookupFunctionInfo(canonicalName, numOperands);
-      if (functionInfo != null) {
-        Object[] arguments = new Object[numOperands];
-        for (int i = 0; i < numOperands; i++) {
-          arguments[i] = 
RequestUtils.getLiteralValue(function.getOperands().get(i).getLiteral());
-        }
-        try {
-          FunctionInvoker invoker = new FunctionInvoker(functionInfo);
-          Object result;
-          if (invoker.getMethod().isVarArgs()) {
-            result = invoker.invoke(new Object[] {arguments});
-          } else {
-            invoker.convertTypes(arguments);
-            result = invoker.invoke(arguments);
-          }
-          return RequestUtils.getLiteralExpression(result);
-        } catch (Exception e) {
-          throw new SqlCompilationException(
-              "Caught exception while invoking method: " + 
functionInfo.getMethod() + " with arguments: "
-                  + Arrays.toString(arguments), e);
-        }
+    if (!compilable) {
+      return expression;
+    }
+    String canonicalName = 
FunctionRegistry.canonicalize(function.getOperator());
+    FunctionInfo functionInfo = 
FunctionRegistry.lookupFunctionInfo(canonicalName, argumentTypes);
+    if (functionInfo == null) {
+      return expression;
+    }
+    try {
+      FunctionInvoker invoker = new FunctionInvoker(functionInfo);
+      Object result;
+      if (invoker.getMethod().isVarArgs()) {
+        result = invoker.invoke(new Object[]{arguments});
+      } else {
+        invoker.convertTypes(arguments);
+        result = invoker.invoke(arguments);
       }
+      return RequestUtils.getLiteralExpression(result);
+    } catch (Exception e) {
+      throw new SqlCompilationException(
+          "Caught exception while invoking method: " + 
functionInfo.getMethod() + " with arguments: " + Arrays.toString(
+              arguments), e);
     }
-    return expression;
   }
 }
diff --git 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index 369dd8b886..35a625505a 100644
--- 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++ 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -1054,8 +1054,7 @@ public class CalciteSqlCompilerTest {
         
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(2)
             .getLiteral().getStringValue(), "SECONDS");
     Assert.assertEquals(
-        
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getLiteral().getIntValue(),
-        1394323200);
+        
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getLiteral().getIntValue(),
 1394323200);
   }
 
   @Test
@@ -1379,8 +1378,8 @@ public class CalciteSqlCompilerTest {
       Assert.fail("Query should have failed compilation");
     } catch (Exception e) {
       Assert.assertTrue(e instanceof SqlCompilationException);
-      Assert.assertTrue(e.getMessage().contains("'group_city' should be 
functionally dependent on the columns "
-          + "used in GROUP BY clause."));
+      Assert.assertTrue(e.getMessage()
+          .contains("'group_city' should be functionally dependent on the 
columns " + "used in GROUP BY clause."));
     }
 
     // Valid groupBy non-aggregate function should pass.
@@ -1398,8 +1397,8 @@ public class CalciteSqlCompilerTest {
       Assert.fail("Query should have failed compilation");
     } catch (Exception e) {
       Assert.assertTrue(e instanceof SqlCompilationException);
-      Assert.assertTrue(e.getMessage().contains("'secondsSinceEpoch' should be 
functionally dependent on the columns "
-          + "used in GROUP BY clause."));
+      Assert.assertTrue(e.getMessage().contains(
+          "'secondsSinceEpoch' should be functionally dependent on the columns 
" + "used in GROUP BY clause."));
     }
 
     // Invalid groupBy clause shouldn't contain aggregate expression, like 
sum(rsvp_count), count(*).
@@ -2331,14 +2330,10 @@ public class CalciteSqlCompilerTest {
 
   @Test
   public void testCompileTimeExpression() {
-    final CompileTimeFunctionsInvoker compileTimeFunctionsInvoker = new 
CompileTimeFunctionsInvoker();
     long lowerBound = System.currentTimeMillis();
     Expression expression = compileToExpression("now()");
     Assert.assertNotNull(expression.getFunctionCall());
-    PinotQuery pinotQuery = new PinotQuery();
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getLiteral());
     long upperBound = System.currentTimeMillis();
     long result = expression.getLiteral().getLongValue();
@@ -2347,10 +2342,7 @@ public class CalciteSqlCompilerTest {
     lowerBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1;
     expression = compileToExpression("to_epoch_hours(now() + 3600000)");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
-    Assert.assertNotNull(expression.getLiteral());
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     upperBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1;
     result = expression.getLiteral().getLongValue();
     Assert.assertTrue(result >= lowerBound && result <= upperBound);
@@ -2358,9 +2350,7 @@ public class CalciteSqlCompilerTest {
     lowerBound = System.currentTimeMillis() - ONE_HOUR_IN_MS;
     expression = compileToExpression("ago('PT1H')");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getLiteral());
     upperBound = System.currentTimeMillis() - ONE_HOUR_IN_MS;
     result = expression.getLiteral().getLongValue();
@@ -2369,9 +2359,7 @@ public class CalciteSqlCompilerTest {
     lowerBound = System.currentTimeMillis() + ONE_HOUR_IN_MS;
     expression = compileToExpression("ago('PT-1H')");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getLiteral());
     upperBound = System.currentTimeMillis() + ONE_HOUR_IN_MS;
     result = expression.getLiteral().getLongValue();
@@ -2379,9 +2367,7 @@ public class CalciteSqlCompilerTest {
 
     expression = compileToExpression("toDateTime(millisSinceEpoch)");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getFunctionCall());
     Assert.assertEquals(expression.getFunctionCall().getOperator(), 
"todatetime");
     
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
@@ -2389,88 +2375,105 @@ public class CalciteSqlCompilerTest {
 
     expression = compileToExpression("encodeUrl('key1=value 
1&key2=value@!$2&key3=value%3')");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getLiteral());
     Assert.assertEquals(expression.getLiteral().getStringValue(),
         "key1%3Dvalue+1%26key2%3Dvalue%40%21%242%26key3%3Dvalue%253");
 
     expression = 
compileToExpression("decodeUrl('key1%3Dvalue+1%26key2%3Dvalue%40%21%242%26key3%3Dvalue%253')");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getLiteral());
     Assert.assertEquals(expression.getLiteral().getStringValue(), "key1=value 
1&key2=value@!$2&key3=value%3");
 
     expression = compileToExpression("reverse(playerName)");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getFunctionCall());
     Assert.assertEquals(expression.getFunctionCall().getOperator(), "reverse");
     
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
 "playerName");
 
     expression = compileToExpression("reverse('playerName')");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getLiteral());
     Assert.assertEquals(expression.getLiteral().getStringValue(), 
"emaNreyalp");
 
     expression = compileToExpression("reverse(123)");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getLiteral());
     Assert.assertEquals(expression.getLiteral().getStringValue(), "321");
 
     expression = compileToExpression("count(*)");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getFunctionCall());
     Assert.assertEquals(expression.getFunctionCall().getOperator(), "count");
     
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
 "*");
 
     expression = compileToExpression("toBase64(toUtf8('hello!'))");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getLiteral());
     Assert.assertEquals(expression.getLiteral().getStringValue(), "aGVsbG8h");
 
     expression = compileToExpression("fromUtf8(fromBase64('aGVsbG8h'))");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getLiteral());
     Assert.assertEquals(expression.getLiteral().getStringValue(), "hello!");
 
     expression = compileToExpression("fromBase64(foo)");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getFunctionCall());
     Assert.assertEquals(expression.getFunctionCall().getOperator(), 
"frombase64");
     
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
 "foo");
 
     expression = compileToExpression("toBase64(foo)");
     Assert.assertNotNull(expression.getFunctionCall());
-    pinotQuery.setFilterExpression(expression);
-    pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery);
-    expression = pinotQuery.getFilterExpression();
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
     Assert.assertNotNull(expression.getFunctionCall());
     Assert.assertEquals(expression.getFunctionCall().getOperator(), 
"tobase64");
     
Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
 "foo");
+
+    expression = compileToExpression("'foo' > 'bar'");
+    Assert.assertNotNull(expression.getFunctionCall());
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+    Assert.assertNotNull(expression.getLiteral());
+    Assert.assertTrue(expression.getLiteral().getBoolValue());
+
+    expression = compileToExpression("toBase64(toUtf8('hello!')) = 
'aGVsbG8h'");
+    Assert.assertNotNull(expression.getFunctionCall());
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+    Assert.assertNotNull(expression.getLiteral());
+    Assert.assertTrue(expression.getLiteral().getBoolValue());
+
+    expression = compileToExpression("fromUtf8(fromBase64('aGVsbG8h')) != 
'hello!'");
+    Assert.assertNotNull(expression.getFunctionCall());
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+    Assert.assertNotNull(expression.getLiteral());
+    Assert.assertFalse(expression.getLiteral().getBoolValue());
+
+    expression = compileToExpression("123 < 123.000000000000000000001");
+    Assert.assertNotNull(expression.getFunctionCall());
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+    Assert.assertNotNull(expression.getLiteral());
+    Assert.assertFalse(expression.getLiteral().getBoolValue());
+
+    expression = compileToExpression("cast('123' as big_decimal) < 
cast('123.000000000000000000001' as big_decimal)");
+    Assert.assertNotNull(expression.getFunctionCall());
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+    Assert.assertNotNull(expression.getLiteral());
+    Assert.assertTrue(expression.getLiteral().getBoolValue());
+
+    // Should fall back to DOUBLE comparison
+    expression = compileToExpression("123 < cast('123.000000000000000000001' 
as big_decimal)");
+    Assert.assertNotNull(expression.getFunctionCall());
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+    Assert.assertNotNull(expression.getLiteral());
+    Assert.assertFalse(expression.getLiteral().getBoolValue());
   }
 
   @Test
@@ -2599,19 +2602,14 @@ public class CalciteSqlCompilerTest {
     String query = "SELECT col1 FROM foo GROUP BY col1, col2";
     PinotQuery pinotQuery = compileToPinotQuery(query);
     Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
-    Assert.assertEquals(
-        pinotQuery.getSelectList().get(0).getIdentifier().getName(), "col1");
-    Assert.assertEquals(
-        pinotQuery.getGroupByList().get(0).getIdentifier().getName(), "col1");
-    Assert.assertEquals(
-        pinotQuery.getGroupByList().get(1).getIdentifier().getName(), "col2");
+    
Assert.assertEquals(pinotQuery.getSelectList().get(0).getIdentifier().getName(),
 "col1");
+    
Assert.assertEquals(pinotQuery.getGroupByList().get(0).getIdentifier().getName(),
 "col1");
+    
Assert.assertEquals(pinotQuery.getGroupByList().get(1).getIdentifier().getName(),
 "col2");
 
     query = "SELECT col1+col2 FROM foo GROUP BY col1,col2";
     pinotQuery = compileToPinotQuery(query);
     Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
-    Assert.assertEquals(
-        pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(),
-        "plus");
+    
Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(),
 "plus");
     Assert.assertEquals(
         
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getIdentifier().getName(),
 "col1");
     Assert.assertEquals(
@@ -3023,7 +3021,6 @@ public class CalciteSqlCompilerTest {
   public void testParserExtensionImpl() {
     String customSql = "INSERT INTO db.tbl FROM FILE 'file:///tmp/file1', FILE 
'file:///tmp/file2'";
     SqlNodeAndOptions sqlNodeAndOptions = 
CalciteSqlParser.compileToSqlNodeAndOptions(customSql);
-    ;
     Assert.assertTrue(sqlNodeAndOptions.getSqlNode() instanceof 
SqlInsertFromFile);
     Assert.assertEquals(sqlNodeAndOptions.getSqlType(), PinotSqlType.DML);
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org


Reply via email to