This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 19b79f406c Polymorphic binary arithmetic scalar functions (#14089)
19b79f406c is described below

commit 19b79f406c36e8b075e9c1698a8752af5b4e6d23
Author: Yash Mayya <yash.ma...@gmail.com>
AuthorDate: Tue Oct 1 01:32:38 2024 +0530

    Polymorphic binary arithmetic scalar functions (#14089)
---
 .../function/scalar/ArithmeticFunctions.java       |  15 --
 .../scalar/arithmetic/MinusScalarFunction.java     |  66 +++++++++
 .../scalar/arithmetic/MultScalarFunction.java      |  66 +++++++++
 .../scalar/arithmetic/PlusScalarFunction.java      |  66 +++++++++
 .../PolymorphicBinaryArithmeticScalarFunction.java |  67 +++++++++
 .../scalar/comparison/EqualsScalarFunction.java    |   4 +-
 .../GreaterThanOrEqualScalarFunction.java          |   9 +-
 .../comparison/GreaterThanScalarFunction.java      |   4 +-
 .../comparison/LessThanOrEqualScalarFunction.java  |   4 +-
 .../scalar/comparison/LessThanScalarFunction.java  |   4 +-
 .../scalar/comparison/NotEqualsScalarFunction.java |   4 +-
 .../pinot/sql/parsers/CalciteSqlCompilerTest.java  |  24 ++++
 .../PostAggregationFunctionTest.java               |   4 +-
 .../tests/OfflineClusterIntegrationTest.java       | 151 +++++++++++++--------
 .../pinot/calcite/sql/fun/PinotOperatorTable.java  |   5 +-
 .../resources/queries/LiteralEvaluationPlans.json  |   4 +-
 .../ExpressionTransformerTest.java                 |   2 +-
 17 files changed, 406 insertions(+), 93 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
index 94489c92b1..27c4952b1f 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
@@ -30,21 +30,6 @@ public class ArithmeticFunctions {
   private ArithmeticFunctions() {
   }
 
-  @ScalarFunction(names = {"add", "plus"})
-  public static double plus(double a, double b) {
-    return a + b;
-  }
-
-  @ScalarFunction(names = {"sub", "minus"})
-  public static double minus(double a, double b) {
-    return a - b;
-  }
-
-  @ScalarFunction(names = {"mult", "times"})
-  public static double times(double a, double b) {
-    return a * b;
-  }
-
   @ScalarFunction(names = {"div", "divide"})
   public static double divide(double a, double b) {
     return a / b;
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MinusScalarFunction.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MinusScalarFunction.java
new file mode 100644
index 0000000000..61488e58e7
--- /dev/null
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MinusScalarFunction.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar.arithmetic;
+
+import java.util.EnumMap;
+import java.util.Map;
+import org.apache.pinot.common.function.FunctionInfo;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.spi.annotations.ScalarFunction;
+
+
+@ScalarFunction(names = {"sub", "minus"})
+public class MinusScalarFunction extends 
PolymorphicBinaryArithmeticScalarFunction {
+
+  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
+
+  static {
+    try {
+      TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG,
+          new FunctionInfo(MinusScalarFunction.class.getMethod("longMinus", 
long.class, long.class),
+              MinusScalarFunction.class, false));
+      TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE,
+          new FunctionInfo(MinusScalarFunction.class.getMethod("doubleMinus", 
double.class, double.class),
+              MinusScalarFunction.class, false));
+    } catch (NoSuchMethodException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Override
+  protected FunctionInfo functionInfoForType(ColumnDataType argumentType) {
+    FunctionInfo functionInfo = TYPE_FUNCTION_INFO_MAP.get(argumentType);
+
+    // Fall back to double based comparison by default
+    return functionInfo != null ? functionInfo : 
TYPE_FUNCTION_INFO_MAP.get(ColumnDataType.DOUBLE);
+  }
+
+  @Override
+  public String getName() {
+    return "minus";
+  }
+
+  public static long longMinus(long a, long b) {
+    return a - b;
+  }
+
+  public static double doubleMinus(double a, double b) {
+    return a - b;
+  }
+}
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MultScalarFunction.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MultScalarFunction.java
new file mode 100644
index 0000000000..a737045393
--- /dev/null
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MultScalarFunction.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar.arithmetic;
+
+import java.util.EnumMap;
+import java.util.Map;
+import org.apache.pinot.common.function.FunctionInfo;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.spi.annotations.ScalarFunction;
+
+
+@ScalarFunction(names = {"mult", "times"})
+public class MultScalarFunction extends 
PolymorphicBinaryArithmeticScalarFunction {
+
+  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
+
+  static {
+    try {
+      TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG,
+          new FunctionInfo(MultScalarFunction.class.getMethod("longMult", 
long.class, long.class),
+              MultScalarFunction.class, false));
+      TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE,
+          new FunctionInfo(MultScalarFunction.class.getMethod("doubleMult", 
double.class, double.class),
+              MultScalarFunction.class, false));
+    } catch (NoSuchMethodException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Override
+  protected FunctionInfo functionInfoForType(ColumnDataType argumentType) {
+    FunctionInfo functionInfo = TYPE_FUNCTION_INFO_MAP.get(argumentType);
+
+    // Fall back to double based comparison by default
+    return functionInfo != null ? functionInfo : 
TYPE_FUNCTION_INFO_MAP.get(ColumnDataType.DOUBLE);
+  }
+
+  @Override
+  public String getName() {
+    return "mult";
+  }
+
+  public static long longMult(long a, long b) {
+    return a * b;
+  }
+
+  public static double doubleMult(double a, double b) {
+    return a * b;
+  }
+}
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PlusScalarFunction.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PlusScalarFunction.java
new file mode 100644
index 0000000000..5951afa527
--- /dev/null
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PlusScalarFunction.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar.arithmetic;
+
+import java.util.EnumMap;
+import java.util.Map;
+import org.apache.pinot.common.function.FunctionInfo;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.spi.annotations.ScalarFunction;
+
+
+@ScalarFunction(names = {"add", "plus"})
+public class PlusScalarFunction extends 
PolymorphicBinaryArithmeticScalarFunction {
+
+  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
+
+  static {
+    try {
+      TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG,
+          new FunctionInfo(PlusScalarFunction.class.getMethod("longPlus", 
long.class, long.class),
+              PlusScalarFunction.class, false));
+      TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE,
+          new FunctionInfo(PlusScalarFunction.class.getMethod("doublePlus", 
double.class, double.class),
+              PlusScalarFunction.class, false));
+    } catch (NoSuchMethodException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Override
+  protected FunctionInfo functionInfoForType(ColumnDataType argumentType) {
+    FunctionInfo functionInfo = TYPE_FUNCTION_INFO_MAP.get(argumentType);
+
+    // Fall back to double based comparison by default
+    return functionInfo != null ? functionInfo : 
TYPE_FUNCTION_INFO_MAP.get(ColumnDataType.DOUBLE);
+  }
+
+  @Override
+  public String getName() {
+    return "plus";
+  }
+
+  public static long longPlus(long a, long b) {
+    return a + b;
+  }
+
+  public static double doublePlus(double a, double b) {
+    return a + b;
+  }
+}
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PolymorphicBinaryArithmeticScalarFunction.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PolymorphicBinaryArithmeticScalarFunction.java
new file mode 100644
index 0000000000..10167161f9
--- /dev/null
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PolymorphicBinaryArithmeticScalarFunction.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar.arithmetic;
+
+import javax.annotation.Nullable;
+import org.apache.pinot.common.function.FunctionInfo;
+import org.apache.pinot.common.function.PinotScalarFunction;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+
+
+/**
+ * Base class for polymorphic binary arithmetic scalar functions
+ */
+public abstract class PolymorphicBinaryArithmeticScalarFunction implements 
PinotScalarFunction {
+
+  @Nullable
+  @Override
+  public FunctionInfo getFunctionInfo(ColumnDataType[] argumentTypes) {
+    if (argumentTypes.length != 2) {
+      return null;
+    }
+
+    return functionInfoForTypes(argumentTypes[0].getStoredType(), 
argumentTypes[1].getStoredType());
+  }
+
+  @Nullable
+  @Override
+  public FunctionInfo getFunctionInfo(int numArguments) {
+    if (numArguments != 2) {
+      return null;
+    }
+
+    // For backward compatibility
+    return functionInfoForType(ColumnDataType.DOUBLE);
+  }
+
+  private FunctionInfo functionInfoForTypes(ColumnDataType argumentType1, 
ColumnDataType argumentType2) {
+    if ((argumentType1 == ColumnDataType.LONG || argumentType1 == 
ColumnDataType.INT) && (
+        argumentType2 == ColumnDataType.LONG || argumentType2 == 
ColumnDataType.INT)) {
+      return functionInfoForType(ColumnDataType.LONG);
+    }
+
+    // Fall back to double based comparison by default
+    return functionInfoForType(ColumnDataType.DOUBLE);
+  }
+
+  /**
+   * Get the binary arithmetic scalar function's {@link FunctionInfo} for the 
given argument type.
+   */
+  protected abstract FunctionInfo functionInfoForType(ColumnDataType 
argumentType);
+}
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java
index 656722ccc8..0bc0fcb075 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java
@@ -20,7 +20,7 @@ package org.apache.pinot.common.function.scalar.comparison;
 
 import java.math.BigDecimal;
 import java.util.Arrays;
-import java.util.HashMap;
+import java.util.EnumMap;
 import java.util.Map;
 import java.util.Objects;
 import org.apache.pinot.common.function.FunctionInfo;
@@ -33,7 +33,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
 @ScalarFunction
 public class EqualsScalarFunction extends PolymorphicComparisonScalarFunction {
 
-  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
   private static final FunctionInfo DOUBLE_EQUALS_WITH_TOLERANCE;
 
   static {
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java
index cdf27b0f5e..d7782cf7e7 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java
@@ -19,7 +19,7 @@
 package org.apache.pinot.common.function.scalar.comparison;
 
 import java.math.BigDecimal;
-import java.util.HashMap;
+import java.util.EnumMap;
 import java.util.Map;
 import org.apache.pinot.common.function.FunctionInfo;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
@@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
 @ScalarFunction
 public class GreaterThanOrEqualScalarFunction extends 
PolymorphicComparisonScalarFunction {
 
-  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
 
   static {
     try {
@@ -51,9 +51,8 @@ public class GreaterThanOrEqualScalarFunction extends 
PolymorphicComparisonScala
           
GreaterThanOrEqualScalarFunction.class.getMethod("doubleGreaterThanOrEqual", 
double.class, double.class),
           GreaterThanOrEqualScalarFunction.class, false));
       TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BIG_DECIMAL, new FunctionInfo(
-          
GreaterThanOrEqualScalarFunction.class.getMethod("bigDecimalGreaterThanOrEqual",
-              BigDecimal.class, BigDecimal.class),
-          GreaterThanOrEqualScalarFunction.class, false));
+          
GreaterThanOrEqualScalarFunction.class.getMethod("bigDecimalGreaterThanOrEqual",
 BigDecimal.class,
+              BigDecimal.class), GreaterThanOrEqualScalarFunction.class, 
false));
       TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.STRING, new FunctionInfo(
           
GreaterThanOrEqualScalarFunction.class.getMethod("stringGreaterThanOrEqual", 
String.class, String.class),
           GreaterThanOrEqualScalarFunction.class, false));
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java
index be8775f549..a41ddb6823 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java
@@ -19,7 +19,7 @@
 package org.apache.pinot.common.function.scalar.comparison;
 
 import java.math.BigDecimal;
-import java.util.HashMap;
+import java.util.EnumMap;
 import java.util.Map;
 import org.apache.pinot.common.function.FunctionInfo;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
@@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
 @ScalarFunction
 public class GreaterThanScalarFunction extends 
PolymorphicComparisonScalarFunction {
 
-  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
 
   static {
     try {
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java
index 941c1a6d56..7ff076744e 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java
@@ -19,7 +19,7 @@
 package org.apache.pinot.common.function.scalar.comparison;
 
 import java.math.BigDecimal;
-import java.util.HashMap;
+import java.util.EnumMap;
 import java.util.Map;
 import org.apache.pinot.common.function.FunctionInfo;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
@@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
 @ScalarFunction
 public class LessThanOrEqualScalarFunction extends 
PolymorphicComparisonScalarFunction {
 
-  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
 
   static {
     try {
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java
index e9d722370e..d2d85d9bbf 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java
@@ -19,7 +19,7 @@
 package org.apache.pinot.common.function.scalar.comparison;
 
 import java.math.BigDecimal;
-import java.util.HashMap;
+import java.util.EnumMap;
 import java.util.Map;
 import org.apache.pinot.common.function.FunctionInfo;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
@@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
 @ScalarFunction
 public class LessThanScalarFunction extends 
PolymorphicComparisonScalarFunction {
 
-  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
 
   static {
     try {
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java
index 7f63a1eb9e..8344514646 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java
@@ -20,7 +20,7 @@ package org.apache.pinot.common.function.scalar.comparison;
 
 import java.math.BigDecimal;
 import java.util.Arrays;
-import java.util.HashMap;
+import java.util.EnumMap;
 import java.util.Map;
 import java.util.Objects;
 import org.apache.pinot.common.function.FunctionInfo;
@@ -33,7 +33,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
 @ScalarFunction
 public class NotEqualsScalarFunction extends 
PolymorphicComparisonScalarFunction {
 
-  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+  private static final Map<ColumnDataType, FunctionInfo> 
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
   private static final FunctionInfo DOUBLE_NOT_EQUALS_WITH_TOLERANCE;
 
   static {
diff --git 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index 35a625505a..34e2a6b5f5 100644
--- 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++ 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -2339,6 +2339,30 @@ public class CalciteSqlCompilerTest {
     long result = expression.getLiteral().getLongValue();
     Assert.assertTrue(result >= lowerBound && result <= upperBound);
 
+    expression = compileToExpression("now() - 0");
+    Assert.assertNotNull(expression.getFunctionCall());
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+    Assert.assertNotNull(expression.getLiteral());
+    upperBound = System.currentTimeMillis();
+    result = expression.getLiteral().getLongValue();
+    Assert.assertTrue(result >= lowerBound && result <= upperBound);
+
+    expression = compileToExpression("now() + 0");
+    Assert.assertNotNull(expression.getFunctionCall());
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+    Assert.assertNotNull(expression.getLiteral());
+    upperBound = System.currentTimeMillis();
+    result = expression.getLiteral().getLongValue();
+    Assert.assertTrue(result >= lowerBound && result <= upperBound);
+
+    expression = compileToExpression("now() * 1");
+    Assert.assertNotNull(expression.getFunctionCall());
+    expression = 
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+    Assert.assertNotNull(expression.getLiteral());
+    upperBound = System.currentTimeMillis();
+    result = expression.getLiteral().getLongValue();
+    Assert.assertTrue(result >= lowerBound && result <= upperBound);
+
     lowerBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1;
     expression = compileToExpression("to_epoch_hours(now() + 3600000)");
     Assert.assertNotNull(expression.getFunctionCall());
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
index 0c7b0e3e52..6f4cd02a29 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
@@ -37,8 +37,8 @@ public class PostAggregationFunctionTest {
     // Plus
     PostAggregationFunction function =
         new PostAggregationFunction("plus", new 
ColumnDataType[]{ColumnDataType.INT, ColumnDataType.LONG});
-    assertEquals(function.getResultType(), ColumnDataType.DOUBLE);
-    assertEquals(function.invoke(new Object[]{1, 2L}), 3.0);
+    assertEquals(function.getResultType(), ColumnDataType.LONG);
+    assertEquals(function.invoke(new Object[]{1, 2L}), 3L);
 
     // Minus
     function = new PostAggregationFunction("MINUS", new 
ColumnDataType[]{ColumnDataType.FLOAT, ColumnDataType.DOUBLE});
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
index be438702bf..2bcfcabef1 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
@@ -2040,54 +2040,61 @@ public class OfflineClusterIntegrationTest extends 
BaseClusterIntegrationTestSet
     assertEquals(row.get(0).asLong(), 16138 * 24);
     assertEquals(row.get(1).asLong(), 605);
 
-    if (useMultiStageQueryEngine) {
-      query = "SELECT add(DaysSinceEpoch,add(DaysSinceEpoch,15)), COUNT(*) 
FROM mytable "
-          + "GROUP BY add(DaysSinceEpoch,add(DaysSinceEpoch,15)) ORDER BY 
COUNT(*) DESC";
-    } else {
-      query = "SELECT add(DaysSinceEpoch,DaysSinceEpoch,15), COUNT(*) FROM 
mytable "
-          + "GROUP BY add(DaysSinceEpoch,DaysSinceEpoch,15) ORDER BY COUNT(*) 
DESC";
-    }
+    query = "SELECT arrayLength(DivAirports), COUNT(*) FROM mytable "
+        + "GROUP BY arrayLength(DivAirports) ORDER BY COUNT(*) DESC";
     response = postQuery(query);
     resultTable = response.get("resultTable");
     dataSchema = resultTable.get("dataSchema");
-    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"DOUBLE\",\"LONG\"]");
+    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"INT\",\"LONG\"]");
     rows = resultTable.get("rows");
     assertFalse(rows.isEmpty());
     row = rows.get(0);
     assertEquals(row.size(), 2);
-    assertEquals(row.get(0).asDouble(), 16138.0 + 16138 + 15);
-    assertEquals(row.get(1).asLong(), 605);
+    assertEquals(row.get(0).asInt(), 5);
+    assertEquals(row.get(1).asLong(), 115545);
 
-    query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable "
-        + "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC";
+    query = "SELECT arrayLength(valueIn(DivAirports,'DFW','ORD')), COUNT(*) 
FROM mytable GROUP BY "
+        + "arrayLength(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*) 
DESC";
     response = postQuery(query);
     resultTable = response.get("resultTable");
     dataSchema = resultTable.get("dataSchema");
-    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"DOUBLE\",\"LONG\"]");
+    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"INT\",\"LONG\"]");
     rows = resultTable.get("rows");
-    assertFalse(rows.isEmpty());
+    assertEquals(rows.size(), 3);
     row = rows.get(0);
     assertEquals(row.size(), 2);
-    assertEquals(row.get(0).asDouble(), 16138.0 - 25);
-    assertEquals(row.get(1).asLong(), 605);
+    assertEquals(row.get(0).asInt(), 0);
+    assertEquals(row.get(1).asLong(), 114895);
+    row = rows.get(1);
+    assertEquals(row.size(), 2);
+    assertEquals(row.get(0).asInt(), 1);
+    assertEquals(row.get(1).asLong(), 648);
+    row = rows.get(2);
+    assertEquals(row.size(), 2);
+    assertEquals(row.get(0).asInt(), 2);
+    assertEquals(row.get(1).asLong(), 2);
 
-    if (useMultiStageQueryEngine) {
-      query = "SELECT mult(DaysSinceEpoch,mult(24,3600)), COUNT(*) FROM 
mytable "
-          + "GROUP BY mult(DaysSinceEpoch,mult(24,3600)) ORDER BY COUNT(*) 
DESC";
+    if (useMultiStageQueryEngine()) {
+      query = "SELECT arrayToMV(valueIn(DivAirports,'DFW','ORD')), COUNT(*) 
FROM mytable "
+          + "GROUP BY arrayToMV(valueIn(DivAirports,'DFW','ORD')) ORDER BY 
COUNT(*) DESC";
     } else {
-      query = "SELECT mult(DaysSinceEpoch,24,3600), COUNT(*) FROM mytable "
-          + "GROUP BY mult(DaysSinceEpoch,24,3600) ORDER BY COUNT(*) DESC";
+      query = "SELECT valueIn(DivAirports,'DFW','ORD'), COUNT(*) FROM mytable "
+          + "GROUP BY valueIn(DivAirports,'DFW','ORD') ORDER BY COUNT(*) DESC";
     }
     response = postQuery(query);
     resultTable = response.get("resultTable");
     dataSchema = resultTable.get("dataSchema");
-    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"DOUBLE\",\"LONG\"]");
+    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"STRING\",\"LONG\"]");
     rows = resultTable.get("rows");
-    assertFalse(rows.isEmpty());
+    assertEquals(rows.size(), 2);
     row = rows.get(0);
     assertEquals(row.size(), 2);
-    assertEquals(row.get(0).asDouble(), 16138.0 * 24 * 3600);
-    assertEquals(row.get(1).asLong(), 605);
+    assertEquals(row.get(0).asText(), "ORD");
+    assertEquals(row.get(1).asLong(), 336);
+    row = rows.get(1);
+    assertEquals(row.size(), 2);
+    assertEquals(row.get(0).asText(), "DFW");
+    assertEquals(row.get(1).asLong(), 316);
 
     query = "SELECT div(DaysSinceEpoch,2), COUNT(*) FROM mytable "
         + "GROUP BY div(DaysSinceEpoch,2) ORDER BY COUNT(*) DESC";
@@ -2101,62 +2108,92 @@ public class OfflineClusterIntegrationTest extends 
BaseClusterIntegrationTestSet
     assertEquals(row.size(), 2);
     assertEquals(row.get(0).asDouble(), 16138.0 / 2);
     assertEquals(row.get(1).asLong(), 605);
+  }
 
-    query = "SELECT arrayLength(DivAirports), COUNT(*) FROM mytable "
-        + "GROUP BY arrayLength(DivAirports) ORDER BY COUNT(*) DESC";
+  @Test
+  public void testGroupByUDFV1() throws Exception {
+    setUseMultiStageQueryEngine(false);
+    String query = "SELECT add(DaysSinceEpoch,DaysSinceEpoch,15), COUNT(*) 
FROM mytable "
+        + "GROUP BY add(DaysSinceEpoch,DaysSinceEpoch,15) ORDER BY COUNT(*) 
DESC";
+    JsonNode response = postQuery(query);
+    JsonNode resultTable = response.get("resultTable");
+    JsonNode dataSchema = resultTable.get("dataSchema");
+    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"DOUBLE\",\"LONG\"]");
+    JsonNode rows = resultTable.get("rows");
+    assertFalse(rows.isEmpty());
+    JsonNode row = rows.get(0);
+    assertEquals(row.size(), 2);
+    assertEquals(row.get(0).asDouble(), 16138.0 + 16138 + 15);
+    assertEquals(row.get(1).asLong(), 605);
+
+    query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable "
+        + "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC";
     response = postQuery(query);
     resultTable = response.get("resultTable");
     dataSchema = resultTable.get("dataSchema");
-    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"INT\",\"LONG\"]");
+    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"DOUBLE\",\"LONG\"]");
     rows = resultTable.get("rows");
     assertFalse(rows.isEmpty());
     row = rows.get(0);
     assertEquals(row.size(), 2);
-    assertEquals(row.get(0).asInt(), 5);
-    assertEquals(row.get(1).asLong(), 115545);
+    assertEquals(row.get(0).asDouble(), 16138.0 - 25);
+    assertEquals(row.get(1).asLong(), 605);
 
-    query = "SELECT arrayLength(valueIn(DivAirports,'DFW','ORD')), COUNT(*) 
FROM mytable GROUP BY "
-        + "arrayLength(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*) 
DESC";
+    query = "SELECT mult(DaysSinceEpoch,24,3600), COUNT(*) FROM mytable "
+        + "GROUP BY mult(DaysSinceEpoch,24,3600) ORDER BY COUNT(*) DESC";
     response = postQuery(query);
     resultTable = response.get("resultTable");
     dataSchema = resultTable.get("dataSchema");
-    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"INT\",\"LONG\"]");
+    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"DOUBLE\",\"LONG\"]");
     rows = resultTable.get("rows");
-    assertEquals(rows.size(), 3);
+    assertFalse(rows.isEmpty());
     row = rows.get(0);
     assertEquals(row.size(), 2);
-    assertEquals(row.get(0).asInt(), 0);
-    assertEquals(row.get(1).asLong(), 114895);
-    row = rows.get(1);
-    assertEquals(row.size(), 2);
-    assertEquals(row.get(0).asInt(), 1);
-    assertEquals(row.get(1).asLong(), 648);
-    row = rows.get(2);
+    assertEquals(row.get(0).asDouble(), 16138.0 * 24 * 3600);
+    assertEquals(row.get(1).asLong(), 605);
+  }
+
+  @Test
+  public void testGroupByUDFV2() throws Exception {
+    setUseMultiStageQueryEngine(true);
+    String query = "SELECT add(DaysSinceEpoch,add(DaysSinceEpoch,15)), 
COUNT(*) FROM mytable "
+        + "GROUP BY add(DaysSinceEpoch,add(DaysSinceEpoch,15)) ORDER BY 
COUNT(*) DESC";
+    JsonNode response = postQuery(query);
+    JsonNode resultTable = response.get("resultTable");
+    JsonNode dataSchema = resultTable.get("dataSchema");
+    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"INT\",\"LONG\"]");
+    JsonNode rows = resultTable.get("rows");
+    assertFalse(rows.isEmpty());
+    JsonNode row = rows.get(0);
     assertEquals(row.size(), 2);
-    assertEquals(row.get(0).asInt(), 2);
-    assertEquals(row.get(1).asLong(), 2);
+    assertEquals(row.get(0).asInt(), 16138 + 16138 + 15);
+    assertEquals(row.get(1).asLong(), 605);
 
-    if (useMultiStageQueryEngine()) {
-      query = "SELECT arrayToMV(valueIn(DivAirports,'DFW','ORD')), COUNT(*) 
FROM mytable "
-          + "GROUP BY arrayToMV(valueIn(DivAirports,'DFW','ORD')) ORDER BY 
COUNT(*) DESC";
-    } else {
-      query = "SELECT valueIn(DivAirports,'DFW','ORD'), COUNT(*) FROM mytable "
-          + "GROUP BY valueIn(DivAirports,'DFW','ORD') ORDER BY COUNT(*) DESC";
-    }
+    query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable "
+        + "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC";
     response = postQuery(query);
     resultTable = response.get("resultTable");
     dataSchema = resultTable.get("dataSchema");
-    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"STRING\",\"LONG\"]");
+    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"INT\",\"LONG\"]");
     rows = resultTable.get("rows");
-    assertEquals(rows.size(), 2);
+    assertFalse(rows.isEmpty());
     row = rows.get(0);
     assertEquals(row.size(), 2);
-    assertEquals(row.get(0).asText(), "ORD");
-    assertEquals(row.get(1).asLong(), 336);
-    row = rows.get(1);
+    assertEquals(row.get(0).asInt(), 16138 - 25);
+    assertEquals(row.get(1).asLong(), 605);
+
+    query = "SELECT mult(DaysSinceEpoch,mult(24,3600)), COUNT(*) FROM mytable "
+        + "GROUP BY mult(DaysSinceEpoch,mult(24,3600)) ORDER BY COUNT(*) DESC";
+    response = postQuery(query);
+    resultTable = response.get("resultTable");
+    dataSchema = resultTable.get("dataSchema");
+    assertEquals(dataSchema.get("columnDataTypes").toString(), 
"[\"INT\",\"LONG\"]");
+    rows = resultTable.get("rows");
+    assertFalse(rows.isEmpty());
+    row = rows.get(0);
     assertEquals(row.size(), 2);
-    assertEquals(row.get(0).asText(), "DFW");
-    assertEquals(row.get(1).asLong(), 316);
+    assertEquals(row.get(0).asInt(), 16138 * 24 * 3600);
+    assertEquals(row.get(1).asLong(), 605);
   }
 
   @Test
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java
index 5e282544d2..0c1a8d8a48 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java
@@ -223,7 +223,10 @@ public class PinotOperatorTable implements 
SqlOperatorTable {
       Pair.of(SqlStdOperatorTable.GREATER_THAN, List.of("GREATER_THAN")),
       Pair.of(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, 
List.of("GREATER_THAN_OR_EQUAL")),
       Pair.of(SqlStdOperatorTable.LESS_THAN, List.of("LESS_THAN")),
-      Pair.of(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, 
List.of("LESS_THAN_OR_EQUAL"))
+      Pair.of(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, 
List.of("LESS_THAN_OR_EQUAL")),
+      Pair.of(SqlStdOperatorTable.MINUS, List.of("SUB", "MINUS")),
+      Pair.of(SqlStdOperatorTable.PLUS, List.of("ADD", "PLUS")),
+      Pair.of(SqlStdOperatorTable.MULTIPLY, List.of("MULT", "TIMES"))
   );
 
   /**
diff --git 
a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json 
b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
index 6298709bf5..8e513b76fa 100644
--- a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
@@ -15,7 +15,7 @@
         "sql": "EXPLAIN PLAN FOR SELECT 5*6,5+6 FROM d",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(EXPR$0=[30.0], EXPR$1=[11.0])",
+          "\nLogicalProject(EXPR$0=[30], EXPR$1=[11])",
           "\n  LogicalTableScan(table=[[default, d]])",
           "\n"
         ]
@@ -175,7 +175,7 @@
         "sql": "EXPLAIN PLAN FOR SELECT 1 + 
ToEpochDays(fromDateTime('1970-01-15', 'yyyy-MM-dd')) FROM a",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(EXPR$0=[15.0:BIGINT])",
+          "\nLogicalProject(EXPR$0=[15:BIGINT])",
           "\n  LogicalTableScan(table=[[default, a]])",
           "\n"
         ]
diff --git 
a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java
 
b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java
index 55d8d7172f..58de9ec70c 100644
--- 
a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java
+++ 
b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java
@@ -365,7 +365,7 @@ public class ExpressionTransformerTest {
       expressionTransformer.transform(genericRow);
       Assert.fail();
     } catch (Exception e) {
-      Assert.assertEquals(e.getCause().getMessage(), "Caught exception while 
executing function: plus(x,'10')");
+      Assert.assertTrue(e.getCause().getMessage().contains("Caught exception 
while executing function"));
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org


Reply via email to