This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 19b79f406c Polymorphic binary arithmetic scalar functions (#14089) 19b79f406c is described below commit 19b79f406c36e8b075e9c1698a8752af5b4e6d23 Author: Yash Mayya <yash.ma...@gmail.com> AuthorDate: Tue Oct 1 01:32:38 2024 +0530 Polymorphic binary arithmetic scalar functions (#14089) --- .../function/scalar/ArithmeticFunctions.java | 15 -- .../scalar/arithmetic/MinusScalarFunction.java | 66 +++++++++ .../scalar/arithmetic/MultScalarFunction.java | 66 +++++++++ .../scalar/arithmetic/PlusScalarFunction.java | 66 +++++++++ .../PolymorphicBinaryArithmeticScalarFunction.java | 67 +++++++++ .../scalar/comparison/EqualsScalarFunction.java | 4 +- .../GreaterThanOrEqualScalarFunction.java | 9 +- .../comparison/GreaterThanScalarFunction.java | 4 +- .../comparison/LessThanOrEqualScalarFunction.java | 4 +- .../scalar/comparison/LessThanScalarFunction.java | 4 +- .../scalar/comparison/NotEqualsScalarFunction.java | 4 +- .../pinot/sql/parsers/CalciteSqlCompilerTest.java | 24 ++++ .../PostAggregationFunctionTest.java | 4 +- .../tests/OfflineClusterIntegrationTest.java | 151 +++++++++++++-------- .../pinot/calcite/sql/fun/PinotOperatorTable.java | 5 +- .../resources/queries/LiteralEvaluationPlans.json | 4 +- .../ExpressionTransformerTest.java | 2 +- 17 files changed, 406 insertions(+), 93 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java index 94489c92b1..27c4952b1f 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java @@ -30,21 +30,6 @@ public class ArithmeticFunctions { private ArithmeticFunctions() { } - @ScalarFunction(names = {"add", "plus"}) - public static double plus(double a, double b) { - return a + b; - } - - @ScalarFunction(names = {"sub", "minus"}) - public static double minus(double a, double b) { - return a - b; - } - - @ScalarFunction(names = {"mult", "times"}) - public static double times(double a, double b) { - return a * b; - } - @ScalarFunction(names = {"div", "divide"}) public static double divide(double a, double b) { return a / b; diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MinusScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MinusScalarFunction.java new file mode 100644 index 0000000000..61488e58e7 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MinusScalarFunction.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar.arithmetic; + +import java.util.EnumMap; +import java.util.Map; +import org.apache.pinot.common.function.FunctionInfo; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.spi.annotations.ScalarFunction; + + +@ScalarFunction(names = {"sub", "minus"}) +public class MinusScalarFunction extends PolymorphicBinaryArithmeticScalarFunction { + + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class); + + static { + try { + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG, + new FunctionInfo(MinusScalarFunction.class.getMethod("longMinus", long.class, long.class), + MinusScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE, + new FunctionInfo(MinusScalarFunction.class.getMethod("doubleMinus", double.class, double.class), + MinusScalarFunction.class, false)); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + @Override + protected FunctionInfo functionInfoForType(ColumnDataType argumentType) { + FunctionInfo functionInfo = TYPE_FUNCTION_INFO_MAP.get(argumentType); + + // Fall back to double based comparison by default + return functionInfo != null ? functionInfo : TYPE_FUNCTION_INFO_MAP.get(ColumnDataType.DOUBLE); + } + + @Override + public String getName() { + return "minus"; + } + + public static long longMinus(long a, long b) { + return a - b; + } + + public static double doubleMinus(double a, double b) { + return a - b; + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MultScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MultScalarFunction.java new file mode 100644 index 0000000000..a737045393 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MultScalarFunction.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar.arithmetic; + +import java.util.EnumMap; +import java.util.Map; +import org.apache.pinot.common.function.FunctionInfo; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.spi.annotations.ScalarFunction; + + +@ScalarFunction(names = {"mult", "times"}) +public class MultScalarFunction extends PolymorphicBinaryArithmeticScalarFunction { + + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class); + + static { + try { + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG, + new FunctionInfo(MultScalarFunction.class.getMethod("longMult", long.class, long.class), + MultScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE, + new FunctionInfo(MultScalarFunction.class.getMethod("doubleMult", double.class, double.class), + MultScalarFunction.class, false)); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + @Override + protected FunctionInfo functionInfoForType(ColumnDataType argumentType) { + FunctionInfo functionInfo = TYPE_FUNCTION_INFO_MAP.get(argumentType); + + // Fall back to double based comparison by default + return functionInfo != null ? functionInfo : TYPE_FUNCTION_INFO_MAP.get(ColumnDataType.DOUBLE); + } + + @Override + public String getName() { + return "mult"; + } + + public static long longMult(long a, long b) { + return a * b; + } + + public static double doubleMult(double a, double b) { + return a * b; + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PlusScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PlusScalarFunction.java new file mode 100644 index 0000000000..5951afa527 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PlusScalarFunction.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar.arithmetic; + +import java.util.EnumMap; +import java.util.Map; +import org.apache.pinot.common.function.FunctionInfo; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.spi.annotations.ScalarFunction; + + +@ScalarFunction(names = {"add", "plus"}) +public class PlusScalarFunction extends PolymorphicBinaryArithmeticScalarFunction { + + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class); + + static { + try { + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG, + new FunctionInfo(PlusScalarFunction.class.getMethod("longPlus", long.class, long.class), + PlusScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE, + new FunctionInfo(PlusScalarFunction.class.getMethod("doublePlus", double.class, double.class), + PlusScalarFunction.class, false)); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + @Override + protected FunctionInfo functionInfoForType(ColumnDataType argumentType) { + FunctionInfo functionInfo = TYPE_FUNCTION_INFO_MAP.get(argumentType); + + // Fall back to double based comparison by default + return functionInfo != null ? functionInfo : TYPE_FUNCTION_INFO_MAP.get(ColumnDataType.DOUBLE); + } + + @Override + public String getName() { + return "plus"; + } + + public static long longPlus(long a, long b) { + return a + b; + } + + public static double doublePlus(double a, double b) { + return a + b; + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PolymorphicBinaryArithmeticScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PolymorphicBinaryArithmeticScalarFunction.java new file mode 100644 index 0000000000..10167161f9 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PolymorphicBinaryArithmeticScalarFunction.java @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar.arithmetic; + +import javax.annotation.Nullable; +import org.apache.pinot.common.function.FunctionInfo; +import org.apache.pinot.common.function.PinotScalarFunction; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; + + +/** + * Base class for polymorphic binary arithmetic scalar functions + */ +public abstract class PolymorphicBinaryArithmeticScalarFunction implements PinotScalarFunction { + + @Nullable + @Override + public FunctionInfo getFunctionInfo(ColumnDataType[] argumentTypes) { + if (argumentTypes.length != 2) { + return null; + } + + return functionInfoForTypes(argumentTypes[0].getStoredType(), argumentTypes[1].getStoredType()); + } + + @Nullable + @Override + public FunctionInfo getFunctionInfo(int numArguments) { + if (numArguments != 2) { + return null; + } + + // For backward compatibility + return functionInfoForType(ColumnDataType.DOUBLE); + } + + private FunctionInfo functionInfoForTypes(ColumnDataType argumentType1, ColumnDataType argumentType2) { + if ((argumentType1 == ColumnDataType.LONG || argumentType1 == ColumnDataType.INT) && ( + argumentType2 == ColumnDataType.LONG || argumentType2 == ColumnDataType.INT)) { + return functionInfoForType(ColumnDataType.LONG); + } + + // Fall back to double based comparison by default + return functionInfoForType(ColumnDataType.DOUBLE); + } + + /** + * Get the binary arithmetic scalar function's {@link FunctionInfo} for the given argument type. + */ + protected abstract FunctionInfo functionInfoForType(ColumnDataType argumentType); +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java index 656722ccc8..0bc0fcb075 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java @@ -20,7 +20,7 @@ package org.apache.pinot.common.function.scalar.comparison; import java.math.BigDecimal; import java.util.Arrays; -import java.util.HashMap; +import java.util.EnumMap; import java.util.Map; import java.util.Objects; import org.apache.pinot.common.function.FunctionInfo; @@ -33,7 +33,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction; @ScalarFunction public class EqualsScalarFunction extends PolymorphicComparisonScalarFunction { - private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class); private static final FunctionInfo DOUBLE_EQUALS_WITH_TOLERANCE; static { diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java index cdf27b0f5e..d7782cf7e7 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java @@ -19,7 +19,7 @@ package org.apache.pinot.common.function.scalar.comparison; import java.math.BigDecimal; -import java.util.HashMap; +import java.util.EnumMap; import java.util.Map; import org.apache.pinot.common.function.FunctionInfo; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; @@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction; @ScalarFunction public class GreaterThanOrEqualScalarFunction extends PolymorphicComparisonScalarFunction { - private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class); static { try { @@ -51,9 +51,8 @@ public class GreaterThanOrEqualScalarFunction extends PolymorphicComparisonScala GreaterThanOrEqualScalarFunction.class.getMethod("doubleGreaterThanOrEqual", double.class, double.class), GreaterThanOrEqualScalarFunction.class, false)); TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BIG_DECIMAL, new FunctionInfo( - GreaterThanOrEqualScalarFunction.class.getMethod("bigDecimalGreaterThanOrEqual", - BigDecimal.class, BigDecimal.class), - GreaterThanOrEqualScalarFunction.class, false)); + GreaterThanOrEqualScalarFunction.class.getMethod("bigDecimalGreaterThanOrEqual", BigDecimal.class, + BigDecimal.class), GreaterThanOrEqualScalarFunction.class, false)); TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.STRING, new FunctionInfo( GreaterThanOrEqualScalarFunction.class.getMethod("stringGreaterThanOrEqual", String.class, String.class), GreaterThanOrEqualScalarFunction.class, false)); diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java index be8775f549..a41ddb6823 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java @@ -19,7 +19,7 @@ package org.apache.pinot.common.function.scalar.comparison; import java.math.BigDecimal; -import java.util.HashMap; +import java.util.EnumMap; import java.util.Map; import org.apache.pinot.common.function.FunctionInfo; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; @@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction; @ScalarFunction public class GreaterThanScalarFunction extends PolymorphicComparisonScalarFunction { - private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class); static { try { diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java index 941c1a6d56..7ff076744e 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java @@ -19,7 +19,7 @@ package org.apache.pinot.common.function.scalar.comparison; import java.math.BigDecimal; -import java.util.HashMap; +import java.util.EnumMap; import java.util.Map; import org.apache.pinot.common.function.FunctionInfo; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; @@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction; @ScalarFunction public class LessThanOrEqualScalarFunction extends PolymorphicComparisonScalarFunction { - private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class); static { try { diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java index e9d722370e..d2d85d9bbf 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java @@ -19,7 +19,7 @@ package org.apache.pinot.common.function.scalar.comparison; import java.math.BigDecimal; -import java.util.HashMap; +import java.util.EnumMap; import java.util.Map; import org.apache.pinot.common.function.FunctionInfo; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; @@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction; @ScalarFunction public class LessThanScalarFunction extends PolymorphicComparisonScalarFunction { - private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class); static { try { diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java index 7f63a1eb9e..8344514646 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java @@ -20,7 +20,7 @@ package org.apache.pinot.common.function.scalar.comparison; import java.math.BigDecimal; import java.util.Arrays; -import java.util.HashMap; +import java.util.EnumMap; import java.util.Map; import java.util.Objects; import org.apache.pinot.common.function.FunctionInfo; @@ -33,7 +33,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction; @ScalarFunction public class NotEqualsScalarFunction extends PolymorphicComparisonScalarFunction { - private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class); private static final FunctionInfo DOUBLE_NOT_EQUALS_WITH_TOLERANCE; static { diff --git a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java index 35a625505a..34e2a6b5f5 100644 --- a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java +++ b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java @@ -2339,6 +2339,30 @@ public class CalciteSqlCompilerTest { long result = expression.getLiteral().getLongValue(); Assert.assertTrue(result >= lowerBound && result <= upperBound); + expression = compileToExpression("now() - 0"); + Assert.assertNotNull(expression.getFunctionCall()); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); + Assert.assertNotNull(expression.getLiteral()); + upperBound = System.currentTimeMillis(); + result = expression.getLiteral().getLongValue(); + Assert.assertTrue(result >= lowerBound && result <= upperBound); + + expression = compileToExpression("now() + 0"); + Assert.assertNotNull(expression.getFunctionCall()); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); + Assert.assertNotNull(expression.getLiteral()); + upperBound = System.currentTimeMillis(); + result = expression.getLiteral().getLongValue(); + Assert.assertTrue(result >= lowerBound && result <= upperBound); + + expression = compileToExpression("now() * 1"); + Assert.assertNotNull(expression.getFunctionCall()); + expression = CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression); + Assert.assertNotNull(expression.getLiteral()); + upperBound = System.currentTimeMillis(); + result = expression.getLiteral().getLongValue(); + Assert.assertTrue(result >= lowerBound && result <= upperBound); + lowerBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1; expression = compileToExpression("to_epoch_hours(now() + 3600000)"); Assert.assertNotNull(expression.getFunctionCall()); diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java index 0c7b0e3e52..6f4cd02a29 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java @@ -37,8 +37,8 @@ public class PostAggregationFunctionTest { // Plus PostAggregationFunction function = new PostAggregationFunction("plus", new ColumnDataType[]{ColumnDataType.INT, ColumnDataType.LONG}); - assertEquals(function.getResultType(), ColumnDataType.DOUBLE); - assertEquals(function.invoke(new Object[]{1, 2L}), 3.0); + assertEquals(function.getResultType(), ColumnDataType.LONG); + assertEquals(function.invoke(new Object[]{1, 2L}), 3L); // Minus function = new PostAggregationFunction("MINUS", new ColumnDataType[]{ColumnDataType.FLOAT, ColumnDataType.DOUBLE}); diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java index be438702bf..2bcfcabef1 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java @@ -2040,54 +2040,61 @@ public class OfflineClusterIntegrationTest extends BaseClusterIntegrationTestSet assertEquals(row.get(0).asLong(), 16138 * 24); assertEquals(row.get(1).asLong(), 605); - if (useMultiStageQueryEngine) { - query = "SELECT add(DaysSinceEpoch,add(DaysSinceEpoch,15)), COUNT(*) FROM mytable " - + "GROUP BY add(DaysSinceEpoch,add(DaysSinceEpoch,15)) ORDER BY COUNT(*) DESC"; - } else { - query = "SELECT add(DaysSinceEpoch,DaysSinceEpoch,15), COUNT(*) FROM mytable " - + "GROUP BY add(DaysSinceEpoch,DaysSinceEpoch,15) ORDER BY COUNT(*) DESC"; - } + query = "SELECT arrayLength(DivAirports), COUNT(*) FROM mytable " + + "GROUP BY arrayLength(DivAirports) ORDER BY COUNT(*) DESC"; response = postQuery(query); resultTable = response.get("resultTable"); dataSchema = resultTable.get("dataSchema"); - assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]"); + assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]"); rows = resultTable.get("rows"); assertFalse(rows.isEmpty()); row = rows.get(0); assertEquals(row.size(), 2); - assertEquals(row.get(0).asDouble(), 16138.0 + 16138 + 15); - assertEquals(row.get(1).asLong(), 605); + assertEquals(row.get(0).asInt(), 5); + assertEquals(row.get(1).asLong(), 115545); - query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable " - + "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC"; + query = "SELECT arrayLength(valueIn(DivAirports,'DFW','ORD')), COUNT(*) FROM mytable GROUP BY " + + "arrayLength(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*) DESC"; response = postQuery(query); resultTable = response.get("resultTable"); dataSchema = resultTable.get("dataSchema"); - assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]"); + assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]"); rows = resultTable.get("rows"); - assertFalse(rows.isEmpty()); + assertEquals(rows.size(), 3); row = rows.get(0); assertEquals(row.size(), 2); - assertEquals(row.get(0).asDouble(), 16138.0 - 25); - assertEquals(row.get(1).asLong(), 605); + assertEquals(row.get(0).asInt(), 0); + assertEquals(row.get(1).asLong(), 114895); + row = rows.get(1); + assertEquals(row.size(), 2); + assertEquals(row.get(0).asInt(), 1); + assertEquals(row.get(1).asLong(), 648); + row = rows.get(2); + assertEquals(row.size(), 2); + assertEquals(row.get(0).asInt(), 2); + assertEquals(row.get(1).asLong(), 2); - if (useMultiStageQueryEngine) { - query = "SELECT mult(DaysSinceEpoch,mult(24,3600)), COUNT(*) FROM mytable " - + "GROUP BY mult(DaysSinceEpoch,mult(24,3600)) ORDER BY COUNT(*) DESC"; + if (useMultiStageQueryEngine()) { + query = "SELECT arrayToMV(valueIn(DivAirports,'DFW','ORD')), COUNT(*) FROM mytable " + + "GROUP BY arrayToMV(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*) DESC"; } else { - query = "SELECT mult(DaysSinceEpoch,24,3600), COUNT(*) FROM mytable " - + "GROUP BY mult(DaysSinceEpoch,24,3600) ORDER BY COUNT(*) DESC"; + query = "SELECT valueIn(DivAirports,'DFW','ORD'), COUNT(*) FROM mytable " + + "GROUP BY valueIn(DivAirports,'DFW','ORD') ORDER BY COUNT(*) DESC"; } response = postQuery(query); resultTable = response.get("resultTable"); dataSchema = resultTable.get("dataSchema"); - assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]"); + assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"STRING\",\"LONG\"]"); rows = resultTable.get("rows"); - assertFalse(rows.isEmpty()); + assertEquals(rows.size(), 2); row = rows.get(0); assertEquals(row.size(), 2); - assertEquals(row.get(0).asDouble(), 16138.0 * 24 * 3600); - assertEquals(row.get(1).asLong(), 605); + assertEquals(row.get(0).asText(), "ORD"); + assertEquals(row.get(1).asLong(), 336); + row = rows.get(1); + assertEquals(row.size(), 2); + assertEquals(row.get(0).asText(), "DFW"); + assertEquals(row.get(1).asLong(), 316); query = "SELECT div(DaysSinceEpoch,2), COUNT(*) FROM mytable " + "GROUP BY div(DaysSinceEpoch,2) ORDER BY COUNT(*) DESC"; @@ -2101,62 +2108,92 @@ public class OfflineClusterIntegrationTest extends BaseClusterIntegrationTestSet assertEquals(row.size(), 2); assertEquals(row.get(0).asDouble(), 16138.0 / 2); assertEquals(row.get(1).asLong(), 605); + } - query = "SELECT arrayLength(DivAirports), COUNT(*) FROM mytable " - + "GROUP BY arrayLength(DivAirports) ORDER BY COUNT(*) DESC"; + @Test + public void testGroupByUDFV1() throws Exception { + setUseMultiStageQueryEngine(false); + String query = "SELECT add(DaysSinceEpoch,DaysSinceEpoch,15), COUNT(*) FROM mytable " + + "GROUP BY add(DaysSinceEpoch,DaysSinceEpoch,15) ORDER BY COUNT(*) DESC"; + JsonNode response = postQuery(query); + JsonNode resultTable = response.get("resultTable"); + JsonNode dataSchema = resultTable.get("dataSchema"); + assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]"); + JsonNode rows = resultTable.get("rows"); + assertFalse(rows.isEmpty()); + JsonNode row = rows.get(0); + assertEquals(row.size(), 2); + assertEquals(row.get(0).asDouble(), 16138.0 + 16138 + 15); + assertEquals(row.get(1).asLong(), 605); + + query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable " + + "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC"; response = postQuery(query); resultTable = response.get("resultTable"); dataSchema = resultTable.get("dataSchema"); - assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]"); + assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]"); rows = resultTable.get("rows"); assertFalse(rows.isEmpty()); row = rows.get(0); assertEquals(row.size(), 2); - assertEquals(row.get(0).asInt(), 5); - assertEquals(row.get(1).asLong(), 115545); + assertEquals(row.get(0).asDouble(), 16138.0 - 25); + assertEquals(row.get(1).asLong(), 605); - query = "SELECT arrayLength(valueIn(DivAirports,'DFW','ORD')), COUNT(*) FROM mytable GROUP BY " - + "arrayLength(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*) DESC"; + query = "SELECT mult(DaysSinceEpoch,24,3600), COUNT(*) FROM mytable " + + "GROUP BY mult(DaysSinceEpoch,24,3600) ORDER BY COUNT(*) DESC"; response = postQuery(query); resultTable = response.get("resultTable"); dataSchema = resultTable.get("dataSchema"); - assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]"); + assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"DOUBLE\",\"LONG\"]"); rows = resultTable.get("rows"); - assertEquals(rows.size(), 3); + assertFalse(rows.isEmpty()); row = rows.get(0); assertEquals(row.size(), 2); - assertEquals(row.get(0).asInt(), 0); - assertEquals(row.get(1).asLong(), 114895); - row = rows.get(1); - assertEquals(row.size(), 2); - assertEquals(row.get(0).asInt(), 1); - assertEquals(row.get(1).asLong(), 648); - row = rows.get(2); + assertEquals(row.get(0).asDouble(), 16138.0 * 24 * 3600); + assertEquals(row.get(1).asLong(), 605); + } + + @Test + public void testGroupByUDFV2() throws Exception { + setUseMultiStageQueryEngine(true); + String query = "SELECT add(DaysSinceEpoch,add(DaysSinceEpoch,15)), COUNT(*) FROM mytable " + + "GROUP BY add(DaysSinceEpoch,add(DaysSinceEpoch,15)) ORDER BY COUNT(*) DESC"; + JsonNode response = postQuery(query); + JsonNode resultTable = response.get("resultTable"); + JsonNode dataSchema = resultTable.get("dataSchema"); + assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]"); + JsonNode rows = resultTable.get("rows"); + assertFalse(rows.isEmpty()); + JsonNode row = rows.get(0); assertEquals(row.size(), 2); - assertEquals(row.get(0).asInt(), 2); - assertEquals(row.get(1).asLong(), 2); + assertEquals(row.get(0).asInt(), 16138 + 16138 + 15); + assertEquals(row.get(1).asLong(), 605); - if (useMultiStageQueryEngine()) { - query = "SELECT arrayToMV(valueIn(DivAirports,'DFW','ORD')), COUNT(*) FROM mytable " - + "GROUP BY arrayToMV(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*) DESC"; - } else { - query = "SELECT valueIn(DivAirports,'DFW','ORD'), COUNT(*) FROM mytable " - + "GROUP BY valueIn(DivAirports,'DFW','ORD') ORDER BY COUNT(*) DESC"; - } + query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable " + + "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC"; response = postQuery(query); resultTable = response.get("resultTable"); dataSchema = resultTable.get("dataSchema"); - assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"STRING\",\"LONG\"]"); + assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]"); rows = resultTable.get("rows"); - assertEquals(rows.size(), 2); + assertFalse(rows.isEmpty()); row = rows.get(0); assertEquals(row.size(), 2); - assertEquals(row.get(0).asText(), "ORD"); - assertEquals(row.get(1).asLong(), 336); - row = rows.get(1); + assertEquals(row.get(0).asInt(), 16138 - 25); + assertEquals(row.get(1).asLong(), 605); + + query = "SELECT mult(DaysSinceEpoch,mult(24,3600)), COUNT(*) FROM mytable " + + "GROUP BY mult(DaysSinceEpoch,mult(24,3600)) ORDER BY COUNT(*) DESC"; + response = postQuery(query); + resultTable = response.get("resultTable"); + dataSchema = resultTable.get("dataSchema"); + assertEquals(dataSchema.get("columnDataTypes").toString(), "[\"INT\",\"LONG\"]"); + rows = resultTable.get("rows"); + assertFalse(rows.isEmpty()); + row = rows.get(0); assertEquals(row.size(), 2); - assertEquals(row.get(0).asText(), "DFW"); - assertEquals(row.get(1).asLong(), 316); + assertEquals(row.get(0).asInt(), 16138 * 24 * 3600); + assertEquals(row.get(1).asLong(), 605); } @Test diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java index 5e282544d2..0c1a8d8a48 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java @@ -223,7 +223,10 @@ public class PinotOperatorTable implements SqlOperatorTable { Pair.of(SqlStdOperatorTable.GREATER_THAN, List.of("GREATER_THAN")), Pair.of(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, List.of("GREATER_THAN_OR_EQUAL")), Pair.of(SqlStdOperatorTable.LESS_THAN, List.of("LESS_THAN")), - Pair.of(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, List.of("LESS_THAN_OR_EQUAL")) + Pair.of(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, List.of("LESS_THAN_OR_EQUAL")), + Pair.of(SqlStdOperatorTable.MINUS, List.of("SUB", "MINUS")), + Pair.of(SqlStdOperatorTable.PLUS, List.of("ADD", "PLUS")), + Pair.of(SqlStdOperatorTable.MULTIPLY, List.of("MULT", "TIMES")) ); /** diff --git a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json index 6298709bf5..8e513b76fa 100644 --- a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json +++ b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json @@ -15,7 +15,7 @@ "sql": "EXPLAIN PLAN FOR SELECT 5*6,5+6 FROM d", "output": [ "Execution Plan", - "\nLogicalProject(EXPR$0=[30.0], EXPR$1=[11.0])", + "\nLogicalProject(EXPR$0=[30], EXPR$1=[11])", "\n LogicalTableScan(table=[[default, d]])", "\n" ] @@ -175,7 +175,7 @@ "sql": "EXPLAIN PLAN FOR SELECT 1 + ToEpochDays(fromDateTime('1970-01-15', 'yyyy-MM-dd')) FROM a", "output": [ "Execution Plan", - "\nLogicalProject(EXPR$0=[15.0:BIGINT])", + "\nLogicalProject(EXPR$0=[15:BIGINT])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java index 55d8d7172f..58de9ec70c 100644 --- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java +++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java @@ -365,7 +365,7 @@ public class ExpressionTransformerTest { expressionTransformer.transform(genericRow); Assert.fail(); } catch (Exception e) { - Assert.assertEquals(e.getCause().getMessage(), "Caught exception while executing function: plus(x,'10')"); + Assert.assertTrue(e.getCause().getMessage().contains("Caught exception while executing function")); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org