This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new f8de958174 Use argument type to lookup function for literal only query (#13673) f8de958174 is described below commit f8de958174e209ab7f11572149ce1723d63b5af3 Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com> AuthorDate: Wed Aug 28 11:50:02 2024 -0700 Use argument type to lookup function for literal only query (#13673) --- .../BaseSingleStageBrokerRequestHandler.java | 120 ++++++-------------- .../common/request/context/LiteralContext.java | 36 +----- .../pinot/common/utils/request/RequestUtils.java | 98 ++++++++++++++-- .../rewriter/CompileTimeFunctionsInvoker.java | 66 ++++++----- .../pinot/sql/parsers/CalciteSqlCompilerTest.java | 125 ++++++++++----------- 5 files changed, 225 insertions(+), 220 deletions(-) diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java index 28eebf205b..83f83188dc 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java @@ -77,6 +77,7 @@ import org.apache.pinot.common.response.ProcessingException; import org.apache.pinot.common.response.broker.BrokerResponseNative; import org.apache.pinot.common.response.broker.ResultTable; import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.common.utils.DatabaseUtils; import org.apache.pinot.common.utils.config.QueryOptionsUtils; import org.apache.pinot.common.utils.request.RequestUtils; @@ -101,8 +102,6 @@ import org.apache.pinot.spi.exception.BadQueryRequestException; import org.apache.pinot.spi.exception.DatabaseConflictException; import org.apache.pinot.spi.trace.RequestContext; import org.apache.pinot.spi.trace.Tracing; -import org.apache.pinot.spi.utils.BigDecimalUtils; -import org.apache.pinot.spi.utils.BytesUtils; import org.apache.pinot.spi.utils.CommonConstants; import org.apache.pinot.spi.utils.CommonConstants.Broker; import org.apache.pinot.spi.utils.CommonConstants.Broker.Request.QueryOptionKey; @@ -1489,16 +1488,18 @@ public abstract class BaseSingleStageBrokerRequestHandler extends BaseBrokerRequ private BrokerResponseNative processLiteralOnlyQuery(long requestId, PinotQuery pinotQuery, RequestContext requestContext) { BrokerResponseNative brokerResponse = new BrokerResponseNative(); - List<String> columnNames = new ArrayList<>(); - List<DataSchema.ColumnDataType> columnTypes = new ArrayList<>(); - List<Object> row = new ArrayList<>(); - for (Expression expression : pinotQuery.getSelectList()) { - computeResultsForExpression(expression, columnNames, columnTypes, row); - } - DataSchema dataSchema = - new DataSchema(columnNames.toArray(new String[0]), columnTypes.toArray(new DataSchema.ColumnDataType[0])); - List<Object[]> rows = new ArrayList<>(); - rows.add(row.toArray()); + List<Expression> selectList = pinotQuery.getSelectList(); + int numColumns = selectList.size(); + String[] columnNames = new String[numColumns]; + ColumnDataType[] columnTypes = new ColumnDataType[numColumns]; + Object[] values = new Object[numColumns]; + for (int i = 0; i < numColumns; i++) { + computeResultsForExpression(selectList.get(i), columnNames, columnTypes, values, i); + values[i] = columnTypes[i].format(values[i]); + } + DataSchema dataSchema = new DataSchema(columnNames, columnTypes); + List<Object[]> rows = new ArrayList<>(1); + rows.add(values); ResultTable resultTable = new ResultTable(dataSchema, rows); brokerResponse.setResultTable(resultTable); brokerResponse.setTimeUsedMs(System.currentTimeMillis() - requestContext.getRequestArrivalTimeMillis()); @@ -1510,87 +1511,30 @@ public abstract class BaseSingleStageBrokerRequestHandler extends BaseBrokerRequ } // TODO(xiangfu): Move Literal function computation here from Calcite Parser. - private void computeResultsForExpression(Expression e, List<String> columnNames, - List<DataSchema.ColumnDataType> columnTypes, List<Object> row) { - if (e.getType() == ExpressionType.LITERAL) { - computeResultsForLiteral(e.getLiteral(), columnNames, columnTypes, row); - } - if (e.getType() == ExpressionType.FUNCTION) { - if (e.getFunctionCall().getOperator().equals("as")) { - String columnName = e.getFunctionCall().getOperands().get(1).getIdentifier().getName(); - computeResultsForExpression(e.getFunctionCall().getOperands().get(0), columnNames, columnTypes, row); - columnNames.set(columnNames.size() - 1, columnName); + private void computeResultsForExpression(Expression expression, String[] columnNames, ColumnDataType[] columnTypes, + Object[] values, int index) { + ExpressionType type = expression.getType(); + if (type == ExpressionType.LITERAL) { + computeResultsForLiteral(expression.getLiteral(), columnNames, columnTypes, values, index); + } else if (type == ExpressionType.FUNCTION) { + Function function = expression.getFunctionCall(); + String operator = function.getOperator(); + if (operator.equals("as")) { + List<Expression> operands = function.getOperands(); + computeResultsForExpression(operands.get(0), columnNames, columnTypes, values, index); + columnNames[index] = operands.get(1).getIdentifier().getName(); } else { - throw new IllegalStateException( - "No able to compute results for function - " + e.getFunctionCall().getOperator()); + throw new IllegalStateException("No able to compute results for function - " + operator); } } } - private void computeResultsForLiteral(Literal literal, List<String> columnNames, - List<DataSchema.ColumnDataType> columnTypes, List<Object> row) { - columnNames.add(RequestUtils.prettyPrint(literal)); - switch (literal.getSetField()) { - case NULL_VALUE: - columnTypes.add(DataSchema.ColumnDataType.UNKNOWN); - row.add(null); - break; - case BOOL_VALUE: - columnTypes.add(DataSchema.ColumnDataType.BOOLEAN); - row.add(literal.getBoolValue()); - break; - case INT_VALUE: - columnTypes.add(DataSchema.ColumnDataType.INT); - row.add(literal.getIntValue()); - break; - case LONG_VALUE: - columnTypes.add(DataSchema.ColumnDataType.LONG); - row.add(literal.getLongValue()); - break; - case FLOAT_VALUE: - columnTypes.add(DataSchema.ColumnDataType.FLOAT); - row.add(Float.intBitsToFloat(literal.getFloatValue())); - break; - case DOUBLE_VALUE: - columnTypes.add(DataSchema.ColumnDataType.DOUBLE); - row.add(literal.getDoubleValue()); - break; - case BIG_DECIMAL_VALUE: - columnTypes.add(DataSchema.ColumnDataType.BIG_DECIMAL); - row.add(BigDecimalUtils.deserialize(literal.getBigDecimalValue())); - break; - case STRING_VALUE: - columnTypes.add(DataSchema.ColumnDataType.STRING); - row.add(literal.getStringValue()); - break; - case BINARY_VALUE: - columnTypes.add(DataSchema.ColumnDataType.BYTES); - row.add(BytesUtils.toHexString(literal.getBinaryValue())); - break; - // TODO: Revisit the array handling. Currently we are setting List into the row. - case INT_ARRAY_VALUE: - columnTypes.add(DataSchema.ColumnDataType.INT_ARRAY); - row.add(literal.getIntArrayValue()); - break; - case LONG_ARRAY_VALUE: - columnTypes.add(DataSchema.ColumnDataType.LONG_ARRAY); - row.add(literal.getLongArrayValue()); - break; - case FLOAT_ARRAY_VALUE: - columnTypes.add(DataSchema.ColumnDataType.FLOAT_ARRAY); - row.add(literal.getFloatArrayValue().stream().map(Float::intBitsToFloat).collect(Collectors.toList())); - break; - case DOUBLE_ARRAY_VALUE: - columnTypes.add(DataSchema.ColumnDataType.DOUBLE_ARRAY); - row.add(literal.getDoubleArrayValue()); - break; - case STRING_ARRAY_VALUE: - columnTypes.add(DataSchema.ColumnDataType.STRING_ARRAY); - row.add(literal.getStringArrayValue()); - break; - default: - throw new IllegalStateException("Unsupported literal: " + literal); - } + private void computeResultsForLiteral(Literal literal, String[] columnNames, ColumnDataType[] columnTypes, + Object[] values, int index) { + columnNames[index] = RequestUtils.prettyPrint(literal); + Pair<ColumnDataType, Object> typeAndValue = RequestUtils.getLiteralTypeAndValue(literal); + columnTypes[index] = typeAndValue.getLeft(); + values[index] = typeAndValue.getRight(); } /** diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java index 0a2b8ad6e1..c0a55f23f5 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java @@ -23,11 +23,11 @@ import com.google.common.base.Preconditions; import java.math.BigDecimal; import java.sql.Timestamp; import java.util.Arrays; -import java.util.List; import java.util.Objects; import javax.annotation.Nullable; import org.apache.pinot.common.request.Literal; import org.apache.pinot.common.utils.PinotDataType; +import org.apache.pinot.common.utils.request.RequestUtils; import org.apache.pinot.spi.data.FieldSpec.DataType; import org.apache.pinot.spi.utils.BigDecimalUtils; import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder; @@ -105,55 +105,31 @@ public class LiteralContext { break; case INT_ARRAY_VALUE: { _type = DataType.INT; - List<Integer> valueList = literal.getIntArrayValue(); - int numValues = valueList.size(); - int[] values = new int[numValues]; - for (int i = 0; i < numValues; i++) { - values[i] = valueList.get(i); - } - _value = values; + _value = RequestUtils.getIntArrayValue(literal); _pinotDataType = PinotDataType.PRIMITIVE_INT_ARRAY; break; } case LONG_ARRAY_VALUE: { _type = DataType.LONG; - List<Long> valueList = literal.getLongArrayValue(); - int numValues = valueList.size(); - long[] values = new long[numValues]; - for (int i = 0; i < numValues; i++) { - values[i] = valueList.get(i); - } - _value = values; + _value = RequestUtils.getLongArrayValue(literal); _pinotDataType = PinotDataType.PRIMITIVE_LONG_ARRAY; break; } case FLOAT_ARRAY_VALUE: { _type = DataType.FLOAT; - List<Integer> valueList = literal.getFloatArrayValue(); - int numValues = valueList.size(); - float[] values = new float[numValues]; - for (int i = 0; i < numValues; i++) { - values[i] = Float.intBitsToFloat(valueList.get(i)); - } - _value = values; + _value = RequestUtils.getFloatArrayValue(literal); _pinotDataType = PinotDataType.PRIMITIVE_FLOAT_ARRAY; break; } case DOUBLE_ARRAY_VALUE: { _type = DataType.DOUBLE; - List<Double> valueList = literal.getDoubleArrayValue(); - int numValues = valueList.size(); - double[] values = new double[numValues]; - for (int i = 0; i < numValues; i++) { - values[i] = valueList.get(i); - } - _value = values; + _value = RequestUtils.getDoubleArrayValue(literal); _pinotDataType = PinotDataType.PRIMITIVE_DOUBLE_ARRAY; break; } case STRING_ARRAY_VALUE: _type = DataType.STRING; - _value = literal.getStringArrayValue().toArray(new String[0]); + _value = RequestUtils.getStringArrayValue(literal); _pinotDataType = PinotDataType.STRING_ARRAY; break; default: diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java index e8feaeeb07..6147b7f7ea 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java @@ -41,6 +41,7 @@ import javax.annotation.Nullable; import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNumericLiteral; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.request.DataSource; import org.apache.pinot.common.request.Expression; import org.apache.pinot.common.request.ExpressionType; @@ -48,6 +49,7 @@ import org.apache.pinot.common.request.Function; import org.apache.pinot.common.request.Identifier; import org.apache.pinot.common.request.Literal; import org.apache.pinot.common.request.PinotQuery; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.spi.utils.BigDecimalUtils; import org.apache.pinot.spi.utils.BytesUtils; import org.apache.pinot.spi.utils.CommonConstants.Broker.Request; @@ -343,21 +345,95 @@ public class RequestUtils { case BINARY_VALUE: return literal.getBinaryValue(); case INT_ARRAY_VALUE: - return literal.getIntArrayValue().stream().mapToInt(Integer::intValue).toArray(); + return getIntArrayValue(literal); case LONG_ARRAY_VALUE: - return literal.getLongArrayValue().stream().mapToLong(Long::longValue).toArray(); + return getLongArrayValue(literal); case FLOAT_ARRAY_VALUE: - List<Integer> floatList = literal.getFloatArrayValue(); - int numFloats = floatList.size(); - float[] floatArray = new float[numFloats]; - for (int i = 0; i < numFloats; i++) { - floatArray[i] = Float.intBitsToFloat(floatList.get(i)); - } - return floatArray; + return getFloatArrayValue(literal); + case DOUBLE_ARRAY_VALUE: + return getDoubleArrayValue(literal); + case STRING_ARRAY_VALUE: + return getStringArrayValue(literal); + default: + throw new IllegalStateException("Unsupported field type: " + type); + } + } + + public static int[] getIntArrayValue(Literal literal) { + List<Integer> list = literal.getIntArrayValue(); + int size = list.size(); + int[] array = new int[size]; + for (int i = 0; i < size; i++) { + array[i] = list.get(i); + } + return array; + } + + public static long[] getLongArrayValue(Literal literal) { + List<Long> list = literal.getLongArrayValue(); + int size = list.size(); + long[] array = new long[size]; + for (int i = 0; i < size; i++) { + array[i] = list.get(i); + } + return array; + } + + public static float[] getFloatArrayValue(Literal literal) { + List<Integer> list = literal.getFloatArrayValue(); + int size = list.size(); + float[] array = new float[size]; + for (int i = 0; i < size; i++) { + array[i] = Float.intBitsToFloat(list.get(i)); + } + return array; + } + + public static double[] getDoubleArrayValue(Literal literal) { + List<Double> list = literal.getDoubleArrayValue(); + int size = list.size(); + double[] array = new double[size]; + for (int i = 0; i < size; i++) { + array[i] = list.get(i); + } + return array; + } + + public static String[] getStringArrayValue(Literal literal) { + return literal.getStringArrayValue().toArray(new String[0]); + } + + public static Pair<ColumnDataType, Object> getLiteralTypeAndValue(Literal literal) { + Literal._Fields type = literal.getSetField(); + switch (type) { + case NULL_VALUE: + return Pair.of(ColumnDataType.UNKNOWN, null); + case BOOL_VALUE: + return Pair.of(ColumnDataType.BOOLEAN, literal.getBoolValue()); + case INT_VALUE: + return Pair.of(ColumnDataType.INT, literal.getIntValue()); + case LONG_VALUE: + return Pair.of(ColumnDataType.LONG, literal.getLongValue()); + case FLOAT_VALUE: + return Pair.of(ColumnDataType.FLOAT, Float.intBitsToFloat(literal.getFloatValue())); + case DOUBLE_VALUE: + return Pair.of(ColumnDataType.DOUBLE, literal.getDoubleValue()); + case BIG_DECIMAL_VALUE: + return Pair.of(ColumnDataType.BIG_DECIMAL, BigDecimalUtils.deserialize(literal.getBigDecimalValue())); + case STRING_VALUE: + return Pair.of(ColumnDataType.STRING, literal.getStringValue()); + case BINARY_VALUE: + return Pair.of(ColumnDataType.BYTES, literal.getBinaryValue()); + case INT_ARRAY_VALUE: + return Pair.of(ColumnDataType.INT_ARRAY, getIntArrayValue(literal)); + case LONG_ARRAY_VALUE: + return Pair.of(ColumnDataType.LONG_ARRAY, getLongArrayValue(literal)); + case FLOAT_ARRAY_VALUE: + return Pair.of(ColumnDataType.FLOAT_ARRAY, getFloatArrayValue(literal)); case DOUBLE_ARRAY_VALUE: - return literal.getDoubleArrayValue().stream().mapToDouble(Double::doubleValue).toArray(); + return Pair.of(ColumnDataType.DOUBLE_ARRAY, getDoubleArrayValue(literal)); case STRING_ARRAY_VALUE: - return literal.getStringArrayValue().toArray(new String[0]); + return Pair.of(ColumnDataType.STRING_ARRAY, getStringArrayValue(literal)); default: throw new IllegalStateException("Unsupported field type: " + type); } diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java index 6a47fa827e..1e10fbed52 100644 --- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java +++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java @@ -18,20 +18,25 @@ */ package org.apache.pinot.sql.parsers.rewriter; +import com.google.common.annotations.VisibleForTesting; import java.util.Arrays; import java.util.List; import javax.annotation.Nullable; +import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.function.FunctionInfo; import org.apache.pinot.common.function.FunctionInvoker; import org.apache.pinot.common.function.FunctionRegistry; import org.apache.pinot.common.request.Expression; import org.apache.pinot.common.request.Function; +import org.apache.pinot.common.request.Literal; import org.apache.pinot.common.request.PinotQuery; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.common.utils.request.RequestUtils; import org.apache.pinot.sql.parsers.SqlCompilationException; public class CompileTimeFunctionsInvoker implements QueryRewriter { + @Override public PinotQuery rewrite(PinotQuery pinotQuery) { for (int i = 0; i < pinotQuery.getSelectListSize(); i++) { @@ -53,7 +58,8 @@ public class CompileTimeFunctionsInvoker implements QueryRewriter { return pinotQuery; } - protected static Expression invokeCompileTimeFunctionExpression(@Nullable Expression expression) { + @VisibleForTesting + public static Expression invokeCompileTimeFunctionExpression(@Nullable Expression expression) { if (expression == null || expression.getFunctionCall() == null) { return expression; } @@ -61,38 +67,44 @@ public class CompileTimeFunctionsInvoker implements QueryRewriter { List<Expression> operands = function.getOperands(); int numOperands = operands.size(); boolean compilable = true; + ColumnDataType[] argumentTypes = new ColumnDataType[numOperands]; + Object[] arguments = new Object[numOperands]; for (int i = 0; i < numOperands; i++) { Expression operand = invokeCompileTimeFunctionExpression(operands.get(i)); - if (operand.getLiteral() == null) { + operands.set(i, operand); + Literal literal = operand.getLiteral(); + if (compilable && literal != null) { + Pair<ColumnDataType, Object> typeAndValue = RequestUtils.getLiteralTypeAndValue(literal); + argumentTypes[i] = typeAndValue.getLeft(); + arguments[i] = typeAndValue.getRight(); + } else { + // NOTE: Do not directly 'return expression;' here because we want to compile all operands even if the current + // expression is not compilable. compilable = false; } - operands.set(i, operand); } - if (compilable) { - String canonicalName = FunctionRegistry.canonicalize(function.getOperator()); - FunctionInfo functionInfo = FunctionRegistry.lookupFunctionInfo(canonicalName, numOperands); - if (functionInfo != null) { - Object[] arguments = new Object[numOperands]; - for (int i = 0; i < numOperands; i++) { - arguments[i] = RequestUtils.getLiteralValue(function.getOperands().get(i).getLiteral()); - } - try { - FunctionInvoker invoker = new FunctionInvoker(functionInfo); - Object result; - if (invoker.getMethod().isVarArgs()) { - result = invoker.invoke(new Object[] {arguments}); - } else { - invoker.convertTypes(arguments); - result = invoker.invoke(arguments); - } - return RequestUtils.getLiteralExpression(result); - } catch (Exception e) { - throw new SqlCompilationException( - "Caught exception while invoking method: " + functionInfo.getMethod() + " with arguments: " - + Arrays.toString(arguments), e); - } + if (!compilable) { + return expression; + } + String canonicalName = FunctionRegistry.canonicalize(function.getOperator()); + FunctionInfo functionInfo = FunctionRegistry.lookupFunctionInfo(canonicalName, argumentTypes); + if (functionInfo == null) { + return expression; + } + try { + FunctionInvoker invoker = new FunctionInvoker(functionInfo); + Object result; + if (invoker.getMethod().isVarArgs()) { + result = invoker.invoke(new Object[]{arguments}); + } else { + invoker.convertTypes(arguments); + result = invoker.invoke(arguments); } + return RequestUtils.getLiteralExpression(result); + } catch (Exception e) { + throw new SqlCompilationException( + "Caught exception while invoking method: " + functionInfo.getMethod() + " with arguments: " + Arrays.toString( + arguments), e); } - return expression; } } diff --git a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java index 369dd8b886..35a625505a 100644 --- a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java +++ b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java @@ -1054,8 +1054,7 @@ public class CalciteSqlCompilerTest { pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(2) .getLiteral().getStringValue(), "SECONDS"); Assert.assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getLiteral().getIntValue(), - 1394323200); + pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getLiteral().getIntValue(), 1394323200); } @Test @@ -1379,8 +1378,8 @@ public class CalciteSqlCompilerTest { Assert.fail("Query should have failed compilation"); } catch (Exception e) { Assert.assertTrue(e instanceof SqlCompilationException); - Assert.assertTrue(e.getMessage().contains("'group_city' should be functionally dependent on the columns " - + "used in GROUP BY clause.")); + Assert.assertTrue(e.getMessage() + .contains("'group_city' should be functionally dependent on the columns " + "used in GROUP BY clause.")); } // Valid groupBy non-aggregate function should pass. @@ -1398,8 +1397,8 @@ public class CalciteSqlCompilerTest { Assert.fail("Query should have failed compilation"); } catch (Exception e) { Assert.assertTrue(e instanceof SqlCompilationException); - Assert.assertTrue(e.getMessage().contains("'secondsSinceEpoch' should be functionally dependent on the columns " - + "used in GROUP BY clause.")); + Assert.assertTrue(e.getMessage().contains( + "'secondsSinceEpoch' should be functionally dependent on the columns " + "used in GROUP BY clause.")); } // Invalid groupBy clause shouldn't contain aggregate expression, like sum(rsvp_count), count(*). @@ -2331,14 +2330,10 @@ public class CalciteSqlCompilerTest { @Test public void testCompileTimeExpression() { - final CompileTimeFunctionsInvoker compileTimeFunctionsInvoker = new CompileTimeFunctionsInvoker(); long lowerBound = System.currentTimeMillis(); Expression expression = compileToExpression("now()"); Assert.assertNotNull(expression.getFunctionCall()); - PinotQuery pinotQuery = new PinotQuery(); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getLiteral()); long upperBound = System.currentTimeMillis(); long result = expression.getLiteral().getLongValue(); @@ -2347,10 +2342,7 @@ public class CalciteSqlCompilerTest { lowerBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1; expression = compileToExpression("to_epoch_hours(now() + 3600000)"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); - Assert.assertNotNull(expression.getLiteral()); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); upperBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1; result = expression.getLiteral().getLongValue(); Assert.assertTrue(result >= lowerBound && result <= upperBound); @@ -2358,9 +2350,7 @@ public class CalciteSqlCompilerTest { lowerBound = System.currentTimeMillis() - ONE_HOUR_IN_MS; expression = compileToExpression("ago('PT1H')"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getLiteral()); upperBound = System.currentTimeMillis() - ONE_HOUR_IN_MS; result = expression.getLiteral().getLongValue(); @@ -2369,9 +2359,7 @@ public class CalciteSqlCompilerTest { lowerBound = System.currentTimeMillis() + ONE_HOUR_IN_MS; expression = compileToExpression("ago('PT-1H')"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getLiteral()); upperBound = System.currentTimeMillis() + ONE_HOUR_IN_MS; result = expression.getLiteral().getLongValue(); @@ -2379,9 +2367,7 @@ public class CalciteSqlCompilerTest { expression = compileToExpression("toDateTime(millisSinceEpoch)"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getFunctionCall()); Assert.assertEquals(expression.getFunctionCall().getOperator(), "todatetime"); Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(), @@ -2389,88 +2375,105 @@ public class CalciteSqlCompilerTest { expression = compileToExpression("encodeUrl('key1=value 1&key2=value@!$2&key3=value%3')"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getLiteral()); Assert.assertEquals(expression.getLiteral().getStringValue(), "key1%3Dvalue+1%26key2%3Dvalue%40%21%242%26key3%3Dvalue%253"); expression = compileToExpression("decodeUrl('key1%3Dvalue+1%26key2%3Dvalue%40%21%242%26key3%3Dvalue%253')"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getLiteral()); Assert.assertEquals(expression.getLiteral().getStringValue(), "key1=value 1&key2=value@!$2&key3=value%3"); expression = compileToExpression("reverse(playerName)"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getFunctionCall()); Assert.assertEquals(expression.getFunctionCall().getOperator(), "reverse"); Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(), "playerName"); expression = compileToExpression("reverse('playerName')"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getLiteral()); Assert.assertEquals(expression.getLiteral().getStringValue(), "emaNreyalp"); expression = compileToExpression("reverse(123)"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getLiteral()); Assert.assertEquals(expression.getLiteral().getStringValue(), "321"); expression = compileToExpression("count(*)"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getFunctionCall()); Assert.assertEquals(expression.getFunctionCall().getOperator(), "count"); Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(), "*"); expression = compileToExpression("toBase64(toUtf8('hello!'))"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getLiteral()); Assert.assertEquals(expression.getLiteral().getStringValue(), "aGVsbG8h"); expression = compileToExpression("fromUtf8(fromBase64('aGVsbG8h'))"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getLiteral()); Assert.assertEquals(expression.getLiteral().getStringValue(), "hello!"); expression = compileToExpression("fromBase64(foo)"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getFunctionCall()); Assert.assertEquals(expression.getFunctionCall().getOperator(), "frombase64"); Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(), "foo"); expression = compileToExpression("toBase64(foo)"); Assert.assertNotNull(expression.getFunctionCall()); - pinotQuery.setFilterExpression(expression); - pinotQuery = compileTimeFunctionsInvoker.rewrite(pinotQuery); - expression = pinotQuery.getFilterExpression(); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); Assert.assertNotNull(expression.getFunctionCall()); Assert.assertEquals(expression.getFunctionCall().getOperator(), "tobase64"); Assert.assertEquals(expression.getFunctionCall().getOperands().get(0).getIdentifier().getName(), "foo"); + + expression = compileToExpression("'foo' > 'bar'"); + Assert.assertNotNull(expression.getFunctionCall()); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); + Assert.assertNotNull(expression.getLiteral()); + Assert.assertTrue(expression.getLiteral().getBoolValue()); + + expression = compileToExpression("toBase64(toUtf8('hello!')) = 'aGVsbG8h'"); + Assert.assertNotNull(expression.getFunctionCall()); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); + Assert.assertNotNull(expression.getLiteral()); + Assert.assertTrue(expression.getLiteral().getBoolValue()); + + expression = compileToExpression("fromUtf8(fromBase64('aGVsbG8h')) != 'hello!'"); + Assert.assertNotNull(expression.getFunctionCall()); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); + Assert.assertNotNull(expression.getLiteral()); + Assert.assertFalse(expression.getLiteral().getBoolValue()); + + expression = compileToExpression("123 < 123.000000000000000000001"); + Assert.assertNotNull(expression.getFunctionCall()); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); + Assert.assertNotNull(expression.getLiteral()); + Assert.assertFalse(expression.getLiteral().getBoolValue()); + + expression = compileToExpression("cast('123' as big_decimal) < cast('123.000000000000000000001' as big_decimal)"); + Assert.assertNotNull(expression.getFunctionCall()); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); + Assert.assertNotNull(expression.getLiteral()); + Assert.assertTrue(expression.getLiteral().getBoolValue()); + + // Should fall back to DOUBLE comparison + expression = compileToExpression("123 < cast('123.000000000000000000001' as big_decimal)"); + Assert.assertNotNull(expression.getFunctionCall()); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); + Assert.assertNotNull(expression.getLiteral()); + Assert.assertFalse(expression.getLiteral().getBoolValue()); } @Test @@ -2599,19 +2602,14 @@ public class CalciteSqlCompilerTest { String query = "SELECT col1 FROM foo GROUP BY col1, col2"; PinotQuery pinotQuery = compileToPinotQuery(query); Assert.assertEquals(pinotQuery.getSelectListSize(), 1); - Assert.assertEquals( - pinotQuery.getSelectList().get(0).getIdentifier().getName(), "col1"); - Assert.assertEquals( - pinotQuery.getGroupByList().get(0).getIdentifier().getName(), "col1"); - Assert.assertEquals( - pinotQuery.getGroupByList().get(1).getIdentifier().getName(), "col2"); + Assert.assertEquals(pinotQuery.getSelectList().get(0).getIdentifier().getName(), "col1"); + Assert.assertEquals(pinotQuery.getGroupByList().get(0).getIdentifier().getName(), "col1"); + Assert.assertEquals(pinotQuery.getGroupByList().get(1).getIdentifier().getName(), "col2"); query = "SELECT col1+col2 FROM foo GROUP BY col1,col2"; pinotQuery = compileToPinotQuery(query); Assert.assertEquals(pinotQuery.getSelectListSize(), 1); - Assert.assertEquals( - pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), - "plus"); + Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "plus"); Assert.assertEquals( pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getIdentifier().getName(), "col1"); Assert.assertEquals( @@ -3023,7 +3021,6 @@ public class CalciteSqlCompilerTest { public void testParserExtensionImpl() { String customSql = "INSERT INTO db.tbl FROM FILE 'file:///tmp/file1', FILE 'file:///tmp/file2'"; SqlNodeAndOptions sqlNodeAndOptions = CalciteSqlParser.compileToSqlNodeAndOptions(customSql); - ; Assert.assertTrue(sqlNodeAndOptions.getSqlNode() instanceof SqlInsertFromFile); Assert.assertEquals(sqlNodeAndOptions.getSqlType(), PinotSqlType.DML); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org