This is an automated email from the ASF dual-hosted git repository.

yashmayya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new a7fde6d78aa Fix function polymorphism for 
ScalarTransformFunctionWrapper (#16992)
a7fde6d78aa is described below

commit a7fde6d78aa8ba7a16a352f5bd89097193b04bf1
Author: Yash Mayya <[email protected]>
AuthorDate: Thu Oct 9 18:50:07 2025 -0700

    Fix function polymorphism for ScalarTransformFunctionWrapper (#16992)
---
 .../function/TransformFunctionFactory.java         | 26 ++++++++++++++++------
 .../tests/BaseClusterIntegrationTestSet.java       | 16 +++++++++++++
 2 files changed, 35 insertions(+), 7 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index f25f51d8d33..ff5c78caf0d 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -33,6 +33,7 @@ import org.apache.pinot.common.function.TransformFunctionType;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.FunctionContext;
 import org.apache.pinot.common.request.context.LiteralContext;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.common.utils.HashUtil;
 import org.apache.pinot.core.geospatial.transform.function.GeoToH3Function;
 import org.apache.pinot.core.geospatial.transform.function.GridDiskFunction;
@@ -55,6 +56,7 @@ import 
org.apache.pinot.core.geospatial.transform.function.StPointFunction;
 import org.apache.pinot.core.geospatial.transform.function.StPolygonFunction;
 import org.apache.pinot.core.geospatial.transform.function.StWithinFunction;
 import org.apache.pinot.core.operator.ColumnContext;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
 import 
org.apache.pinot.core.operator.transform.function.SingleParamMathTransformFunction.AbsTransformFunction;
 import 
org.apache.pinot.core.operator.transform.function.SingleParamMathTransformFunction.CeilTransformFunction;
 import 
org.apache.pinot.core.operator.transform.function.SingleParamMathTransformFunction.ExpTransformFunction;
@@ -308,8 +310,6 @@ public class TransformFunctionFactory {
       case FUNCTION:
         FunctionContext function = expression.getFunction();
         String functionName = canonicalize(function.getFunctionName());
-        List<ExpressionContext> arguments = function.getArguments();
-        int numArguments = arguments.size();
 
         // Check if the function is ArrayValueConstructor transform function
         if 
(functionName.equalsIgnoreCase(ArrayLiteralTransformFunction.FUNCTION_NAME)) {
@@ -324,6 +324,15 @@ public class TransformFunctionFactory {
               GenerateArrayTransformFunction::new);
         }
 
+        List<ExpressionContext> arguments = function.getArguments();
+        int numArguments = arguments.size();
+
+        // Build child transform functions first to derive argument data types 
for scalar function polymorphism
+        List<TransformFunction> transformFunctionArguments = new 
ArrayList<>(numArguments);
+        for (ExpressionContext argument : arguments) {
+          
transformFunctionArguments.add(TransformFunctionFactory.get(argument, 
columnContextMap, queryContext));
+        }
+
         TransformFunction transformFunction;
         Class<? extends TransformFunction> transformFunctionClass = 
TRANSFORM_FUNCTION_MAP.get(functionName);
         if (transformFunctionClass != null) {
@@ -336,7 +345,14 @@ public class TransformFunctionFactory {
         } else {
           // Scalar function
           String canonicalName = FunctionRegistry.canonicalize(functionName);
-          FunctionInfo functionInfo = 
FunctionRegistry.lookupFunctionInfo(canonicalName, numArguments);
+          // Get data types for the arguments
+          ColumnDataType[] argumentDataTypes = new 
ColumnDataType[numArguments];
+          for (int i = 0; i < numArguments; i++) {
+            TransformResultMetadata resultMetadata = 
transformFunctionArguments.get(i).getResultMetadata();
+            argumentDataTypes[i] =
+                ColumnDataType.fromDataType(resultMetadata.getDataType(), 
resultMetadata.isSingleValue());
+          }
+          FunctionInfo functionInfo = 
FunctionRegistry.lookupFunctionInfo(canonicalName, argumentDataTypes);
           if (functionInfo == null) {
             if (FunctionRegistry.contains(canonicalName)) {
               throw new BadQueryRequestException(
@@ -348,10 +364,6 @@ public class TransformFunctionFactory {
           transformFunction = new ScalarTransformFunctionWrapper(functionInfo);
         }
 
-        List<TransformFunction> transformFunctionArguments = new 
ArrayList<>(numArguments);
-        for (ExpressionContext argument : arguments) {
-          
transformFunctionArguments.add(TransformFunctionFactory.get(argument, 
columnContextMap, queryContext));
-        }
         try {
           transformFunction.init(transformFunctionArguments, columnContextMap, 
queryContext.isNullHandlingEnabled());
         } catch (Exception e) {
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
index b9c9072b821..4de86eb3c6c 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
@@ -211,6 +211,22 @@ public abstract class BaseClusterIntegrationTestSet 
extends BaseClusterIntegrati
     query = "SELECT count(*) FROM mytable WHERE (NOT DaysSinceEpoch = 16312) 
AND Carrier = 'DL'";
     testQuery(query);
 
+    // BETWEEN
+    query = "SELECT count(*) FROM mytable WHERE OriginState BETWEEN 'DE' AND 
'PA'";
+    testQuery(query);
+
+    query = "SELECT count(*) FROM mytable WHERE OriginState BETWEEN 'PA' AND 
'DE'";
+    testQuery(query);
+
+    query = "SELECT count(*) FROM mytable WHERE DaysSinceEpoch BETWEEN 16312 
AND 16318";
+    testQuery(query);
+
+    query = "SELECT Carrier BETWEEN 'AA' AND 'QQ' FROM mytable";
+    testQuery(query);
+
+    query = "SELECT DaysSinceEpoch BETWEEN 16312 AND 16318 FROM mytable";
+    testQuery(query);
+
     // Post-aggregation in ORDER-BY
     query = "SELECT MAX(ArrTime) FROM mytable GROUP BY DaysSinceEpoch ORDER BY 
MAX(ArrTime) - MIN(ArrTime)";
     testQuery(query);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to