This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new d0e041c18e Support polymorphic scalar comparison functions in the multi-stage query engine (#13711) d0e041c18e is described below commit d0e041c18ec1442031d69082b4e74d6ee3e25c97 Author: Yash Mayya <yash.ma...@gmail.com> AuthorDate: Tue Aug 20 02:09:53 2024 +0530 Support polymorphic scalar comparison functions in the multi-stage query engine (#13711) --- .../pinot/common/function/FunctionRegistry.java | 2 +- .../pinot/common/function/PinotScalarFunction.java | 4 +- .../function/scalar/ComparisonFunctions.java | 33 ---- .../scalar/comparison/EqualsScalarFunction.java | 176 +++++++++++++++++++++ .../GreaterThanOrEqualScalarFunction.java | 98 ++++++++++++ .../comparison/GreaterThanScalarFunction.java | 97 ++++++++++++ .../comparison/LessThanOrEqualScalarFunction.java | 98 ++++++++++++ .../scalar/comparison/LessThanScalarFunction.java | 97 ++++++++++++ .../scalar/comparison/NotEqualsScalarFunction.java | 176 +++++++++++++++++++++ .../PolymorphicComparisonScalarFunction.java | 66 ++++++++ .../tests/MultiStageEngineIntegrationTest.java | 98 ++++++++++++ .../pinot/calcite/sql/fun/PinotOperatorTable.java | 8 +- 12 files changed, 917 insertions(+), 36 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index a9d6f639b2..3e0da5733d 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -312,7 +312,7 @@ public class FunctionRegistry { private static SqlTypeFamily getSqlTypeFamily(Class<?> clazz) { // NOTE: Pinot allows some non-standard type conversions such as Timestamp <-> long, boolean <-> int etc. Do not // restrict the type family for now. We only restrict the type family for String so that cast can be added. - // Explicit cast is required to correctly convert boolean and Timestamp to String. Without explicit case, + // Explicit cast is required to correctly convert boolean and Timestamp to String. Without explicit cast, // BOOLEAN and TIMESTAMP type will be converted with their internal stored format which is INT and LONG // respectively. E.g. true will be converted to "1", timestamp will be converted to long value string. // TODO: Revisit this. diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/PinotScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/PinotScalarFunction.java index 7a87935bd9..6a2ed5e626 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/PinotScalarFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/PinotScalarFunction.java @@ -40,7 +40,9 @@ public interface PinotScalarFunction { * doesn't need to be registered (e.g. standard SqlFunction). */ @Nullable - PinotSqlFunction toPinotSqlFunction(); + default PinotSqlFunction toPinotSqlFunction() { + return null; + } /** * Returns the {@link FunctionInfo} for the given argument types, or {@code null} if there is no matching. diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java index 3a4eef70e8..00a829e09c 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java @@ -22,42 +22,9 @@ import org.apache.pinot.spi.annotations.ScalarFunction; public class ComparisonFunctions { - private static final double DOUBLE_COMPARISON_TOLERANCE = 1e-7d; - private ComparisonFunctions() { } - @ScalarFunction - public static boolean greaterThan(double a, double b) { - return a > b; - } - - @ScalarFunction - public static boolean greaterThanOrEqual(double a, double b) { - return a >= b; - } - - @ScalarFunction - public static boolean lessThan(double a, double b) { - return a < b; - } - - @ScalarFunction - public static boolean lessThanOrEqual(double a, double b) { - return a <= b; - } - - @ScalarFunction - public static boolean notEquals(double a, double b) { - return Math.abs(a - b) >= DOUBLE_COMPARISON_TOLERANCE; - } - - @ScalarFunction - public static boolean equals(double a, double b) { - // To avoid approximation errors - return Math.abs(a - b) < DOUBLE_COMPARISON_TOLERANCE; - } - @ScalarFunction public static boolean between(double val, double a, double b) { return val >= a && val <= b; diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java new file mode 100644 index 0000000000..75fc48320d --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java @@ -0,0 +1,176 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar.comparison; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import javax.annotation.Nullable; +import org.apache.pinot.common.function.FunctionInfo; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.spi.annotations.ScalarFunction; + +/** + * Polymorphic equals (=) scalar function implementation + */ +@ScalarFunction +public class EqualsScalarFunction extends PolymorphicComparisonScalarFunction { + + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + private static final FunctionInfo DOUBLE_EQUALS_WITH_TOLERANCE; + + static { + try { + DOUBLE_EQUALS_WITH_TOLERANCE = new FunctionInfo( + EqualsScalarFunction.class.getMethod("doubleEqualsWithTolerance", double.class, double.class), + EqualsScalarFunction.class, false); + + // Set nullable parameters to false for each function because the return value should be null if any argument + // is null + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.INT, new FunctionInfo( + EqualsScalarFunction.class.getMethod("intEquals", int.class, int.class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG, new FunctionInfo( + EqualsScalarFunction.class.getMethod("longEquals", long.class, long.class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.FLOAT, new FunctionInfo( + EqualsScalarFunction.class.getMethod("floatEquals", float.class, float.class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE, new FunctionInfo( + EqualsScalarFunction.class.getMethod("doubleEquals", double.class, double.class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BIG_DECIMAL, new FunctionInfo( + EqualsScalarFunction.class.getMethod("bigDecimalEquals", BigDecimal.class, BigDecimal.class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.STRING, new FunctionInfo( + EqualsScalarFunction.class.getMethod("stringEquals", String.class, String.class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BYTES, new FunctionInfo( + EqualsScalarFunction.class.getMethod("bytesEquals", byte[].class, byte[].class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.OBJECT, new FunctionInfo( + EqualsScalarFunction.class.getMethod("objectEquals", Object.class, Object.class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.INT_ARRAY, new FunctionInfo( + EqualsScalarFunction.class.getMethod("intArrayEquals", int[].class, int[].class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG_ARRAY, new FunctionInfo( + EqualsScalarFunction.class.getMethod("longArrayEquals", long[].class, long[].class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.FLOAT_ARRAY, new FunctionInfo( + EqualsScalarFunction.class.getMethod("floatArrayEquals", float[].class, float[].class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE_ARRAY, new FunctionInfo( + EqualsScalarFunction.class.getMethod("doubleArrayEquals", double[].class, double[].class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.STRING_ARRAY, new FunctionInfo( + EqualsScalarFunction.class.getMethod("stringArrayEquals", String[].class, String[].class), + EqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BYTES_ARRAY, new FunctionInfo( + EqualsScalarFunction.class.getMethod("bytesArrayEquals", byte[][].class, byte[][].class), + EqualsScalarFunction.class, false)); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + @Override + protected FunctionInfo functionInfoForType(ColumnDataType argumentType) { + return TYPE_FUNCTION_INFO_MAP.get(argumentType); + } + + @Nullable + @Override + public FunctionInfo getFunctionInfo(int numArguments) { + if (numArguments != 2) { + return null; + } + + // For backward compatibility + return DOUBLE_EQUALS_WITH_TOLERANCE; + } + + @Override + public String getName() { + return "equals"; + } + + public static boolean intEquals(int a, int b) { + return a == b; + } + + public static boolean longEquals(long a, long b) { + return a == b; + } + + public static boolean floatEquals(float a, float b) { + return a == b; + } + + public static boolean doubleEquals(double a, double b) { + return a == b; + } + + public static boolean doubleEqualsWithTolerance(double a, double b) { + // To avoid approximation errors + return Math.abs(a - b) < DOUBLE_COMPARISON_TOLERANCE; + } + + public static boolean bigDecimalEquals(BigDecimal a, BigDecimal b) { + return a.compareTo(b) == 0; + } + + public static boolean stringEquals(String a, String b) { + return a.equals(b); + } + + public static boolean bytesEquals(byte[] a, byte[] b) { + return Arrays.equals(a, b); + } + + public static boolean objectEquals(Object a, Object b) { + return Objects.equals(a, b); + } + + public static boolean intArrayEquals(int[] a, int[] b) { + return Arrays.equals(a, b); + } + + public static boolean longArrayEquals(long[] a, long[] b) { + return Arrays.equals(a, b); + } + + public static boolean floatArrayEquals(float[] a, float[] b) { + return Arrays.equals(a, b); + } + + public static boolean doubleArrayEquals(double[] a, double[] b) { + return Arrays.equals(a, b); + } + + public static boolean stringArrayEquals(String[] a, String[] b) { + return Arrays.equals(a, b); + } + + public static boolean bytesArrayEquals(byte[][] a, byte[][]b) { + return Arrays.deepEquals(a, b); + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java new file mode 100644 index 0000000000..cdf27b0f5e --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar.comparison; + +import java.math.BigDecimal; +import java.util.HashMap; +import java.util.Map; +import org.apache.pinot.common.function.FunctionInfo; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.spi.annotations.ScalarFunction; + + +/** + * Polymorphic greaterThanOrEqual (>=) scalar function implementation + */ +@ScalarFunction +public class GreaterThanOrEqualScalarFunction extends PolymorphicComparisonScalarFunction { + + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + + static { + try { + // Set nullable parameters to false for each function because the return value should be null if any argument + // is null + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.INT, new FunctionInfo( + GreaterThanOrEqualScalarFunction.class.getMethod("intGreaterThanOrEqual", int.class, int.class), + GreaterThanOrEqualScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG, new FunctionInfo( + GreaterThanOrEqualScalarFunction.class.getMethod("longGreaterThanOrEqual", long.class, long.class), + GreaterThanOrEqualScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.FLOAT, new FunctionInfo( + GreaterThanOrEqualScalarFunction.class.getMethod("floatGreaterThanOrEqual", float.class, float.class), + GreaterThanOrEqualScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE, new FunctionInfo( + GreaterThanOrEqualScalarFunction.class.getMethod("doubleGreaterThanOrEqual", double.class, double.class), + GreaterThanOrEqualScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BIG_DECIMAL, new FunctionInfo( + GreaterThanOrEqualScalarFunction.class.getMethod("bigDecimalGreaterThanOrEqual", + BigDecimal.class, BigDecimal.class), + GreaterThanOrEqualScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.STRING, new FunctionInfo( + GreaterThanOrEqualScalarFunction.class.getMethod("stringGreaterThanOrEqual", String.class, String.class), + GreaterThanOrEqualScalarFunction.class, false)); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + @Override + protected FunctionInfo functionInfoForType(ColumnDataType argumentType) { + return TYPE_FUNCTION_INFO_MAP.get(argumentType); + } + + @Override + public String getName() { + return "greaterThanOrEqual"; + } + + public static boolean intGreaterThanOrEqual(int a, int b) { + return a >= b; + } + + public static boolean longGreaterThanOrEqual(long a, long b) { + return a >= b; + } + + public static boolean floatGreaterThanOrEqual(float a, float b) { + return a >= b; + } + + public static boolean doubleGreaterThanOrEqual(double a, double b) { + return a >= b; + } + + public static boolean bigDecimalGreaterThanOrEqual(BigDecimal a, BigDecimal b) { + return a.compareTo(b) >= 0; + } + + public static boolean stringGreaterThanOrEqual(String a, String b) { + return a.compareTo(b) >= 0; + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java new file mode 100644 index 0000000000..be8775f549 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java @@ -0,0 +1,97 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar.comparison; + +import java.math.BigDecimal; +import java.util.HashMap; +import java.util.Map; +import org.apache.pinot.common.function.FunctionInfo; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.spi.annotations.ScalarFunction; + + +/** + * Polymorphic greaterThan (>) scalar function implementation + */ +@ScalarFunction +public class GreaterThanScalarFunction extends PolymorphicComparisonScalarFunction { + + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + + static { + try { + // Set nullable parameters to false for each function because the return value should be null if any argument + // is null + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.INT, new FunctionInfo( + GreaterThanScalarFunction.class.getMethod("intGreaterThan", int.class, int.class), + GreaterThanScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG, new FunctionInfo( + GreaterThanScalarFunction.class.getMethod("longGreaterThan", long.class, long.class), + GreaterThanScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.FLOAT, new FunctionInfo( + GreaterThanScalarFunction.class.getMethod("floatGreaterThan", float.class, float.class), + GreaterThanScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE, new FunctionInfo( + GreaterThanScalarFunction.class.getMethod("doubleGreaterThan", double.class, double.class), + GreaterThanScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BIG_DECIMAL, new FunctionInfo( + GreaterThanScalarFunction.class.getMethod("bigDecimalGreaterThan", BigDecimal.class, BigDecimal.class), + GreaterThanScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.STRING, new FunctionInfo( + GreaterThanScalarFunction.class.getMethod("stringGreaterThan", String.class, String.class), + GreaterThanScalarFunction.class, false)); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + @Override + protected FunctionInfo functionInfoForType(ColumnDataType argumentType) { + return TYPE_FUNCTION_INFO_MAP.get(argumentType); + } + + @Override + public String getName() { + return "greaterThan"; + } + + public static boolean intGreaterThan(int a, int b) { + return a > b; + } + + public static boolean longGreaterThan(long a, long b) { + return a > b; + } + + public static boolean floatGreaterThan(float a, float b) { + return a > b; + } + + public static boolean doubleGreaterThan(double a, double b) { + return a > b; + } + + public static boolean bigDecimalGreaterThan(BigDecimal a, BigDecimal b) { + return a.compareTo(b) > 0; + } + + public static boolean stringGreaterThan(String a, String b) { + return a.compareTo(b) > 0; + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java new file mode 100644 index 0000000000..941c1a6d56 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar.comparison; + +import java.math.BigDecimal; +import java.util.HashMap; +import java.util.Map; +import org.apache.pinot.common.function.FunctionInfo; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.spi.annotations.ScalarFunction; + + +/** + * Polymorphic lessThanOrEqual (<=) scalar function implementation + */ +@ScalarFunction +public class LessThanOrEqualScalarFunction extends PolymorphicComparisonScalarFunction { + + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + + static { + try { + // Set nullable parameters to false for each function because the return value should be null if any argument + // is null + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.INT, new FunctionInfo( + LessThanOrEqualScalarFunction.class.getMethod("intLessThanOrEqual", int.class, int.class), + LessThanOrEqualScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG, new FunctionInfo( + LessThanOrEqualScalarFunction.class.getMethod("longLessThanOrEqual", long.class, long.class), + LessThanOrEqualScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.FLOAT, new FunctionInfo( + LessThanOrEqualScalarFunction.class.getMethod("floatLessThanOrEqual", float.class, float.class), + LessThanOrEqualScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE, new FunctionInfo( + LessThanOrEqualScalarFunction.class.getMethod("doubleLessThanOrEqual", double.class, double.class), + LessThanOrEqualScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BIG_DECIMAL, new FunctionInfo( + LessThanOrEqualScalarFunction.class.getMethod("bigDecimalLessThanOrEqual", + BigDecimal.class, BigDecimal.class), + LessThanOrEqualScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.STRING, new FunctionInfo( + LessThanOrEqualScalarFunction.class.getMethod("stringLessThanOrEqual", String.class, String.class), + LessThanOrEqualScalarFunction.class, false)); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + @Override + protected FunctionInfo functionInfoForType(ColumnDataType argumentType) { + return TYPE_FUNCTION_INFO_MAP.get(argumentType); + } + + @Override + public String getName() { + return "lessThanOrEqual"; + } + + public static boolean intLessThanOrEqual(int a, int b) { + return a <= b; + } + + public static boolean longLessThanOrEqual(long a, long b) { + return a <= b; + } + + public static boolean floatLessThanOrEqual(float a, float b) { + return a <= b; + } + + public static boolean doubleLessThanOrEqual(double a, double b) { + return a <= b; + } + + public static boolean bigDecimalLessThanOrEqual(BigDecimal a, BigDecimal b) { + return a.compareTo(b) <= 0; + } + + public static boolean stringLessThanOrEqual(String a, String b) { + return a.compareTo(b) <= 0; + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java new file mode 100644 index 0000000000..e9d722370e --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java @@ -0,0 +1,97 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar.comparison; + +import java.math.BigDecimal; +import java.util.HashMap; +import java.util.Map; +import org.apache.pinot.common.function.FunctionInfo; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.spi.annotations.ScalarFunction; + + +/** + * Polymorphic lessThan (<) scalar function implementation + */ +@ScalarFunction +public class LessThanScalarFunction extends PolymorphicComparisonScalarFunction { + + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + + static { + try { + // Set nullable parameters to false for each function because the return value should be null if any argument + // is null + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.INT, new FunctionInfo( + LessThanScalarFunction.class.getMethod("intLessThan", int.class, int.class), + LessThanScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG, new FunctionInfo( + LessThanScalarFunction.class.getMethod("longLessThan", long.class, long.class), + LessThanScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.FLOAT, new FunctionInfo( + LessThanScalarFunction.class.getMethod("floatLessThan", float.class, float.class), + LessThanScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE, new FunctionInfo( + LessThanScalarFunction.class.getMethod("doubleLessThan", double.class, double.class), + LessThanScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BIG_DECIMAL, new FunctionInfo( + LessThanScalarFunction.class.getMethod("bigDecimalLessThan", BigDecimal.class, BigDecimal.class), + LessThanScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.STRING, new FunctionInfo( + LessThanScalarFunction.class.getMethod("stringLessThan", String.class, String.class), + LessThanScalarFunction.class, false)); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + @Override + protected FunctionInfo functionInfoForType(ColumnDataType argumentType) { + return TYPE_FUNCTION_INFO_MAP.get(argumentType); + } + + @Override + public String getName() { + return "lessThan"; + } + + public static boolean intLessThan(int a, int b) { + return a < b; + } + + public static boolean longLessThan(long a, long b) { + return a < b; + } + + public static boolean floatLessThan(float a, float b) { + return a < b; + } + + public static boolean doubleLessThan(double a, double b) { + return a < b; + } + + public static boolean bigDecimalLessThan(BigDecimal a, BigDecimal b) { + return a.compareTo(b) < 0; + } + + public static boolean stringLessThan(String a, String b) { + return a.compareTo(b) < 0; + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java new file mode 100644 index 0000000000..18d9052432 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java @@ -0,0 +1,176 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar.comparison; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import javax.annotation.Nullable; +import org.apache.pinot.common.function.FunctionInfo; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.spi.annotations.ScalarFunction; + +/** + * Polymorphic notEquals (!=) scalar function implementation + */ +@ScalarFunction +public class NotEqualsScalarFunction extends PolymorphicComparisonScalarFunction { + + private static final Map<ColumnDataType, FunctionInfo> TYPE_FUNCTION_INFO_MAP = new HashMap<>(); + private static final FunctionInfo DOUBLE_NOT_EQUALS_WITH_TOLERANCE; + + static { + try { + DOUBLE_NOT_EQUALS_WITH_TOLERANCE = new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("doubleNotEqualsWithTolerance", double.class, double.class), + NotEqualsScalarFunction.class, false); + + // Set nullable parameters to false for each function because the return value should be null if any argument + // is null + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.INT, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("intNotEquals", int.class, int.class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("longNotEquals", long.class, long.class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.FLOAT, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("floatNotEquals", float.class, float.class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("doubleNotEquals", double.class, double.class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BIG_DECIMAL, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("bigDecimalNotEquals", BigDecimal.class, BigDecimal.class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.STRING, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("stringNotEquals", String.class, String.class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BYTES, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("bytesNotEquals", byte[].class, byte[].class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.OBJECT, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("objectNotEquals", Object.class, Object.class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.INT_ARRAY, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("intArrayNotEquals", int[].class, int[].class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG_ARRAY, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("longArrayNotEquals", long[].class, long[].class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.FLOAT_ARRAY, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("floatArrayNotEquals", float[].class, float[].class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE_ARRAY, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("doubleArrayNotEquals", double[].class, double[].class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.STRING_ARRAY, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("stringArrayNotEquals", String[].class, String[].class), + NotEqualsScalarFunction.class, false)); + TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BYTES_ARRAY, new FunctionInfo( + NotEqualsScalarFunction.class.getMethod("bytesArrayNotEquals", byte[][].class, byte[][].class), + NotEqualsScalarFunction.class, false)); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + @Override + protected FunctionInfo functionInfoForType(ColumnDataType argumentType) { + return TYPE_FUNCTION_INFO_MAP.get(argumentType); + } + + @Nullable + @Override + public FunctionInfo getFunctionInfo(int numArguments) { + if (numArguments != 2) { + return null; + } + + // For backward compatibility + return DOUBLE_NOT_EQUALS_WITH_TOLERANCE; + } + + @Override + public String getName() { + return "notEquals"; + } + + public static boolean intNotEquals(int a, int b) { + return a != b; + } + + public static boolean longNotEquals(long a, long b) { + return a != b; + } + + public static boolean floatNotEquals(float a, float b) { + return a != b; + } + + public static boolean doubleNotEquals(double a, double b) { + return a != b; + } + + public static boolean doubleNotEqualsWithTolerance(double a, double b) { + // To avoid approximation errors + return Math.abs(a - b) >= DOUBLE_COMPARISON_TOLERANCE; + } + + public static boolean bigDecimalNotEquals(BigDecimal a, BigDecimal b) { + return a.compareTo(b) != 0; + } + + public static boolean stringNotEquals(String a, String b) { + return !a.equals(b); + } + + public static boolean bytesNotEquals(byte[] a, byte[] b) { + return !Arrays.equals(a, b); + } + + public static boolean objectNotEquals(Object a, Object b) { + return !Objects.equals(a, b); + } + + public static boolean intArrayNotEquals(int[] a, int[] b) { + return !Arrays.equals(a, b); + } + + public static boolean longArrayNotEquals(long[] a, long[] b) { + return !Arrays.equals(a, b); + } + + public static boolean floatArrayNotEquals(float[] a, float[] b) { + return !Arrays.equals(a, b); + } + + public static boolean doubleArrayNotEquals(double[] a, double[] b) { + return !Arrays.equals(a, b); + } + + public static boolean stringArrayNotEquals(String[] a, String[] b) { + return !Arrays.equals(a, b); + } + + public static boolean bytesArrayNotEquals(byte[][] a, byte[][]b) { + return !Arrays.deepEquals(a, b); + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/PolymorphicComparisonScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/PolymorphicComparisonScalarFunction.java new file mode 100644 index 0000000000..eb029544c8 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/PolymorphicComparisonScalarFunction.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar.comparison; + +import javax.annotation.Nullable; +import org.apache.pinot.common.function.FunctionInfo; +import org.apache.pinot.common.function.PinotScalarFunction; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; + + +/** + * Base class for polymorphic comparison scalar functions + */ +public abstract class PolymorphicComparisonScalarFunction implements PinotScalarFunction { + + protected static final double DOUBLE_COMPARISON_TOLERANCE = 1e-7d; + + @Nullable + @Override + public FunctionInfo getFunctionInfo(ColumnDataType[] argumentTypes) { + if (argumentTypes.length != 2) { + return null; + } + + // In case of heterogeneous argument types, fall back to double based comparison and allow FunctionInvoker to + // convert argument types for v1 engine support. + if (argumentTypes[0] != argumentTypes[1]) { + return functionInfoForType(ColumnDataType.DOUBLE); + } + + return functionInfoForType(argumentTypes[0].getStoredType()); + } + + @Nullable + @Override + public FunctionInfo getFunctionInfo(int numArguments) { + if (numArguments != 2) { + return null; + } + + // For backward compatibility + return functionInfoForType(ColumnDataType.DOUBLE); + } + + /** + * Get the comparison scalar function's {@link FunctionInfo} for the given argument type. Comparison scalar functions + * take two arguments of the same type. + */ + protected abstract FunctionInfo functionInfoForType(ColumnDataType argumentType); +} diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java index 602979fc57..57d96385cd 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java @@ -43,10 +43,13 @@ import org.apache.pinot.spi.data.MetricFieldSpec; import org.apache.pinot.spi.data.Schema; import org.apache.pinot.spi.utils.CommonConstants; import org.apache.pinot.util.TestUtils; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; import org.testng.Assert; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.BeforeMethod; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import static org.apache.pinot.common.function.scalar.StringFunctions.*; @@ -704,6 +707,101 @@ public class MultiStageEngineIntegrationTest extends BaseClusterIntegrationTestS assertEquals(jsonNode.get("numRowsResultSet").asInt(), 3); } + @Test(dataProvider = "polymorphicScalarComparisonFunctionsDataProvider") + public void testPolymorphicScalarComparisonFunctions(String type, String literal, String lesserLiteral, + Object expectedValue) throws Exception { + + // Queries written this way will trigger the PinotEvaluateLiteralRule which will call the scalar comparison function + // on the literals. Simpler queries like SELECT ... WHERE 'test' = 'test' will not trigger the optimization rule + // because the filter will be removed by Calcite in the SQL to Rel conversion phase even before the optimization + // rules are fired. + String sqlQueryPrefix = "WITH data as (SELECT " + literal + " as \"foo\" FROM mytable) " + + "SELECT * FROM data "; + + // Test equals + JsonNode result = postQuery(sqlQueryPrefix + "WHERE \"foo\" = " + literal); + assertNoError(result); + checkSingleColumnSameValueResult(result, DEFAULT_COUNT_STAR_RESULT, type, expectedValue); + + // Test not equals + result = postQuery(sqlQueryPrefix + "WHERE \"foo\" != " + lesserLiteral); + assertNoError(result); + checkSingleColumnSameValueResult(result, DEFAULT_COUNT_STAR_RESULT, type, expectedValue); + + // Test greater than + result = postQuery(sqlQueryPrefix + "WHERE \"foo\" > " + lesserLiteral); + assertNoError(result); + checkSingleColumnSameValueResult(result, DEFAULT_COUNT_STAR_RESULT, type, expectedValue); + + // Test greater than or equals + result = postQuery(sqlQueryPrefix + "WHERE \"foo\" >= " + lesserLiteral); + assertNoError(result); + checkSingleColumnSameValueResult(result, DEFAULT_COUNT_STAR_RESULT, type, expectedValue); + + // Test less than + result = postQuery(sqlQueryPrefix + "WHERE " + lesserLiteral + " < \"foo\""); + assertNoError(result); + checkSingleColumnSameValueResult(result, DEFAULT_COUNT_STAR_RESULT, type, expectedValue); + + // Test less than or equals + result = postQuery(sqlQueryPrefix + "WHERE " + lesserLiteral + " <= \"foo\""); + assertNoError(result); + checkSingleColumnSameValueResult(result, DEFAULT_COUNT_STAR_RESULT, type, expectedValue); + } + + @Test + public void testPolymorphicScalarComparisonFunctionsDifferentType() throws Exception { + // Don't support comparison for literals with different types + String sqlQueryPrefix = "WITH data as (SELECT 1 as \"foo\" FROM mytable) " + + "SELECT * FROM data WHERE \"foo\" "; + + JsonNode jsonNode = postQuery(sqlQueryPrefix + "= 'test'"); + assertFalse(jsonNode.get("exceptions").isEmpty()); + + jsonNode = postQuery(sqlQueryPrefix + "!= 'test'"); + assertFalse(jsonNode.get("exceptions").isEmpty()); + + jsonNode = postQuery(sqlQueryPrefix + "> 'test'"); + assertFalse(jsonNode.get("exceptions").isEmpty()); + + jsonNode = postQuery(sqlQueryPrefix + ">= 'test'"); + assertFalse(jsonNode.get("exceptions").isEmpty()); + + jsonNode = postQuery(sqlQueryPrefix + "< 'test'"); + assertFalse(jsonNode.get("exceptions").isEmpty()); + + jsonNode = postQuery(sqlQueryPrefix + "<= 'test'"); + assertFalse(jsonNode.get("exceptions").isEmpty()); + } + + /** + * Helper method to verify the result of a query that is assumed to return a single column with the same value for + * all the rows. Only the first row value is checked. + */ + private void checkSingleColumnSameValueResult(JsonNode result, long expectedRows, String type, + Object expectedValue) { + assertEquals(result.get("resultTable").get("dataSchema").get("columnDataTypes").size(), 1); + assertEquals(result.get("resultTable").get("dataSchema").get("columnDataTypes").get(0).asText(), type); + assertEquals(result.get("numRowsResultSet").asLong(), expectedRows); + assertEquals(result.get("resultTable").get("rows").get(0).get(0).asText(), expectedValue); + } + + @DataProvider(name = "polymorphicScalarComparisonFunctionsDataProvider") + Object[][] polymorphicScalarComparisonFunctionsDataProvider() { + List<Object[]> inputs = new ArrayList<>(); + + inputs.add(new Object[]{"STRING", "'test'", "'abc'", "test"}); + inputs.add(new Object[]{"INT", "1", "0", "1"}); + inputs.add(new Object[]{"LONG", "12345678999", "12345678998", "12345678999"}); + inputs.add(new Object[]{"FLOAT", "CAST(1.234 AS FLOAT)", "CAST(1.23 AS FLOAT)", "1.234"}); + inputs.add(new Object[]{"DOUBLE", "1.234", "1.23", "1.234"}); + inputs.add(new Object[]{"BOOLEAN", "CAST(true AS BOOLEAN)", "CAST(FALSE AS BOOLEAN)", "true"}); + inputs.add(new Object[]{"TIMESTAMP", "CAST(1723593600000 AS TIMESTAMP)", "CAST (1623593600000 AS TIMESTAMP)", + new DateTime(1723593600000L, DateTimeZone.getDefault()).toString("yyyy-MM-dd HH:mm:ss.S")}); + + return inputs.toArray(new Object[0][]); + } + @Test public void skipArrayToMvOptimization() throws Exception { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java index fd6f9e6d87..5e282544d2 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java @@ -217,7 +217,13 @@ public class PinotOperatorTable implements SqlOperatorTable { private static final List<Pair<SqlOperator, List<String>>> STANDARD_OPERATORS_WITH_ALIASES = List.of( Pair.of(SqlStdOperatorTable.CASE, List.of("CASE", "CASE_WHEN")), - Pair.of(SqlStdOperatorTable.LN, List.of("LN", "LOG")) + Pair.of(SqlStdOperatorTable.LN, List.of("LN", "LOG")), + Pair.of(SqlStdOperatorTable.EQUALS, List.of("EQUALS")), + Pair.of(SqlStdOperatorTable.NOT_EQUALS, List.of("NOT_EQUALS")), + Pair.of(SqlStdOperatorTable.GREATER_THAN, List.of("GREATER_THAN")), + Pair.of(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, List.of("GREATER_THAN_OR_EQUAL")), + Pair.of(SqlStdOperatorTable.LESS_THAN, List.of("LESS_THAN")), + Pair.of(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, List.of("LESS_THAN_OR_EQUAL")) ); /** --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org