This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 8e5e8d915e Fix array literal handling (#13345)
8e5e8d915e is described below

commit 8e5e8d915eff1f769d5b4cd2ad666fd7ff90166e
Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com>
AuthorDate: Sat Jun 8 21:23:37 2024 -0700

    Fix array literal handling (#13345)
---
 .../common/request/context/ExpressionContext.java  |   2 +
 .../common/request/context/LiteralContext.java     | 135 ++++++++++++-----
 .../pinot/common/utils/request/RequestUtils.java   |  30 ++--
 pinot-common/src/main/proto/expressions.proto      |  25 ++++
 .../function/ArrayLiteralTransformFunction.java    |  37 ++---
 .../function/TransformFunctionFactory.java         |   7 +-
 .../function/HistogramAggregationFunction.java     |  46 ++++--
 .../query/executor/ServerQueryExecutorV1Impl.java  |   9 +-
 .../query/parser/CalciteRexExpressionParser.java   |  17 ++-
 .../pinot/query/planner/logical/RexExpression.java |   5 +-
 .../serde/ProtoExpressionToRexExpression.java      |  58 ++++++++
 .../serde/RexExpressionToProtoExpression.java      |  25 ++++
 .../planner/serde/RexExpressionSerDeTest.java      | 165 +++++++++++++++++++++
 13 files changed, 457 insertions(+), 104 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java
index 927ab4eb69..d52c0091e0 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.common.request.context;
 
+import com.google.common.annotations.VisibleForTesting;
 import java.util.Objects;
 import java.util.Set;
 import javax.annotation.Nullable;
@@ -51,6 +52,7 @@ public class ExpressionContext {
     return forLiteral(new LiteralContext(literal));
   }
 
+  @VisibleForTesting
   public static ExpressionContext forLiteral(DataType type, @Nullable Object 
value) {
     return forLiteral(new LiteralContext(type, value));
   }
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
index eb0667296f..0a2b8ad6e1 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java
@@ -18,8 +18,12 @@
  */
 package org.apache.pinot.common.request.context;
 
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
 import java.math.BigDecimal;
 import java.sql.Timestamp;
+import java.util.Arrays;
+import java.util.List;
 import java.util.Objects;
 import javax.annotation.Nullable;
 import org.apache.pinot.common.request.Literal;
@@ -52,101 +56,147 @@ public class LiteralContext {
   private String _stringValue;
   private byte[] _bytesValue;
 
-  public LiteralContext(DataType type, Object value) {
-    _type = type;
-    _value = value;
-    _pinotDataType = getPinotDataType(type);
-  }
-
   public LiteralContext(Literal literal) {
     switch (literal.getSetField()) {
+      case NULL_VALUE:
+        _type = DataType.UNKNOWN;
+        _value = null;
+        _pinotDataType = null;
+        break;
       case BOOL_VALUE:
         _type = DataType.BOOLEAN;
         _value = literal.getBoolValue();
+        _pinotDataType = PinotDataType.BOOLEAN;
         break;
       case INT_VALUE:
         _type = DataType.INT;
         _value = literal.getIntValue();
+        _pinotDataType = PinotDataType.INTEGER;
         break;
       case LONG_VALUE:
         _type = DataType.LONG;
         _value = literal.getLongValue();
+        _pinotDataType = PinotDataType.LONG;
         break;
       case FLOAT_VALUE:
         _type = DataType.FLOAT;
         _value = Float.intBitsToFloat(literal.getFloatValue());
+        _pinotDataType = PinotDataType.FLOAT;
         break;
       case DOUBLE_VALUE:
         _type = DataType.DOUBLE;
         _value = literal.getDoubleValue();
+        _pinotDataType = PinotDataType.DOUBLE;
         break;
       case BIG_DECIMAL_VALUE:
         _type = DataType.BIG_DECIMAL;
         _value = BigDecimalUtils.deserialize(literal.getBigDecimalValue());
+        _pinotDataType = PinotDataType.BIG_DECIMAL;
         break;
       case STRING_VALUE:
         _type = DataType.STRING;
         _value = literal.getStringValue();
+        _pinotDataType = PinotDataType.STRING;
         break;
       case BINARY_VALUE:
         _type = DataType.BYTES;
         _value = literal.getBinaryValue();
+        _pinotDataType = PinotDataType.BYTES;
         break;
-      // TODO: Revisit the type handling and whether we should convert value 
to primitive array for ARRAY types
-      case INT_ARRAY_VALUE:
+      case INT_ARRAY_VALUE: {
         _type = DataType.INT;
-        _value = literal.getIntArrayValue();
+        List<Integer> valueList = literal.getIntArrayValue();
+        int numValues = valueList.size();
+        int[] values = new int[numValues];
+        for (int i = 0; i < numValues; i++) {
+          values[i] = valueList.get(i);
+        }
+        _value = values;
+        _pinotDataType = PinotDataType.PRIMITIVE_INT_ARRAY;
         break;
-      case LONG_ARRAY_VALUE:
+      }
+      case LONG_ARRAY_VALUE: {
         _type = DataType.LONG;
-        _value = literal.getLongArrayValue();
+        List<Long> valueList = literal.getLongArrayValue();
+        int numValues = valueList.size();
+        long[] values = new long[numValues];
+        for (int i = 0; i < numValues; i++) {
+          values[i] = valueList.get(i);
+        }
+        _value = values;
+        _pinotDataType = PinotDataType.PRIMITIVE_LONG_ARRAY;
         break;
-      // TODO: Revisit the FLOAT_ARRAY handling. Currently the values are 
stored as int bits.
-      case FLOAT_ARRAY_VALUE:
+      }
+      case FLOAT_ARRAY_VALUE: {
         _type = DataType.FLOAT;
-        _value = literal.getFloatArrayValue();
+        List<Integer> valueList = literal.getFloatArrayValue();
+        int numValues = valueList.size();
+        float[] values = new float[numValues];
+        for (int i = 0; i < numValues; i++) {
+          values[i] = Float.intBitsToFloat(valueList.get(i));
+        }
+        _value = values;
+        _pinotDataType = PinotDataType.PRIMITIVE_FLOAT_ARRAY;
         break;
-      case DOUBLE_ARRAY_VALUE:
+      }
+      case DOUBLE_ARRAY_VALUE: {
         _type = DataType.DOUBLE;
-        _value = literal.getDoubleArrayValue();
+        List<Double> valueList = literal.getDoubleArrayValue();
+        int numValues = valueList.size();
+        double[] values = new double[numValues];
+        for (int i = 0; i < numValues; i++) {
+          values[i] = valueList.get(i);
+        }
+        _value = values;
+        _pinotDataType = PinotDataType.PRIMITIVE_DOUBLE_ARRAY;
         break;
+      }
       case STRING_ARRAY_VALUE:
         _type = DataType.STRING;
-        _value = literal.getStringArrayValue();
-        break;
-      case NULL_VALUE:
-        _type = DataType.UNKNOWN;
-        _value = null;
+        _value = literal.getStringArrayValue().toArray(new String[0]);
+        _pinotDataType = PinotDataType.STRING_ARRAY;
         break;
       default:
         throw new IllegalStateException("Unsupported field type: " + 
literal.getSetField());
     }
-    _pinotDataType = getPinotDataType(_type);
+  }
+
+  @VisibleForTesting
+  public LiteralContext(DataType type, @Nullable Object value) {
+    _type = type;
+    _value = value;
+    _pinotDataType = getPinotDataType(type, value);
   }
 
   @Nullable
-  private static PinotDataType getPinotDataType(DataType type) {
+  private static PinotDataType getPinotDataType(DataType type, @Nullable 
Object value) {
+    if (value == null) {
+      return null;
+    }
+    if (type == DataType.BYTES) {
+      Preconditions.checkState(value.getClass().getComponentType() == 
byte.class, "Bytes array is not supported");
+      return PinotDataType.BYTES;
+    }
+    boolean singleValue = !value.getClass().isArray();
     switch (type) {
       case BOOLEAN:
+        Preconditions.checkState(singleValue, "Boolean array is not 
supported");
         return PinotDataType.BOOLEAN;
       case INT:
-        return PinotDataType.INTEGER;
+        return singleValue ? PinotDataType.INTEGER : 
PinotDataType.PRIMITIVE_INT_ARRAY;
       case LONG:
-        return PinotDataType.LONG;
+        return singleValue ? PinotDataType.LONG : 
PinotDataType.PRIMITIVE_LONG_ARRAY;
       case FLOAT:
-        return PinotDataType.FLOAT;
+        return singleValue ? PinotDataType.FLOAT : 
PinotDataType.PRIMITIVE_FLOAT_ARRAY;
       case DOUBLE:
-        return PinotDataType.DOUBLE;
+        return singleValue ? PinotDataType.DOUBLE : 
PinotDataType.PRIMITIVE_DOUBLE_ARRAY;
       case BIG_DECIMAL:
+        Preconditions.checkState(singleValue, "BigDecimal array is not 
supported");
         return PinotDataType.BIG_DECIMAL;
       case STRING:
-        return PinotDataType.STRING;
-      case BYTES:
-        return PinotDataType.BYTES;
-      case UNKNOWN:
-        return null;
+        return singleValue ? PinotDataType.STRING : PinotDataType.STRING_ARRAY;
       default:
-        throw new IllegalStateException("Unsupported data type: " + type);
+        throw new IllegalStateException("Unsupported DataType: " + type);
     }
   }
 
@@ -159,6 +209,10 @@ public class LiteralContext {
     return _value;
   }
 
+  public boolean isSingleValue() {
+    return _pinotDataType == null || _pinotDataType.isSingleValue();
+  }
+
   public boolean getBooleanValue() {
     Boolean booleanValue = _booleanValue;
     if (booleanValue == null) {
@@ -281,8 +335,21 @@ public class LiteralContext {
     //       https://github.com/apache/pinot/pull/11762)
     if (isNull()) {
       return "'null'";
-    } else {
+    }
+    if (isSingleValue()) {
       return "'" + getStringValue() + "'";
     }
+    switch (_pinotDataType) {
+      case PRIMITIVE_INT_ARRAY:
+        return "'" + Arrays.toString((int[]) _value) + "'";
+      case PRIMITIVE_LONG_ARRAY:
+        return "'" + Arrays.toString((long[]) _value) + "'";
+      case PRIMITIVE_FLOAT_ARRAY:
+        return "'" + Arrays.toString((float[]) _value) + "'";
+      case PRIMITIVE_DOUBLE_ARRAY:
+        return "'" + Arrays.toString((double[]) _value) + "'";
+      default:
+        throw new IllegalStateException("Unsupported PinotDataType: " + 
_pinotDataType);
+    }
   }
 }
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
index 44a0931957..7cc8387731 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java
@@ -170,48 +170,48 @@ public class RequestUtils {
       return getNullLiteral();
     }
     if (object instanceof Boolean) {
-      return RequestUtils.getLiteral((boolean) object);
+      return getLiteral((boolean) object);
     }
     if (object instanceof Integer) {
-      return RequestUtils.getLiteral((int) object);
+      return getLiteral((int) object);
     }
     if (object instanceof Long) {
-      return RequestUtils.getLiteral((long) object);
+      return getLiteral((long) object);
     }
     if (object instanceof Float) {
-      return RequestUtils.getLiteral((float) object);
+      return getLiteral((float) object);
     }
     if (object instanceof Double) {
-      return RequestUtils.getLiteral((double) object);
+      return getLiteral((double) object);
     }
     if (object instanceof BigDecimal) {
-      return RequestUtils.getLiteral((BigDecimal) object);
+      return getLiteral((BigDecimal) object);
     }
     if (object instanceof Timestamp) {
-      return RequestUtils.getLiteral(((Timestamp) object).getTime());
+      return getLiteral(((Timestamp) object).getTime());
     }
     if (object instanceof String) {
-      return RequestUtils.getLiteral((String) object);
+      return getLiteral((String) object);
     }
     if (object instanceof byte[]) {
-      return RequestUtils.getLiteral((byte[]) object);
+      return getLiteral((byte[]) object);
     }
     if (object instanceof int[]) {
-      return RequestUtils.getLiteral((int[]) object);
+      return getLiteral((int[]) object);
     }
     if (object instanceof long[]) {
-      return RequestUtils.getLiteral((long[]) object);
+      return getLiteral((long[]) object);
     }
     if (object instanceof float[]) {
-      return RequestUtils.getLiteral((float[]) object);
+      return getLiteral((float[]) object);
     }
     if (object instanceof double[]) {
-      return RequestUtils.getLiteral((double[]) object);
+      return getLiteral((double[]) object);
     }
     if (object instanceof String[]) {
-      return RequestUtils.getLiteral((String[]) object);
+      return getLiteral((String[]) object);
     }
-    return RequestUtils.getLiteral(object.toString());
+    return getLiteral(object.toString());
   }
 
   public static Literal getLiteral(SqlLiteral node) {
diff --git a/pinot-common/src/main/proto/expressions.proto 
b/pinot-common/src/main/proto/expressions.proto
index ebc164a2ad..17cf4ac115 100644
--- a/pinot-common/src/main/proto/expressions.proto
+++ b/pinot-common/src/main/proto/expressions.proto
@@ -58,9 +58,34 @@ message Literal {
     double double = 6;
     string string = 7;
     bytes bytes = 8;
+    IntArray intArray = 9;
+    LongArray longArray = 10;
+    FloatArray floatArray = 11;
+    DoubleArray doubleArray = 12;
+    StringArray stringArray = 13;
   }
 }
 
+message IntArray {
+  repeated int32 values = 1;
+}
+
+message LongArray {
+  repeated int64 values = 1;
+}
+
+message FloatArray {
+  repeated float values = 1;
+}
+
+message DoubleArray {
+  repeated double values = 1;
+}
+
+message StringArray {
+  repeated string values = 1;
+}
+
 message FunctionCall {
   ColumnDataType dataType = 1;
   string functionName = 2;
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
index b2065e20d3..084d34bf2e 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
@@ -56,10 +56,9 @@ public class ArrayLiteralTransformFunction implements 
TransformFunction {
   private String[][] _stringArrayResult;
 
   public ArrayLiteralTransformFunction(LiteralContext literalContext) {
-    List literalArray = (List) literalContext.getValue();
-    Preconditions.checkNotNull(literalArray);
-    if (literalArray.isEmpty()) {
-      _dataType = DataType.UNKNOWN;
+    _dataType = literalContext.getType();
+    Object value = literalContext.getValue();
+    if (value == null) {
       _intArrayLiteral = new int[0];
       _longArrayLiteral = new long[0];
       _floatArrayLiteral = new float[0];
@@ -67,53 +66,37 @@ public class ArrayLiteralTransformFunction implements 
TransformFunction {
       _stringArrayLiteral = new String[0];
       return;
     }
-    _dataType = literalContext.getType();
     switch (_dataType) {
       case INT:
-        _intArrayLiteral = new int[literalArray.size()];
-        for (int i = 0; i < _intArrayLiteral.length; i++) {
-          _intArrayLiteral[i] = (int) literalArray.get(i);
-        }
+        _intArrayLiteral = (int[]) value;
         _longArrayLiteral = null;
         _floatArrayLiteral = null;
         _doubleArrayLiteral = null;
         _stringArrayLiteral = null;
         break;
       case LONG:
-        _longArrayLiteral = new long[literalArray.size()];
-        for (int i = 0; i < _longArrayLiteral.length; i++) {
-          _longArrayLiteral[i] = (long) literalArray.get(i);
-        }
+        _longArrayLiteral = (long[]) value;
         _intArrayLiteral = null;
         _floatArrayLiteral = null;
         _doubleArrayLiteral = null;
         _stringArrayLiteral = null;
         break;
       case FLOAT:
-        _floatArrayLiteral = new float[literalArray.size()];
-        for (int i = 0; i < _floatArrayLiteral.length; i++) {
-          _floatArrayLiteral[i] = (float) literalArray.get(i);
-        }
+        _floatArrayLiteral = (float[]) value;
         _intArrayLiteral = null;
         _longArrayLiteral = null;
         _doubleArrayLiteral = null;
         _stringArrayLiteral = null;
         break;
       case DOUBLE:
-        _doubleArrayLiteral = new double[literalArray.size()];
-        for (int i = 0; i < _doubleArrayLiteral.length; i++) {
-          _doubleArrayLiteral[i] = (double) literalArray.get(i);
-        }
+        _doubleArrayLiteral = (double[]) value;
         _intArrayLiteral = null;
         _longArrayLiteral = null;
         _floatArrayLiteral = null;
         _stringArrayLiteral = null;
         break;
       case STRING:
-        _stringArrayLiteral = new String[literalArray.size()];
-        for (int i = 0; i < _stringArrayLiteral.length; i++) {
-          _stringArrayLiteral[i] = (String) literalArray.get(i);
-        }
+        _stringArrayLiteral = (String[]) value;
         _intArrayLiteral = null;
         _longArrayLiteral = null;
         _floatArrayLiteral = null;
@@ -121,8 +104,8 @@ public class ArrayLiteralTransformFunction implements 
TransformFunction {
         break;
       default:
         throw new IllegalStateException(
-            "Illegal data type for ArrayLiteralTransformFunction: " + 
_dataType + ", literal contexts: "
-                + Arrays.toString(literalArray.toArray()));
+            "Illegal data type for ArrayLiteralTransformFunction: " + 
_dataType + ", literal context: "
+                + literalContext);
     }
   }
 
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index d5e4d9d481..de7668ca26 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -338,12 +338,13 @@ public class TransformFunctionFactory {
         return new IdentifierTransformFunction(columnName, 
columnContextMap.get(columnName));
       case LITERAL:
         LiteralContext literal = expression.getLiteral();
-        if (literal.getValue() != null && literal.getValue() instanceof 
ArrayList) {
+        if (literal.isSingleValue()) {
+          return 
queryContext.getOrComputeSharedValue(LiteralTransformFunction.class, literal,
+              LiteralTransformFunction::new);
+        } else {
           return 
queryContext.getOrComputeSharedValue(ArrayLiteralTransformFunction.class, 
literal,
               ArrayLiteralTransformFunction::new);
         }
-        return 
queryContext.getOrComputeSharedValue(LiteralTransformFunction.class, literal,
-            LiteralTransformFunction::new);
       default:
         throw new IllegalStateException();
     }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
index 078420bd60..77f5363fb8 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/HistogramAggregationFunction.java
@@ -31,6 +31,7 @@ import 
org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 import 
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
 import org.apache.pinot.core.query.aggregation.utils.DoubleVectorOpUtils;
 import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.apache.pinot.spi.utils.ArrayCopyUtils;
 
 
 /**
@@ -60,15 +61,15 @@ public class HistogramAggregationFunction extends 
BaseSingleInputAggregationFunc
       ExpressionContext arrayExpression = arguments.get(1);
       Preconditions.checkArgument(
           // ARRAY function
-          ((arrayExpression.getType() == ExpressionContext.Type.FUNCTION)
-              && 
(arrayExpression.getFunction().getFunctionName().equals(ARRAY_CONSTRUCTOR)))
-              || ((arrayExpression.getType() == ExpressionContext.Type.LITERAL)
-              && (arrayExpression.getLiteral().getValue() instanceof List)),
+          (arrayExpression.getType() == ExpressionContext.Type.FUNCTION && 
arrayExpression.getFunction()
+              .getFunctionName().equals(ARRAY_CONSTRUCTOR)) || (
+              arrayExpression.getType() == ExpressionContext.Type.LITERAL && 
!arrayExpression.getLiteral()
+                  .isSingleValue()),
           "Please use the format of `Histogram(columnName, ARRAY[1,10,100])` 
to specify the bin edges");
       if (arrayExpression.getType() == ExpressionContext.Type.FUNCTION) {
         _bucketEdges = 
parseVector(arrayExpression.getFunction().getArguments());
       } else {
-        _bucketEdges = parseVectorLiteral((List) 
arrayExpression.getLiteral().getValue());
+        _bucketEdges = 
parseVectorLiteral(arrayExpression.getLiteral().getValue());
       }
       _lower = _bucketEdges[0];
       _upper = _bucketEdges[_bucketEdges.length - 1];
@@ -111,22 +112,35 @@ public class HistogramAggregationFunction extends 
BaseSingleInputAggregationFunc
         ret[i] = arrayStr.get(i).getLiteral().getDoubleValue();
       }
       if (i > 0) {
-        Preconditions.checkState(ret[i] > ret[i - 1], "The bin edges must be 
strictly increasing");
+        Preconditions.checkArgument(ret[i] > ret[i - 1], "The bin edges must 
be strictly increasing");
       }
     }
     return ret;
   }
 
-  private double[] parseVectorLiteral(List arrayStr) {
-    int len = arrayStr.size();
-    Preconditions.checkArgument(len > 1, "The number of bin edges must be 
greater than 1");
-    double[] ret = new double[len];
-    for (int i = 0; i < len; i++) {
-      // TODO: Represent infinity as literal instead of identifier
-      ret[i] = Double.parseDouble(arrayStr.get(i).toString());
-      if (i > 0) {
-        Preconditions.checkState(ret[i] > ret[i - 1], "The bin edges must be 
strictly increasing");
-      }
+  private double[] parseVectorLiteral(Object array) {
+    Preconditions.checkArgument(array != null, "The bin edges must not be 
null");
+    double[] ret;
+    if (array instanceof int[]) {
+      int[] intArray = (int[]) array;
+      ret = new double[intArray.length];
+      ArrayCopyUtils.copy(intArray, ret, intArray.length);
+    } else if (array instanceof long[]) {
+      long[] longArray = (long[]) array;
+      ret = new double[longArray.length];
+      ArrayCopyUtils.copy(longArray, ret, longArray.length);
+    } else if (array instanceof float[]) {
+      float[] floatArray = (float[]) array;
+      ret = new double[floatArray.length];
+      ArrayCopyUtils.copy(floatArray, ret, floatArray.length);
+    } else if (array instanceof double[]) {
+      ret = (double[]) array;
+    } else {
+      throw new IllegalArgumentException("Unsupported array type: " + 
array.getClass());
+    }
+    Preconditions.checkArgument(ret.length > 1, "The number of bin edges must 
be greater than 1");
+    for (int i = 1; i < ret.length; i++) {
+      Preconditions.checkArgument(ret[i] > ret[i - 1], "The bin edges must be 
strictly increasing");
     }
     return ret;
   }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java
index 19f9421a84..8edc7b4970 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java
@@ -41,6 +41,7 @@ import 
org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.FilterContext;
 import org.apache.pinot.common.request.context.FunctionContext;
 import org.apache.pinot.common.utils.config.QueryOptionsUtils;
+import org.apache.pinot.common.utils.request.RequestUtils;
 import org.apache.pinot.core.common.ExplainPlanRowData;
 import org.apache.pinot.core.common.ExplainPlanRows;
 import org.apache.pinot.core.common.Operator;
@@ -72,7 +73,6 @@ import org.apache.pinot.segment.spi.IndexSegment;
 import org.apache.pinot.segment.spi.MutableSegment;
 import org.apache.pinot.segment.spi.SegmentContext;
 import org.apache.pinot.segment.spi.SegmentMetadata;
-import org.apache.pinot.spi.data.FieldSpec;
 import org.apache.pinot.spi.env.PinotConfiguration;
 import org.apache.pinot.spi.exception.BadQueryRequestException;
 import org.apache.pinot.spi.exception.QueryCancelledException;
@@ -236,8 +236,8 @@ public class ServerQueryExecutorV1Impl implements 
QueryExecutor {
           if (indexTimeMs > 0) {
             minIndexTimeMs = Math.min(minIndexTimeMs, indexTimeMs);
           }
-          long ingestionTimeMs = ((RealtimeTableDataManager)
-              
tableDataManager).getPartitionIngestionTimeMs(indexSegment.getSegmentName());
+          long ingestionTimeMs =
+              ((RealtimeTableDataManager) 
tableDataManager).getPartitionIngestionTimeMs(indexSegment.getSegmentName());
           if (ingestionTimeMs > 0) {
             minIngestionTimeMs = Math.min(minIngestionTimeMs, ingestionTimeMs);
           }
@@ -602,8 +602,7 @@ public class ServerQueryExecutorV1Impl implements 
QueryExecutor {
           result != null ? result.getClass().getSimpleName() : null);
       // Rewrite the expression
       function.setFunctionName(TransformFunctionType.IN_ID_SET.name());
-      arguments.set(1,
-          ExpressionContext.forLiteral(FieldSpec.DataType.STRING, ((IdSet) 
result).toBase64String()));
+      arguments.set(1, 
ExpressionContext.forLiteral(RequestUtils.getLiteral(((IdSet) 
result).toBase64String())));
     } else {
       for (ExpressionContext argument : arguments) {
         handleSubquery(argument, tableDataManager, indexSegments, 
timerContext, executorService, endTimeMs);
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
index 67992d4dfc..a20b2479d4 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
@@ -26,9 +26,12 @@ import 
org.apache.calcite.rel.RelFieldCollation.NullDirection;
 import org.apache.pinot.common.request.Expression;
 import org.apache.pinot.common.request.Literal;
 import org.apache.pinot.common.request.PinotQuery;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.common.utils.request.RequestUtils;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.planner.plannode.SortNode;
+import org.apache.pinot.spi.utils.BooleanUtils;
+import org.apache.pinot.spi.utils.ByteArray;
 import org.apache.pinot.sql.parsers.ParserUtils;
 
 
@@ -137,9 +140,19 @@ public class CalciteRexExpressionParser {
 
   public static Literal toLiteral(RexExpression.Literal literal) {
     Object value = literal.getValue();
+    if (value == null) {
+      return RequestUtils.getNullLiteral();
+    }
     // NOTE: Value is stored in internal format in RexExpression.Literal.
-    return value != null ? 
RequestUtils.getLiteral(literal.getDataType().toExternal(value))
-        : RequestUtils.getNullLiteral();
+    //       Do not convert TIMESTAMP/BOOLEAN_ARRAY/TIMESTAMP_ARRAY to 
external format because they are not explicitly
+    //       supported in single-stage engine Literal.
+    ColumnDataType dataType = literal.getDataType();
+    if (dataType == ColumnDataType.BOOLEAN) {
+      value = BooleanUtils.isTrueInternalValue(value);
+    } else if (dataType == ColumnDataType.BYTES) {
+      value = ((ByteArray) value).getBytes();
+    }
+    return RequestUtils.getLiteral(value);
   }
 
   private static Expression 
compileFunctionExpression(RexExpression.FunctionCall rexCall, PinotQuery 
pinotQuery) {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
index b81177877f..d06ee0473a 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.query.planner.logical;
 
+import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 import org.apache.calcite.rex.RexNode;
@@ -92,12 +93,12 @@ public interface RexExpression {
         return false;
       }
       Literal literal = (Literal) o;
-      return _dataType == literal._dataType && Objects.equals(_value, 
literal._value);
+      return _dataType == literal._dataType && Objects.deepEquals(_value, 
literal._value);
     }
 
     @Override
     public int hashCode() {
-      return Objects.hash(_dataType, _value);
+      return Arrays.deepHashCode(new Object[]{_dataType, _value});
     }
   }
 
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java
index 206f9dcd2a..e197276d75 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java
@@ -23,6 +23,7 @@ import java.util.List;
 import org.apache.pinot.common.proto.Expressions;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.spi.utils.BigDecimalUtils;
 import org.apache.pinot.spi.utils.ByteArray;
 
 
@@ -74,10 +75,67 @@ public class ProtoExpressionToRexExpression {
         return new RexExpression.Literal(dataType, literal.getFloat());
       case DOUBLE:
         return new RexExpression.Literal(dataType, literal.getDouble());
+      case BIG_DECIMAL:
+        return new RexExpression.Literal(dataType, 
BigDecimalUtils.deserialize(literal.getBytes().toByteArray()));
       case STRING:
         return new RexExpression.Literal(dataType, literal.getString());
       case BYTES:
         return new RexExpression.Literal(dataType, new 
ByteArray(literal.getBytes().toByteArray()));
+      case INT_ARRAY: {
+        Expressions.IntArray intArray = literal.getIntArray();
+        int numValues = intArray.getValuesCount();
+        int[] values = new int[numValues];
+        {
+          for (int i = 0; i < numValues; i++) {
+            values[i] = intArray.getValues(i);
+          }
+        }
+        return new RexExpression.Literal(dataType, values);
+      }
+      case LONG_ARRAY: {
+        Expressions.LongArray longArray = literal.getLongArray();
+        int numValues = longArray.getValuesCount();
+        long[] values = new long[numValues];
+        {
+          for (int i = 0; i < numValues; i++) {
+            values[i] = longArray.getValues(i);
+          }
+        }
+        return new RexExpression.Literal(dataType, values);
+      }
+      case FLOAT_ARRAY: {
+        Expressions.FloatArray floatArray = literal.getFloatArray();
+        int numValues = floatArray.getValuesCount();
+        float[] values = new float[numValues];
+        {
+          for (int i = 0; i < numValues; i++) {
+            values[i] = floatArray.getValues(i);
+          }
+        }
+        return new RexExpression.Literal(dataType, values);
+      }
+      case DOUBLE_ARRAY: {
+        Expressions.DoubleArray doubleArray = literal.getDoubleArray();
+        int numValues = doubleArray.getValuesCount();
+        double[] values = new double[numValues];
+        {
+          for (int i = 0; i < numValues; i++) {
+            values[i] = doubleArray.getValues(i);
+          }
+        }
+        return new RexExpression.Literal(dataType, values);
+      }
+      case STRING_ARRAY: {
+        Expressions.StringArray stringArray = literal.getStringArray();
+        int numValues = stringArray.getValuesCount();
+        String[] values = new String[numValues];
+        {
+          for (int i = 0; i < numValues; i++) {
+            values[i] = stringArray.getValues(i);
+          }
+        }
+        return new RexExpression.Literal(dataType, values);
+      }
       default:
         throw new IllegalStateException("Unsupported ColumnDataType: " + 
dataType);
     }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java
index 0ff66c0c38..0350d8ba8c 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java
@@ -19,8 +19,13 @@
 package org.apache.pinot.query.planner.serde;
 
 import com.google.protobuf.ByteString;
+import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.floats.FloatArrayList;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
 import java.math.BigDecimal;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import org.apache.pinot.common.proto.Expressions;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
@@ -94,6 +99,26 @@ public class RexExpressionToProtoExpression {
         case BYTES:
           literalBuilder.setBytes(ByteString.copyFrom(((ByteArray) 
value).getBytes()));
           break;
+        case INT_ARRAY:
+          literalBuilder.setIntArray(
+              
Expressions.IntArray.newBuilder().addAllValues(IntArrayList.wrap((int[]) 
value)).build());
+          break;
+        case LONG_ARRAY:
+          literalBuilder.setLongArray(
+              
Expressions.LongArray.newBuilder().addAllValues(LongArrayList.wrap((long[]) 
value)).build());
+          break;
+        case FLOAT_ARRAY:
+          literalBuilder.setFloatArray(
+              
Expressions.FloatArray.newBuilder().addAllValues(FloatArrayList.wrap((float[]) 
value)).build());
+          break;
+        case DOUBLE_ARRAY:
+          literalBuilder.setDoubleArray(
+              
Expressions.DoubleArray.newBuilder().addAllValues(DoubleArrayList.wrap((double[])
 value)).build());
+          break;
+        case STRING_ARRAY:
+          literalBuilder.setStringArray(
+              
Expressions.StringArray.newBuilder().addAllValues(Arrays.asList((String[]) 
value)).build());
+          break;
         default:
           throw new IllegalStateException("Unsupported ColumnDataType: " + 
dataType);
       }
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/RexExpressionSerDeTest.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/RexExpressionSerDeTest.java
new file mode 100644
index 0000000000..b933f5c990
--- /dev/null
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/RexExpressionSerDeTest.java
@@ -0,0 +1,165 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.serde;
+
+import java.math.BigDecimal;
+import java.util.List;
+import java.util.Random;
+import org.apache.commons.lang.RandomStringUtils;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.spi.utils.BooleanUtils;
+import org.apache.pinot.spi.utils.ByteArray;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+
+
+public class RexExpressionSerDeTest {
+  private static final List<ColumnDataType> SUPPORTED_DATE_TYPES =
+      List.of(ColumnDataType.INT, ColumnDataType.LONG, ColumnDataType.FLOAT, 
ColumnDataType.DOUBLE,
+          ColumnDataType.BIG_DECIMAL, ColumnDataType.BOOLEAN, 
ColumnDataType.TIMESTAMP, ColumnDataType.STRING,
+          ColumnDataType.BYTES, ColumnDataType.INT_ARRAY, 
ColumnDataType.LONG_ARRAY, ColumnDataType.FLOAT_ARRAY,
+          ColumnDataType.DOUBLE_ARRAY, ColumnDataType.BOOLEAN_ARRAY, 
ColumnDataType.TIMESTAMP_ARRAY,
+          ColumnDataType.STRING_ARRAY, ColumnDataType.UNKNOWN);
+  private static final Random RANDOM = new Random();
+
+  @Test
+  public void testNullLiteral() {
+    for (ColumnDataType dataType : SUPPORTED_DATE_TYPES) {
+      verifyLiteralSerDe(new RexExpression.Literal(dataType, null));
+    }
+  }
+
+  @Test
+  public void testIntLiteral() {
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.INT, 
RANDOM.nextInt()));
+  }
+
+  @Test
+  public void testLongLiteral() {
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.LONG, 
RANDOM.nextLong()));
+  }
+
+  @Test
+  public void testFloatLiteral() {
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.FLOAT, 
RANDOM.nextFloat()));
+  }
+
+  @Test
+  public void testDoubleLiteral() {
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.DOUBLE, 
RANDOM.nextDouble()));
+  }
+
+  @Test
+  public void testBigDecimalLiteral() {
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.BIG_DECIMAL,
+        RANDOM.nextBoolean() ? BigDecimal.valueOf(RANDOM.nextLong()) : 
BigDecimal.valueOf(RANDOM.nextDouble())));
+  }
+
+  @Test
+  public void testBooleanLiteral() {
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.BOOLEAN, 
BooleanUtils.toInt(RANDOM.nextBoolean())));
+  }
+
+  @Test
+  public void testTimestampLiteral() {
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.TIMESTAMP, 
RANDOM.nextLong()));
+  }
+
+  @Test
+  public void testStringLiteral() {
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.STRING, 
RandomStringUtils.random(RANDOM.nextInt(10))));
+  }
+
+  @Test
+  public void testBytesLiteral() {
+    byte[] bytes = new byte[RANDOM.nextInt(10)];
+    RANDOM.nextBytes(bytes);
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.BYTES, new 
ByteArray(bytes)));
+  }
+
+  @Test
+  public void testIntArrayLiteral() {
+    int[] values = new int[RANDOM.nextInt(10)];
+    for (int i = 0; i < values.length; i++) {
+      values[i] = RANDOM.nextInt();
+    }
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.INT_ARRAY, 
values));
+  }
+
+  @Test
+  public void testLongArrayLiteral() {
+    long[] values = new long[RANDOM.nextInt(10)];
+    for (int i = 0; i < values.length; i++) {
+      values[i] = RANDOM.nextLong();
+    }
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.LONG_ARRAY, 
values));
+  }
+
+  @Test
+  public void testFloatArrayLiteral() {
+    float[] values = new float[RANDOM.nextInt(10)];
+    for (int i = 0; i < values.length; i++) {
+      values[i] = RANDOM.nextFloat();
+    }
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.FLOAT_ARRAY, 
values));
+  }
+
+  @Test
+  public void testDoubleArrayLiteral() {
+    double[] values = new double[RANDOM.nextInt(10)];
+    for (int i = 0; i < values.length; i++) {
+      values[i] = RANDOM.nextDouble();
+    }
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.DOUBLE_ARRAY, 
values));
+  }
+
+  @Test
+  public void testBooleanArrayLiteral() {
+    int[] values = new int[RANDOM.nextInt(10)];
+    for (int i = 0; i < values.length; i++) {
+      values[i] = BooleanUtils.toInt(RANDOM.nextBoolean());
+    }
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.BOOLEAN_ARRAY, 
values));
+  }
+
+  @Test
+  public void testTimestampArrayLiteral() {
+    long[] values = new long[RANDOM.nextInt(10)];
+    for (int i = 0; i < values.length; i++) {
+      values[i] = RANDOM.nextLong();
+    }
+    verifyLiteralSerDe(new 
RexExpression.Literal(ColumnDataType.TIMESTAMP_ARRAY, values));
+  }
+
+  @Test
+  public void testStringArrayLiteral() {
+    String[] values = new String[RANDOM.nextInt(10)];
+    for (int i = 0; i < values.length; i++) {
+      values[i] = RandomStringUtils.random(RANDOM.nextInt(10));
+    }
+    verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.STRING_ARRAY, 
values));
+  }
+
+  private void verifyLiteralSerDe(RexExpression.Literal literal) {
+    assertEquals(literal,
+        
ProtoExpressionToRexExpression.convertLiteral(RexExpressionToProtoExpression.convertLiteral(literal)));
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org


Reply via email to