This is an automated email from the ASF dual-hosted git repository.
yashmayya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new a7fde6d78aa Fix function polymorphism for
ScalarTransformFunctionWrapper (#16992)
a7fde6d78aa is described below
commit a7fde6d78aa8ba7a16a352f5bd89097193b04bf1
Author: Yash Mayya <[email protected]>
AuthorDate: Thu Oct 9 18:50:07 2025 -0700
Fix function polymorphism for ScalarTransformFunctionWrapper (#16992)
---
.../function/TransformFunctionFactory.java | 26 ++++++++++++++++------
.../tests/BaseClusterIntegrationTestSet.java | 16 +++++++++++++
2 files changed, 35 insertions(+), 7 deletions(-)
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index f25f51d8d33..ff5c78caf0d 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -33,6 +33,7 @@ import org.apache.pinot.common.function.TransformFunctionType;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.request.context.LiteralContext;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.common.utils.HashUtil;
import org.apache.pinot.core.geospatial.transform.function.GeoToH3Function;
import org.apache.pinot.core.geospatial.transform.function.GridDiskFunction;
@@ -55,6 +56,7 @@ import
org.apache.pinot.core.geospatial.transform.function.StPointFunction;
import org.apache.pinot.core.geospatial.transform.function.StPolygonFunction;
import org.apache.pinot.core.geospatial.transform.function.StWithinFunction;
import org.apache.pinot.core.operator.ColumnContext;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import
org.apache.pinot.core.operator.transform.function.SingleParamMathTransformFunction.AbsTransformFunction;
import
org.apache.pinot.core.operator.transform.function.SingleParamMathTransformFunction.CeilTransformFunction;
import
org.apache.pinot.core.operator.transform.function.SingleParamMathTransformFunction.ExpTransformFunction;
@@ -308,8 +310,6 @@ public class TransformFunctionFactory {
case FUNCTION:
FunctionContext function = expression.getFunction();
String functionName = canonicalize(function.getFunctionName());
- List<ExpressionContext> arguments = function.getArguments();
- int numArguments = arguments.size();
// Check if the function is ArrayValueConstructor transform function
if
(functionName.equalsIgnoreCase(ArrayLiteralTransformFunction.FUNCTION_NAME)) {
@@ -324,6 +324,15 @@ public class TransformFunctionFactory {
GenerateArrayTransformFunction::new);
}
+ List<ExpressionContext> arguments = function.getArguments();
+ int numArguments = arguments.size();
+
+ // Build child transform functions first to derive argument data types
for scalar function polymorphism
+ List<TransformFunction> transformFunctionArguments = new
ArrayList<>(numArguments);
+ for (ExpressionContext argument : arguments) {
+
transformFunctionArguments.add(TransformFunctionFactory.get(argument,
columnContextMap, queryContext));
+ }
+
TransformFunction transformFunction;
Class<? extends TransformFunction> transformFunctionClass =
TRANSFORM_FUNCTION_MAP.get(functionName);
if (transformFunctionClass != null) {
@@ -336,7 +345,14 @@ public class TransformFunctionFactory {
} else {
// Scalar function
String canonicalName = FunctionRegistry.canonicalize(functionName);
- FunctionInfo functionInfo =
FunctionRegistry.lookupFunctionInfo(canonicalName, numArguments);
+ // Get data types for the arguments
+ ColumnDataType[] argumentDataTypes = new
ColumnDataType[numArguments];
+ for (int i = 0; i < numArguments; i++) {
+ TransformResultMetadata resultMetadata =
transformFunctionArguments.get(i).getResultMetadata();
+ argumentDataTypes[i] =
+ ColumnDataType.fromDataType(resultMetadata.getDataType(),
resultMetadata.isSingleValue());
+ }
+ FunctionInfo functionInfo =
FunctionRegistry.lookupFunctionInfo(canonicalName, argumentDataTypes);
if (functionInfo == null) {
if (FunctionRegistry.contains(canonicalName)) {
throw new BadQueryRequestException(
@@ -348,10 +364,6 @@ public class TransformFunctionFactory {
transformFunction = new ScalarTransformFunctionWrapper(functionInfo);
}
- List<TransformFunction> transformFunctionArguments = new
ArrayList<>(numArguments);
- for (ExpressionContext argument : arguments) {
-
transformFunctionArguments.add(TransformFunctionFactory.get(argument,
columnContextMap, queryContext));
- }
try {
transformFunction.init(transformFunctionArguments, columnContextMap,
queryContext.isNullHandlingEnabled());
} catch (Exception e) {
diff --git
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
index b9c9072b821..4de86eb3c6c 100644
---
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
+++
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
@@ -211,6 +211,22 @@ public abstract class BaseClusterIntegrationTestSet
extends BaseClusterIntegrati
query = "SELECT count(*) FROM mytable WHERE (NOT DaysSinceEpoch = 16312)
AND Carrier = 'DL'";
testQuery(query);
+ // BETWEEN
+ query = "SELECT count(*) FROM mytable WHERE OriginState BETWEEN 'DE' AND
'PA'";
+ testQuery(query);
+
+ query = "SELECT count(*) FROM mytable WHERE OriginState BETWEEN 'PA' AND
'DE'";
+ testQuery(query);
+
+ query = "SELECT count(*) FROM mytable WHERE DaysSinceEpoch BETWEEN 16312
AND 16318";
+ testQuery(query);
+
+ query = "SELECT Carrier BETWEEN 'AA' AND 'QQ' FROM mytable";
+ testQuery(query);
+
+ query = "SELECT DaysSinceEpoch BETWEEN 16312 AND 16318 FROM mytable";
+ testQuery(query);
+
// Post-aggregation in ORDER-BY
query = "SELECT MAX(ArrTime) FROM mytable GROUP BY DaysSinceEpoch ORDER BY
MAX(ArrTime) - MIN(ArrTime)";
testQuery(query);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]