This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 48490933de Support parsing ARRAY literal in multistage query engine 
(#11268)
48490933de is described below

commit 48490933ded6f20d94139eedbb57971fe0cbcb61
Author: Xiang Fu <xiangfu.1...@gmail.com>
AuthorDate: Mon Aug 7 13:10:35 2023 -0700

    Support parsing ARRAY literal in multistage query engine (#11268)
---
 .../common/function/TransformFunctionType.java     |   2 +
 .../org/apache/pinot/common/utils/DataSchema.java  |  54 +++-
 .../function/ArrayLiteralTransformFunction.java    | 291 +++++++++++++++++++++
 .../function/TransformFunctionFactory.java         |   8 +
 .../core/data/function/VectorFunctionsTest.java    |  19 ++
 .../ArrayLiteralTransformFunctionTest.java         | 167 ++++++++++++
 .../function/VectorTransformFunctionTest.java      |  12 +-
 .../integration/tests/VectorIntegrationTest.java   |  86 +++++-
 .../apache/calcite/sql/fun/PinotOperatorTable.java |   8 +-
 .../planner/logical/RelToPlanNodeConverter.java    |  21 +-
 .../local/function/InbuiltFunctionEvaluator.java   |  39 ++-
 11 files changed, 677 insertions(+), 30 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
index 471f6b128a..f741ff223e 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
@@ -221,6 +221,8 @@ public enum TransformFunctionType {
   VECTOR_NORM("vectorNorm", ReturnTypes.explicit(SqlTypeName.DOUBLE),
       OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), 
"vector_norm"),
 
+  ARRAY_VALUE_CONSTRUCTOR("arrayValueConstructor"),
+
   // Trigonometry
   SIN("sin"),
   COS("cos"),
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java 
b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
index 354ba8cd3c..282a3d7416 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java
@@ -24,6 +24,8 @@ import com.fasterxml.jackson.annotation.JsonProperty;
 import com.fasterxml.jackson.annotation.JsonPropertyOrder;
 import com.google.common.collect.Ordering;
 import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.floats.FloatArrayList;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
 import it.unimi.dsi.fastutil.longs.LongArrayList;
 import java.io.ByteArrayOutputStream;
 import java.io.DataOutputStream;
@@ -358,11 +360,11 @@ public class DataSchema {
         case BYTES:
           return ((ByteArray) value).getBytes();
         case INT_ARRAY:
-          return (int[]) value;
+          return toIntArray(value);
         case LONG_ARRAY:
           return toLongArray(value);
         case FLOAT_ARRAY:
-          return (float[]) value;
+          return toFloatArray(value);
         case DOUBLE_ARRAY:
           return toDoubleArray(value);
         case STRING_ARRAY:
@@ -475,6 +477,38 @@ public class DataSchema {
       }
     }
 
+    private static float[] toFloatArray(Object value) {
+      if (value instanceof float[]) {
+        return (float[]) value;
+      } else if (value instanceof FloatArrayList) {
+        return ((FloatArrayList) value).elements();
+      } else if (value instanceof int[]) {
+        int[] intValues = (int[]) value;
+        int length = intValues.length;
+        float[] floatValues = new float[length];
+        for (int i = 0; i < length; i++) {
+          floatValues[i] = intValues[i];
+        }
+        return floatValues;
+      } else if (value instanceof long[]) {
+        long[] longValues = (long[]) value;
+        int length = longValues.length;
+        float[] floatValues = new float[length];
+        for (int i = 0; i < length; i++) {
+          floatValues[i] = longValues[i];
+        }
+        return floatValues;
+      } else {
+        double[] doubleValues = (double[]) value;
+        int length = doubleValues.length;
+        float[] floatValues = new float[length];
+        for (int i = 0; i < length; i++) {
+          floatValues[i] = (float) doubleValues[i];
+        }
+        return floatValues;
+      }
+    }
+
     private static long[] toLongArray(Object value) {
       if (value instanceof long[]) {
         return (long[]) value;
@@ -491,6 +525,22 @@ public class DataSchema {
       }
     }
 
+    private static int[] toIntArray(Object value) {
+      if (value instanceof int[]) {
+        return (int[]) value;
+      } else if (value instanceof IntArrayList) {
+        return ((IntArrayList) value).elements();
+      } else {
+        long[] longValues = (long[]) value;
+        int length = longValues.length;
+        int[] intValues = new int[length];
+        for (int i = 0; i < length; i++) {
+          intValues[i] = (int) longValues[i];
+        }
+        return intValues;
+      }
+    }
+
     private static boolean[] toBooleanArray(Object value) {
       int[] ints = (int[]) value;
       boolean[] booleans = new boolean[ints.length];
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
new file mode 100644
index 0000000000..6208ee1966
--- /dev/null
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java
@@ -0,0 +1,291 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import com.google.common.base.Preconditions;
+import java.math.BigDecimal;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.operator.ColumnContext;
+import org.apache.pinot.core.operator.blocks.ValueBlock;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
+import org.apache.pinot.segment.spi.index.reader.Dictionary;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.roaringbitmap.RoaringBitmap;
+
+
+/**
+ * The <code>LiteralTransformFunction</code> class is a special transform 
function which is a wrapper on top of a
+ * LITERAL. The data type is inferred from the literal string.
+ */
+public class ArrayLiteralTransformFunction implements TransformFunction {
+  public static final String FUNCTION_NAME = "arrayValueConstructor";
+
+  private final DataType _dataType;
+
+  private final int[] _intArrayLiteral;
+  private final long[] _longArrayLiteral;
+  private final float[] _floatArrayLiteral;
+  private final double[] _doubleArrayLiteral;
+  private final String[] _stringArrayLiteral;
+
+  // literals may be shared but values are intentionally not volatile as 
assignment races are benign
+  private int[][] _intArrayResult;
+  private long[][] _longArrayResult;
+  private float[][] _floatArrayResult;
+  private double[][] _doubleArrayResult;
+  private String[][] _stringArrayResult;
+
+  public ArrayLiteralTransformFunction(List<ExpressionContext> 
literalContexts) {
+    Preconditions.checkNotNull(literalContexts);
+    if (literalContexts.isEmpty()) {
+      _dataType = DataType.UNKNOWN;
+      _intArrayLiteral = new int[0];
+      _longArrayLiteral = new long[0];
+      _floatArrayLiteral = new float[0];
+      _doubleArrayLiteral = new double[0];
+      _stringArrayLiteral = new String[0];
+      return;
+    }
+    for (ExpressionContext literalContext : literalContexts) {
+      Preconditions.checkState(literalContext.getType() == 
ExpressionContext.Type.LITERAL,
+          "ArrayLiteralTransformFunction only takes literals as arguments, 
found: %s", literalContext);
+    }
+    _dataType = literalContexts.get(0).getLiteral().getType();
+    switch (_dataType) {
+      case INT:
+        _intArrayLiteral = new int[literalContexts.size()];
+        for (int i = 0; i < _intArrayLiteral.length; i++) {
+          _intArrayLiteral[i] = 
literalContexts.get(i).getLiteral().getIntValue();
+        }
+        _longArrayLiteral = null;
+        _floatArrayLiteral = null;
+        _doubleArrayLiteral = null;
+        _stringArrayLiteral = null;
+        break;
+      case LONG:
+        _longArrayLiteral = new long[literalContexts.size()];
+        for (int i = 0; i < _longArrayLiteral.length; i++) {
+          _longArrayLiteral[i] = 
Long.parseLong(literalContexts.get(i).getLiteral().getStringValue());
+        }
+        _intArrayLiteral = null;
+        _floatArrayLiteral = null;
+        _doubleArrayLiteral = null;
+        _stringArrayLiteral = null;
+        break;
+      case FLOAT:
+        _floatArrayLiteral = new float[literalContexts.size()];
+        for (int i = 0; i < _floatArrayLiteral.length; i++) {
+          _floatArrayLiteral[i] = 
Float.parseFloat(literalContexts.get(i).getLiteral().getStringValue());
+        }
+        _intArrayLiteral = null;
+        _longArrayLiteral = null;
+        _doubleArrayLiteral = null;
+        _stringArrayLiteral = null;
+        break;
+      case DOUBLE:
+        _doubleArrayLiteral = new double[literalContexts.size()];
+        for (int i = 0; i < _doubleArrayLiteral.length; i++) {
+          _doubleArrayLiteral[i] = 
Double.parseDouble(literalContexts.get(i).getLiteral().getStringValue());
+        }
+        _intArrayLiteral = null;
+        _longArrayLiteral = null;
+        _floatArrayLiteral = null;
+        _stringArrayLiteral = null;
+        break;
+      case STRING:
+        _stringArrayLiteral = new String[literalContexts.size()];
+        for (int i = 0; i < _stringArrayLiteral.length; i++) {
+          _stringArrayLiteral[i] = 
literalContexts.get(i).getLiteral().getStringValue();
+        }
+        _intArrayLiteral = null;
+        _longArrayLiteral = null;
+        _floatArrayLiteral = null;
+        _doubleArrayLiteral = null;
+        break;
+      default:
+        throw new IllegalStateException(
+            "Illegal data type for ArrayLiteralTransformFunction: " + 
_dataType + ", literal contexts: "
+                + Arrays.toString(literalContexts.toArray()));
+    }
+  }
+
+  public int[] getIntArrayLiteral() {
+    return _intArrayLiteral;
+  }
+
+  public long[] getLongArrayLiteral() {
+    return _longArrayLiteral;
+  }
+
+  public float[] getFloatArrayLiteral() {
+    return _floatArrayLiteral;
+  }
+
+  public double[] getDoubleArrayLiteral() {
+    return _doubleArrayLiteral;
+  }
+
+  public String[] getStringArrayLiteral() {
+    return _stringArrayLiteral;
+  }
+
+  @Override
+  public String getName() {
+    return FUNCTION_NAME;
+  }
+
+  @Override
+  public void init(List<TransformFunction> arguments, Map<String, 
ColumnContext> columnContextMap) {
+  }
+
+  @Override
+  public TransformResultMetadata getResultMetadata() {
+    return new TransformResultMetadata(_dataType, false, false);
+  }
+
+  @Override
+  public Dictionary getDictionary() {
+    return null;
+  }
+
+  @Override
+  public int[] transformToDictIdsSV(ValueBlock valueBlock) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public int[][] transformToDictIdsMV(ValueBlock valueBlock) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public int[] transformToIntValuesSV(ValueBlock valueBlock) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public long[] transformToLongValuesSV(ValueBlock valueBlock) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public float[] transformToFloatValuesSV(ValueBlock valueBlock) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public double[] transformToDoubleValuesSV(ValueBlock valueBlock) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public BigDecimal[] transformToBigDecimalValuesSV(ValueBlock valueBlock) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public String[] transformToStringValuesSV(ValueBlock valueBlock) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public byte[][] transformToBytesValuesSV(ValueBlock valueBlock) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public int[][] transformToIntValuesMV(ValueBlock valueBlock) {
+    int numDocs = valueBlock.getNumDocs();
+    int[][] intArrayResult = _intArrayResult;
+    if (intArrayResult == null || intArrayResult.length < numDocs) {
+      intArrayResult = new int[numDocs][];
+      Arrays.fill(intArrayResult, _intArrayLiteral);
+      _intArrayResult = intArrayResult;
+    }
+    return intArrayResult;
+  }
+
+  @Override
+  public long[][] transformToLongValuesMV(ValueBlock valueBlock) {
+    int numDocs = valueBlock.getNumDocs();
+    long[][] longArrayResult = _longArrayResult;
+    if (longArrayResult == null || longArrayResult.length < numDocs) {
+      longArrayResult = new long[numDocs][];
+      Arrays.fill(longArrayResult, _longArrayLiteral);
+      _longArrayResult = longArrayResult;
+    }
+    return longArrayResult;
+  }
+
+  @Override
+  public float[][] transformToFloatValuesMV(ValueBlock valueBlock) {
+    int numDocs = valueBlock.getNumDocs();
+    float[][] floatArrayResult = _floatArrayResult;
+    if (floatArrayResult == null || floatArrayResult.length < numDocs) {
+      floatArrayResult = new float[numDocs][];
+      Arrays.fill(floatArrayResult, _floatArrayLiteral);
+      _floatArrayResult = floatArrayResult;
+    }
+    return floatArrayResult;
+  }
+
+  @Override
+  public double[][] transformToDoubleValuesMV(ValueBlock valueBlock) {
+    int numDocs = valueBlock.getNumDocs();
+    double[][] doubleArrayResult = _doubleArrayResult;
+    if (doubleArrayResult == null || doubleArrayResult.length < numDocs) {
+      doubleArrayResult = new double[numDocs][];
+      Arrays.fill(doubleArrayResult, _doubleArrayLiteral);
+      _doubleArrayResult = doubleArrayResult;
+    }
+    return doubleArrayResult;
+  }
+
+  @Override
+  public String[][] transformToStringValuesMV(ValueBlock valueBlock) {
+    int numDocs = valueBlock.getNumDocs();
+    String[][] stringArrayResult = _stringArrayResult;
+    if (stringArrayResult == null || stringArrayResult.length < numDocs) {
+      stringArrayResult = new String[numDocs][];
+      Arrays.fill(stringArrayResult, _stringArrayLiteral);
+      _stringArrayResult = stringArrayResult;
+    }
+    return stringArrayResult;
+  }
+
+  @Override
+  public byte[][][] transformToBytesValuesMV(ValueBlock valueBlock) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public RoaringBitmap getNullBitmap(ValueBlock valueBlock) {
+    // Treat all unknown type values as null regardless of the value.
+    if (_dataType != DataType.UNKNOWN) {
+      return null;
+    }
+    int length = valueBlock.getNumDocs();
+    RoaringBitmap bitmap = new RoaringBitmap();
+    bitmap.add(0L, length);
+    return bitmap;
+  }
+}
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index 4e3ff24119..be2b54d128 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -153,6 +153,7 @@ public class TransformFunctionFactory {
     typeToImplementation.put(TransformFunctionType.ARRAYMAX, 
ArrayMaxTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.ARRAYMIN, 
ArrayMinTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.ARRAYSUM, 
ArraySumTransformFunction.class);
+    typeToImplementation.put(TransformFunctionType.ARRAY_VALUE_CONSTRUCTOR, 
ArrayLiteralTransformFunction.class);
 
     typeToImplementation.put(TransformFunctionType.GROOVY, 
GroovyTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.CASE, 
CaseTransformFunction.class);
@@ -281,6 +282,13 @@ public class TransformFunctionFactory {
         List<ExpressionContext> arguments = function.getArguments();
         int numArguments = arguments.size();
 
+        // Check if the function is ArrayLiteraltransform function
+        if 
(functionName.equalsIgnoreCase(ArrayLiteralTransformFunction.FUNCTION_NAME)) {
+          return 
queryContext.getOrComputeSharedValue(ArrayLiteralTransformFunction.class,
+              expression.getFunction().getArguments(),
+              ArrayLiteralTransformFunction::new);
+        }
+
         TransformFunction transformFunction;
         Class<? extends TransformFunction> transformFunctionClass = 
TRANSFORM_FUNCTION_MAP.get(functionName);
         if (transformFunctionClass != null) {
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java
index 972c33ee43..6600b5c10f 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java
@@ -108,6 +108,25 @@ public class VectorFunctionsTest {
     inputs.add(new Object[]{"vectorDims(vector2)", 
Lists.newArrayList("vector2"), row, 5});
     inputs.add(new Object[]{"vectorNorm(vector1)", 
Lists.newArrayList("vector1"), row, 0.741619857751291});
     inputs.add(new Object[]{"vectorNorm(vector2)", 
Lists.newArrayList("vector2"), row, 0.0});
+
+    inputs.add(new Object[]{
+        "cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", 
Lists.newArrayList("vector1"), row, Double.NaN
+    });
+    inputs.add(new Object[]{
+        "cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0], 0.0)", 
Lists.newArrayList("vector1"), row, 0.0
+    });
+    inputs.add(new Object[]{
+        "cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0], 1.0)", 
Lists.newArrayList("vector1"), row, 1.0
+    });
+    inputs.add(new Object[]{
+        "innerProduct(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", 
Lists.newArrayList("vector1"), row, 0.0
+    });
+    inputs.add(new Object[]{
+        "l2Distance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", 
Lists.newArrayList("vector1"), row, 0.741619857751291
+    });
+    inputs.add(new Object[]{
+        "l1Distance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", 
Lists.newArrayList("vector1"), row, 1.5000000223517418
+    });
     return inputs.toArray(new Object[0][]);
   }
 }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java
new file mode 100644
index 0000000000..005b8c3eeb
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.pinot.common.request.Literal;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.testng.Assert;
+import org.testng.annotations.AfterMethod;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+import static org.mockito.Mockito.when;
+
+
+public class ArrayLiteralTransformFunctionTest {
+  private static final int NUM_DOCS = 100;
+  private AutoCloseable _mocks;
+
+  @Mock
+  private ProjectionBlock _projectionBlock;
+
+  @BeforeMethod
+  public void setUp() {
+    _mocks = MockitoAnnotations.openMocks(this);
+    when(_projectionBlock.getNumDocs()).thenReturn(NUM_DOCS);
+  }
+
+  @AfterMethod
+  public void tearDown()
+      throws Exception {
+    _mocks.close();
+  }
+
+  @Test
+  public void testIntArrayLiteralTransformFunction() {
+    List<ExpressionContext> arrayExpressions = new ArrayList<>();
+    for (int i = 0; i < 10; i++) {
+      arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.INT, 
i));
+    }
+
+    ArrayLiteralTransformFunction intArray = new 
ArrayLiteralTransformFunction(arrayExpressions);
+    Assert.assertEquals(intArray.getResultMetadata().getDataType(), 
DataType.INT);
+    Assert.assertEquals(intArray.getIntArrayLiteral(), new int[]{
+        0, 1, 2, 3, 4, 5, 6, 7, 8, 9
+    });
+  }
+
+  @Test
+  public void testLongArrayLiteralTransformFunction() {
+    List<ExpressionContext> arrayExpressions = new ArrayList<>();
+    for (int i = 0; i < 10; i++) {
+      arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.LONG, 
(long) i));
+    }
+
+    ArrayLiteralTransformFunction longArray = new 
ArrayLiteralTransformFunction(arrayExpressions);
+    Assert.assertEquals(longArray.getResultMetadata().getDataType(), 
DataType.LONG);
+    Assert.assertEquals(longArray.getLongArrayLiteral(), new long[]{
+        0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L
+    });
+  }
+
+  @Test
+  public void testFloatArrayLiteralTransformFunction() {
+    List<ExpressionContext> arrayExpressions = new ArrayList<>();
+    for (int i = 0; i < 10; i++) {
+      arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.FLOAT, 
(double) i));
+    }
+
+    ArrayLiteralTransformFunction floatArray = new 
ArrayLiteralTransformFunction(arrayExpressions);
+    Assert.assertEquals(floatArray.getResultMetadata().getDataType(), 
DataType.FLOAT);
+    Assert.assertEquals(floatArray.getFloatArrayLiteral(), new float[]{
+        0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f
+    });
+  }
+
+  @Test
+  public void testDoubleArrayLiteralTransformFunction() {
+    List<ExpressionContext> arrayExpressions = new ArrayList<>();
+    for (int i = 0; i < 10; i++) {
+      
arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.DOUBLE, 
(double) i));
+    }
+
+    ArrayLiteralTransformFunction doubleArray = new 
ArrayLiteralTransformFunction(arrayExpressions);
+    Assert.assertEquals(doubleArray.getResultMetadata().getDataType(), 
DataType.DOUBLE);
+    Assert.assertEquals(doubleArray.getDoubleArrayLiteral(), new double[]{
+        0d, 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d
+    });
+  }
+
+  @Test
+  public void testStringArrayLiteralTransformFunction() {
+    List<ExpressionContext> arrayExpressions = new ArrayList<>();
+    for (int i = 0; i < 10; i++) {
+      arrayExpressions.add(
+          ExpressionContext.forLiteralContext(new 
Literal(Literal._Fields.STRING_VALUE, String.valueOf(i))));
+    }
+
+    ArrayLiteralTransformFunction stringArray = new 
ArrayLiteralTransformFunction(arrayExpressions);
+    Assert.assertEquals(stringArray.getResultMetadata().getDataType(), 
DataType.STRING);
+    Assert.assertEquals(stringArray.getStringArrayLiteral(), new String[]{
+        "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"
+    });
+  }
+
+  @Test
+  public void testEmptyArrayTransform() {
+    List<ExpressionContext> arrayExpressions = new ArrayList<>();
+    ArrayLiteralTransformFunction emptyLiteral = new 
ArrayLiteralTransformFunction(arrayExpressions);
+    Assert.assertEquals(emptyLiteral.getIntArrayLiteral(), new int[0]);
+    Assert.assertEquals(emptyLiteral.getLongArrayLiteral(), new long[0]);
+    Assert.assertEquals(emptyLiteral.getFloatArrayLiteral(), new float[0]);
+    Assert.assertEquals(emptyLiteral.getDoubleArrayLiteral(), new double[0]);
+    Assert.assertEquals(emptyLiteral.getStringArrayLiteral(), new String[0]);
+
+    int[][] ints = emptyLiteral.transformToIntValuesMV(_projectionBlock);
+    Assert.assertEquals(ints.length, NUM_DOCS);
+    for (int i = 0; i < NUM_DOCS; i++) {
+      Assert.assertEquals(ints[i].length, 0);
+    }
+
+    long[][] longs = emptyLiteral.transformToLongValuesMV(_projectionBlock);
+    Assert.assertEquals(longs.length, NUM_DOCS);
+    for (int i = 0; i < NUM_DOCS; i++) {
+      Assert.assertEquals(longs[i].length, 0);
+    }
+
+    float[][] floats = emptyLiteral.transformToFloatValuesMV(_projectionBlock);
+    Assert.assertEquals(floats.length, NUM_DOCS);
+    for (int i = 0; i < NUM_DOCS; i++) {
+      Assert.assertEquals(floats[i].length, 0);
+    }
+
+    double[][] doubles = 
emptyLiteral.transformToDoubleValuesMV(_projectionBlock);
+    Assert.assertEquals(doubles.length, NUM_DOCS);
+    for (int i = 0; i < NUM_DOCS; i++) {
+      Assert.assertEquals(doubles[i].length, 0);
+    }
+
+    String[][] strings = 
emptyLiteral.transformToStringValuesMV(_projectionBlock);
+    Assert.assertEquals(strings.length, NUM_DOCS);
+    for (int i = 0; i < NUM_DOCS; i++) {
+      Assert.assertEquals(strings[i].length, 0);
+    }
+  }
+}
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java
index 8aed6e4698..23b3213f3e 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java
@@ -59,6 +59,9 @@ public class VectorTransformFunctionTest extends 
BaseTransformFunctionTest {
 
   @DataProvider(name = "testVectorTransformFunctionDataProvider")
   public Object[][] testVectorTransformFunctionDataProvider() {
+    String zeroVectorLiteral = "ARRAY[0.0"
+        + ",0.0".repeat(VECTOR_DIM_SIZE - 1)
+        + "]";
     return new Object[][]{
         new Object[]{"cosineDistance(vector1, vector2)", 0.1, 0.4},
         new Object[]{"cosineDistance(vector1, vector2, 0)", 0.1, 0.4},
@@ -67,7 +70,14 @@ public class VectorTransformFunctionTest extends 
BaseTransformFunctionTest {
         new Object[]{"l1Distance(vector1, vector2)", 140, 210},
         new Object[]{"l2Distance(vector1, vector2)", 8, 11},
         new Object[]{"vectorNorm(vector1)", 10, 16},
-        new Object[]{"vectorNorm(vector2)", 10, 16}
+        new Object[]{"vectorNorm(vector2)", 10, 16},
+
+        new Object[]{String.format("cosineDistance(vector1, %s, 0)", 
zeroVectorLiteral), 0.0, 0.0},
+        new Object[]{String.format("innerProduct(vector1, %s)", 
zeroVectorLiteral), 0.0, 0.0},
+        new Object[]{String.format("l1Distance(vector1, %s)", 
zeroVectorLiteral), 0, VECTOR_DIM_SIZE},
+        new Object[]{String.format("l2Distance(vector1, %s)", 
zeroVectorLiteral), 0, VECTOR_DIM_SIZE},
+        new Object[]{String.format("vectorDims(%s)", zeroVectorLiteral), 
VECTOR_DIM_SIZE, VECTOR_DIM_SIZE},
+        new Object[]{String.format("vectorNorm(%s)", zeroVectorLiteral), 0.0, 
0.0},
     };
   }
 }
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java
index 48efe20490..dbfcd5a347 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java
@@ -100,32 +100,96 @@ public class VectorIntegrationTest extends 
BaseClusterIntegrationTest {
             + "vectorNorm(vector1), vectorNorm(vector2), "
             + "cosineDistance(vector1, zeroVector), "
             + "cosineDistance(vector1, zeroVector, 0) "
-            + "FROM %s", DEFAULT_TABLE_NAME);
+            + "FROM %s LIMIT %d", DEFAULT_TABLE_NAME, getCountStarResult());
     JsonNode jsonNode = postQuery(query);
     for (int i = 0; i < getCountStarResult(); i++) {
-      double cosineDistance = 
jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble();
+      double cosineDistance = 
jsonNode.get("resultTable").get("rows").get(i).get(0).asDouble();
       assertTrue(cosineDistance > 0.1 && cosineDistance < 0.4);
-      double innerProduce = 
jsonNode.get("resultTable").get("rows").get(0).get(1).asDouble();
+      double innerProduce = 
jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble();
       assertTrue(innerProduce > 100 && innerProduce < 160);
-      double l1Distance = 
jsonNode.get("resultTable").get("rows").get(0).get(2).asDouble();
+      double l1Distance = 
jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble();
       assertTrue(l1Distance > 140 && l1Distance < 210);
-      double l2Distance = 
jsonNode.get("resultTable").get("rows").get(0).get(3).asDouble();
+      double l2Distance = 
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble();
       assertTrue(l2Distance > 8 && l2Distance < 11);
-      int vectorDimsVector1 = 
jsonNode.get("resultTable").get("rows").get(0).get(4).asInt();
+      int vectorDimsVector1 = 
jsonNode.get("resultTable").get("rows").get(i).get(4).asInt();
       assertEquals(vectorDimsVector1, VECTOR_DIM_SIZE);
-      int vectorDimsVector2 = 
jsonNode.get("resultTable").get("rows").get(0).get(5).asInt();
+      int vectorDimsVector2 = 
jsonNode.get("resultTable").get("rows").get(i).get(5).asInt();
       assertEquals(vectorDimsVector2, VECTOR_DIM_SIZE);
-      double vectorNormVector1 = 
jsonNode.get("resultTable").get("rows").get(0).get(6).asInt();
+      double vectorNormVector1 = 
jsonNode.get("resultTable").get("rows").get(i).get(6).asInt();
       assertTrue(vectorNormVector1 > 10 && vectorNormVector1 < 16);
-      double vectorNormVector2 = 
jsonNode.get("resultTable").get("rows").get(0).get(7).asInt();
+      double vectorNormVector2 = 
jsonNode.get("resultTable").get("rows").get(i).get(7).asInt();
       assertTrue(vectorNormVector2 > 10 && vectorNormVector2 < 16);
-      cosineDistance = 
jsonNode.get("resultTable").get("rows").get(0).get(8).asDouble();
+      cosineDistance = 
jsonNode.get("resultTable").get("rows").get(i).get(8).asDouble();
       assertEquals(cosineDistance, Double.NaN);
-      cosineDistance = 
jsonNode.get("resultTable").get("rows").get(0).get(9).asDouble();
+      cosineDistance = 
jsonNode.get("resultTable").get("rows").get(i).get(9).asDouble();
       assertEquals(cosineDistance, 0.0);
     }
   }
 
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testQueriesWithLiterals(boolean useMultiStageQueryEngine)
+      throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    String zeroVectorStringLiteral = "ARRAY[0.0"
+        + ", 0.0".repeat(VECTOR_DIM_SIZE - 1)
+        + "]";
+    String oneVectorStringLiteral = "ARRAY[1.0"
+        + ", 1.0".repeat(VECTOR_DIM_SIZE - 1)
+        + "]";
+    String query =
+        String.format("SELECT "
+                + "cosineDistance(vector1, %s), "
+                + "innerProduct(vector1, %s), "
+                + "l1Distance(vector1, %s), "
+                + "l2Distance(vector1, %s), "
+                + "vectorDims(%s), "
+                + "vectorNorm(%s) "
+                + "FROM %s LIMIT %d",
+            zeroVectorStringLiteral, zeroVectorStringLiteral, 
zeroVectorStringLiteral, zeroVectorStringLiteral,
+            zeroVectorStringLiteral, zeroVectorStringLiteral, 
DEFAULT_TABLE_NAME, getCountStarResult());
+    JsonNode jsonNode = postQuery(query);
+    for (int i = 0; i < getCountStarResult(); i++) {
+      double cosineDistance = 
jsonNode.get("resultTable").get("rows").get(i).get(0).asDouble();
+      assertEquals(cosineDistance, Double.NaN);
+      double innerProduce = 
jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble();
+      assertEquals(innerProduce, 0.0);
+      double l1Distance = 
jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble();
+      assertTrue(l1Distance > 100 && l1Distance < 300);
+      double l2Distance = 
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble();
+      assertTrue(l2Distance > 10 && l2Distance < 16);
+      int vectorDimsVector = 
jsonNode.get("resultTable").get("rows").get(i).get(4).asInt();
+      assertEquals(vectorDimsVector, VECTOR_DIM_SIZE);
+      double vectorNormVector = 
jsonNode.get("resultTable").get("rows").get(i).get(5).asInt();
+      assertEquals(vectorNormVector, 0.0);
+    }
+
+    query =
+        String.format("SELECT "
+                + "cosineDistance(%s, %s), "
+                + "cosineDistance(%s, %s, 0.0), "
+                + "innerProduct(%s, %s), "
+                + "l1Distance(%s, %s), "
+                + "l2Distance(%s, %s)"
+                + "FROM %s LIMIT 1",
+            zeroVectorStringLiteral, oneVectorStringLiteral,
+            zeroVectorStringLiteral, oneVectorStringLiteral,
+            zeroVectorStringLiteral, oneVectorStringLiteral,
+            zeroVectorStringLiteral, oneVectorStringLiteral,
+            zeroVectorStringLiteral, oneVectorStringLiteral,
+            DEFAULT_TABLE_NAME);
+    jsonNode = postQuery(query);
+    double cosineDistance = 
jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble();
+    assertEquals(cosineDistance, Double.NaN);
+    cosineDistance = 
jsonNode.get("resultTable").get("rows").get(0).get(1).asDouble();
+    assertEquals(cosineDistance, 0.0);
+    double innerProduce = 
jsonNode.get("resultTable").get("rows").get(0).get(2).asDouble();
+    assertEquals(innerProduce, 0.0);
+    double l1Distance = 
jsonNode.get("resultTable").get("rows").get(0).get(3).asDouble();
+    assertEquals(l1Distance, 512.0);
+    double l2Distance = 
jsonNode.get("resultTable").get("rows").get(0).get(4).asDouble();
+    assertEquals(l2Distance, 22.627416997969522);
+  }
+
   private File createAvroFile(long totalNumRecords)
       throws IOException {
 
diff --git 
a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
 
b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
index 2ee178419f..1a63a6eb07 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
@@ -49,8 +49,6 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
 
   private static @MonotonicNonNull PinotOperatorTable _instance;
 
-  public static final SqlFunction COALESCE = new PinotSqlCoalesceFunction();
-
   // TODO: clean up lazy init by using 
Suppliers.memorized(this::computeInstance) and make getter wrapped around
   // supplier instance. this should replace all lazy init static objects in 
the codebase
   public static synchronized PinotOperatorTable instance() {
@@ -75,6 +73,12 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
    * which are multistage enabled.
    */
   public final void initNoDuplicate() {
+    // Pinot supports native COALESCE function, thus no need to create CASE 
WHEN conversion.
+    register(new PinotSqlCoalesceFunction());
+    // Ensure ArrayValueConstructor is registered before ArrayQueryConstructor
+    register(ARRAY_VALUE_CONSTRUCTOR);
+
+    // TODO: reflection based registration is not ideal, we should use a 
static list of operators and register them
     // Use reflection to register the expressions stored in public fields.
     for (Field field : getClass().getFields()) {
       try {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
index a4e6be355a..b0b7545677 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
@@ -219,7 +219,7 @@ public final class RelToPlanNodeConverter {
       case BIGINT:
         return isArray ? DataSchema.ColumnDataType.LONG_ARRAY : 
DataSchema.ColumnDataType.LONG;
       case DECIMAL:
-        return resolveDecimal(relDataType);
+        return resolveDecimal(relDataType, isArray);
       case FLOAT:
       case REAL:
         return isArray ? DataSchema.ColumnDataType.FLOAT_ARRAY : 
DataSchema.ColumnDataType.FLOAT;
@@ -259,31 +259,32 @@ public final class RelToPlanNodeConverter {
   }
 
   /**
-   * Calcite uses DEMICAL type to infer data type hoisting and infer 
arithmetic result types. down casting this
-   * back to the proper primitive type for Pinot.
+   * Calcite uses DEMICAL type to infer data type hoisting and infer 
arithmetic result types. down casting this back to
+   * the proper primitive type for Pinot.
    *
    * @param relDataType the DECIMAL rel data type.
+   * @param isArray
    * @return proper {@link DataSchema.ColumnDataType}.
    * @see {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#decimalOf}.
    */
-  private static DataSchema.ColumnDataType resolveDecimal(RelDataType 
relDataType) {
+  private static DataSchema.ColumnDataType resolveDecimal(RelDataType 
relDataType, boolean isArray) {
     int precision = relDataType.getPrecision();
     int scale = relDataType.getScale();
     if (scale == 0) {
       if (precision <= 10) {
-        return DataSchema.ColumnDataType.INT;
+        return isArray ? DataSchema.ColumnDataType.INT_ARRAY : 
DataSchema.ColumnDataType.INT;
       } else if (precision <= 38) {
-        return DataSchema.ColumnDataType.LONG;
+        return isArray ? DataSchema.ColumnDataType.LONG_ARRAY : 
DataSchema.ColumnDataType.LONG;
       } else {
-        return DataSchema.ColumnDataType.BIG_DECIMAL;
+        return isArray ? DataSchema.ColumnDataType.DOUBLE_ARRAY : 
DataSchema.ColumnDataType.BIG_DECIMAL;
       }
     } else {
       if (precision <= 14) {
-        return DataSchema.ColumnDataType.FLOAT;
+        return isArray ? DataSchema.ColumnDataType.FLOAT_ARRAY : 
DataSchema.ColumnDataType.FLOAT;
       } else if (precision <= 30) {
-        return DataSchema.ColumnDataType.DOUBLE;
+        return isArray ? DataSchema.ColumnDataType.DOUBLE_ARRAY : 
DataSchema.ColumnDataType.DOUBLE;
       } else {
-        return DataSchema.ColumnDataType.BIG_DECIMAL;
+        return isArray ? DataSchema.ColumnDataType.DOUBLE_ARRAY : 
DataSchema.ColumnDataType.BIG_DECIMAL;
       }
     }
   }
diff --git 
a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
 
b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
index df896d2c00..823dd23b88 100644
--- 
a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
+++ 
b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java
@@ -20,6 +20,7 @@ package org.apache.pinot.segment.local.function;
 
 import com.google.common.base.Preconditions;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.pinot.common.function.FunctionInfo;
@@ -78,6 +79,13 @@ public class InbuiltFunctionEvaluator implements 
FunctionEvaluator {
           case "not":
             Preconditions.checkState(numArguments == 1, "NOT function expects 
1 argument, got: %s", numArguments);
             return new NotExecutionNode(childNodes[0]);
+          case "arrayvalueconstructor":
+            Object[] values = new Object[numArguments];
+            int i = 0;
+            for (ExpressionContext literal : arguments) {
+              values[i++] = literal.getLiteral().getValue();
+            }
+            return new ArrayConstantExecutionNode(values);
           default:
             FunctionInfo functionInfo = 
FunctionRegistry.getFunctionInfo(functionName, numArguments);
             if (functionInfo == null) {
@@ -145,7 +153,7 @@ public class InbuiltFunctionEvaluator implements 
FunctionEvaluator {
 
     @Override
     public Object execute(GenericRow row) {
-      for (ExecutableNode executableNode :_argumentNodes) {
+      for (ExecutableNode executableNode : _argumentNodes) {
         Boolean res = (Boolean) executableNode.execute(row);
         if (res) {
           return true;
@@ -156,7 +164,7 @@ public class InbuiltFunctionEvaluator implements 
FunctionEvaluator {
 
     @Override
     public Object execute(Object[] values) {
-      for (ExecutableNode executableNode :_argumentNodes) {
+      for (ExecutableNode executableNode : _argumentNodes) {
         Boolean res = (Boolean) executableNode.execute(values);
         if (res) {
           return true;
@@ -175,7 +183,7 @@ public class InbuiltFunctionEvaluator implements 
FunctionEvaluator {
 
     @Override
     public Object execute(GenericRow row) {
-      for (ExecutableNode executableNode :_argumentNodes) {
+      for (ExecutableNode executableNode : _argumentNodes) {
         Boolean res = (Boolean) executableNode.execute(row);
         if (!res) {
           return false;
@@ -186,7 +194,7 @@ public class InbuiltFunctionEvaluator implements 
FunctionEvaluator {
 
     @Override
     public Object execute(Object[] values) {
-      for (ExecutableNode executableNode :_argumentNodes) {
+      for (ExecutableNode executableNode : _argumentNodes) {
         Boolean res = (Boolean) executableNode.execute(values);
         if (!res) {
           return false;
@@ -284,6 +292,29 @@ public class InbuiltFunctionEvaluator implements 
FunctionEvaluator {
     }
   }
 
+  private static class ArrayConstantExecutionNode implements ExecutableNode {
+    final Object[] _value;
+
+    ArrayConstantExecutionNode(Object[] value) {
+      _value = value;
+    }
+
+    @Override
+    public Object[] execute(GenericRow row) {
+      return _value;
+    }
+
+    @Override
+    public Object[] execute(Object[] values) {
+      return _value;
+    }
+
+    @Override
+    public String toString() {
+      return String.format("'%s'", Arrays.toString(_value));
+    }
+  }
+
   private static class ColumnExecutionNode implements ExecutableNode {
     final String _column;
     final int _id;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org


Reply via email to